Skip to content

Commit 61a6859

Browse files
committed
update custom eval
1 parent fb105af commit 61a6859

5 files changed

Lines changed: 96 additions & 22 deletions

File tree

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from FlagEmbedding.abc.evaluation import (
2+
AbsEvalArgs as CustomEvalArgs,
3+
AbsEvalModelArgs as CustomEvalModelArgs,
4+
)
5+
6+
from .data_loader import CustomEvalDataLoader
7+
from .runner import CustomEvalRunner
8+
9+
__all__ = [
10+
"CustomEvalArgs",
11+
"CustomEvalModelArgs",
12+
"CustomEvalRunner",
13+
"CustomEvalDataLoader",
14+
]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from transformers import HfArgumentParser
2+
3+
from FlagEmbedding.evaluation.mldr import (
4+
CustomEvalArgs, CustomEvalModelArgs,
5+
CustomEvalRunner
6+
)
7+
8+
9+
parser = HfArgumentParser((
10+
CustomEvalArgs,
11+
CustomEvalModelArgs
12+
))
13+
14+
eval_args, model_args = parser.parse_args_into_dataclasses()
15+
eval_args: CustomEvalArgs
16+
model_args: CustomEvalModelArgs
17+
18+
runner = CustomEvalRunner(
19+
eval_args=eval_args,
20+
model_args=model_args
21+
)
22+
23+
runner.run()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import os
2+
import json
3+
import logging
4+
import datasets
5+
from tqdm import tqdm
6+
from typing import List, Optional
7+
8+
from FlagEmbedding.abc.evaluation import AbsEvalDataLoader
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class MLDREvalDataLoader(AbsEvalDataLoader):
14+
def available_dataset_names(self) -> List[str]:
15+
return []
16+
17+
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
18+
return ["train", "dev", "test"]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from FlagEmbedding.abc.evaluation import AbsEvalRunner
2+
3+
from .data_loader import MLDREvalDataLoader
4+
5+
6+
class MLDREvalRunner(AbsEvalRunner):
7+
def load_data_loader(self) -> MLDREvalDataLoader:
8+
data_loader = MLDREvalDataLoader(
9+
eval_name=self.eval_args.eval_name,
10+
dataset_dir=self.eval_args.dataset_dir,
11+
cache_dir=self.eval_args.cache_path,
12+
token=self.eval_args.token,
13+
force_redownload=self.eval_args.force_redownload,
14+
)
15+
return data_loader

examples/evaluation/README.md

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,47 +54,29 @@ First, we will introduce the commonly used parameters, followed by an introducti
5454
**Parameters for Model Configuration:**
5555

5656
- **`embedder_name_or_path`**: The name or path to the embedder.
57-
5857
- **`embedder_model_class`**: Class of the model used for embedding (options include 'auto', 'encoder-only-base', etc.). Default is `auto`.
59-
6058
- **`normalize_embeddings`**: Set to `true` to normalize embeddings.
61-
6259
- **`use_fp16`**: Use FP16 precision for inference.
63-
6460
- **`devices`**: List of devices used for inference.
65-
6661
- **`query_instruction_for_retrieval`**, **`query_instruction_format_for_retrieval`**: Instructions and format for query during retrieval.
67-
6862
- **`examples_for_task`**, **`examples_instruction_format`**: Example tasks and their instructions format.
69-
7063
- **`trust_remote_code`**: Set to `true` to trust remote code execution.
71-
7264
- **`reranker_name_or_path`**: Name or path to the reranker.
73-
7465
- **`reranker_model_class`**: Reranker model class (options include 'auto', 'decoder-only-base', etc.). Default is `auto`.
75-
7666
- **`reranker_peft_path`**: Path for portable encoder fine-tuning of the reranker.
77-
7867
- **`use_bf16`**: Use BF16 precision for inference.
79-
8068
- **`query_instruction_for_rerank`**, **`query_instruction_format_for_rerank`**: Instructions and format for query during reranking.
81-
8269
- **`passage_instruction_for_rerank`**, **`passage_instruction_format_for_rerank`**: Instructions and format for processing passages during reranking.
83-
8470
- **`cache_dir`**: Cache directory for models.
85-
8671
- **`embedder_batch_size`**, **`reranker_batch_size`**: Batch sizes for embedding and reranking.
87-
8872
- **`embedder_query_max_length`**, **`embedder_passage_max_length`**: Maximum length for embedding queries and passages.
89-
9073
- **`reranker_query_max_length`**, **`reranker_max_length`**: Maximum lengths for reranking queries and reranking in general.
91-
9274
- **`normalize`**: Normalize the reranking scores.
93-
9475
- **`prompt`**: Prompt for the reranker.
95-
9676
- **`cutoff_layers`**, **`compress_ratio`**, **`compress_layers`**: Parameters for configuring the output and compression of layerwise or lightweight rerankers.
9777

78+
***Notice:*** If you evaluate your own model, please set `embedder_model_class` and `reranker_model_class`.
79+
9880
## Usage
9981

10082
### 1. MTEB
@@ -306,8 +288,6 @@ python -m FlagEmbedding.evaluation.air_bench \
306288

307289
### 8. Custom Dataset
308290

309-
You can refer to [MLDR dataset](https://github.com/hanhainebula/FlagEmbedding/tree/new-flagembedding-v1/FlagEmbedding/evaluation/mldr), just need to rewrite `DataLoader`, rewriting the loading method for the required dataset.
310-
311291
The example data for `corpus.jsonl`:
312292

313293
```json
@@ -334,3 +314,27 @@ The example data for `test_qrels.jsonl`:
334314
{"qid": "79085", "docid": "81285", "relevance": 1}
335315
```
336316

317+
Please put the above file in `dataset_dir`, and then you can use the following code:
318+
319+
```shell
320+
python -m FlagEmbedding.evaluation.custom \
321+
--eval_name your_data_name \
322+
--dataset_dir ./your_data_path \
323+
--splits test \
324+
--corpus_embd_save_dir ./your_data_name/corpus_embd \
325+
--output_dir ./your_data_name/search_results \
326+
--search_top_k 1000 \
327+
--rerank_top_k 100 \
328+
--cache_path ./cache/data \
329+
--overwrite False \
330+
--k_values 10 100 \
331+
--eval_output_method markdown \
332+
--eval_output_path ./your_data_name/eval_results.md \
333+
--eval_metrics ndcg_at_10 recall_at_100 \
334+
--embedder_name_or_path BAAI/bge-m3 \
335+
--reranker_name_or_path BAAI/bge-reranker-v2-m3 \
336+
--devices cuda:0 cuda:1 \
337+
--cache_dir ./cache/model \
338+
--reranker_query_max_length 512 \
339+
--reranker_max_length 1024
340+
```

0 commit comments

Comments
 (0)