Skip to content

Commit cc13200

Browse files
committed
clean code
- del main func in evaluation __main__.py - del additional tabs
1 parent b303f7c commit cc13200

12 files changed

Lines changed: 160 additions & 184 deletions

File tree

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,19 @@ class AbsEmbedder(ABC):
4141
"""
4242

4343
def __init__(
44-
self,
45-
model_name_or_path: str,
46-
normalize_embeddings: bool = True,
47-
use_fp16: bool = True,
48-
query_instruction_for_retrieval: Optional[str] = None,
49-
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
50-
devices: Optional[Union[str, int, List[str], List[int]]] = None,
51-
# inference
52-
batch_size: int = 256,
53-
query_max_length: int = 512,
54-
passage_max_length: int = 512,
55-
instruction: Optional[str] = None,
56-
instruction_format: str = "{}{}",
57-
convert_to_numpy: bool = True,
58-
**kwargs: Any,
44+
self,
45+
model_name_or_path: str,
46+
normalize_embeddings: bool = True,
47+
use_fp16: bool = True,
48+
query_instruction_for_retrieval: Optional[str] = None,
49+
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
50+
devices: Optional[Union[str, int, List[str], List[int]]] = None,
51+
# inference
52+
batch_size: int = 256,
53+
query_max_length: int = 512,
54+
passage_max_length: int = 512,
55+
convert_to_numpy: bool = True,
56+
**kwargs: Any,
5957
):
6058
self.model_name_or_path = model_name_or_path
6159
self.normalize_embeddings = normalize_embeddings
@@ -67,8 +65,6 @@ def __init__(
6765
self.batch_size = batch_size
6866
self.query_max_length = query_max_length
6967
self.passage_max_length = passage_max_length
70-
self.instruction = instruction
71-
self.instruction_format = instruction_format
7268
self.convert_to_numpy = convert_to_numpy
7369

7470
for k in kwargs:
@@ -132,12 +128,12 @@ def get_detailed_instruct(instruction_format: str, instruction: str, sentence: s
132128
return instruction_format.format(instruction, sentence)
133129

134130
def encode_queries(
135-
self,
136-
queries: Union[List[str], str],
137-
batch_size: Optional[int] = None,
138-
max_length: Optional[int] = None,
139-
convert_to_numpy: Optional[bool] = None,
140-
**kwargs: Any
131+
self,
132+
queries: Union[List[str], str],
133+
batch_size: Optional[int] = None,
134+
max_length: Optional[int] = None,
135+
convert_to_numpy: Optional[bool] = None,
136+
**kwargs: Any
141137
):
142138
"""encode the queries using the instruction if provided.
143139
@@ -166,12 +162,12 @@ def encode_queries(
166162
)
167163

168164
def encode_corpus(
169-
self,
170-
corpus: Union[List[str], str],
171-
batch_size: Optional[int] = None,
172-
max_length: Optional[int] = None,
173-
convert_to_numpy: Optional[bool] = None,
174-
**kwargs: Any
165+
self,
166+
corpus: Union[List[str], str],
167+
batch_size: Optional[int] = None,
168+
max_length: Optional[int] = None,
169+
convert_to_numpy: Optional[bool] = None,
170+
**kwargs: Any
175171
):
176172
"""encode the corpus using the instruction if provided.
177173
@@ -203,14 +199,14 @@ def encode_corpus(
203199
)
204200

205201
def encode(
206-
self,
207-
sentences: Union[List[str], str],
208-
batch_size: Optional[int] = None,
209-
max_length: Optional[int] = None,
210-
convert_to_numpy: Optional[bool] = None,
211-
instruction: Optional[str] = None,
212-
instruction_format: Optional[str] = None,
213-
**kwargs: Any
202+
self,
203+
sentences: Union[List[str], str],
204+
batch_size: Optional[int] = None,
205+
max_length: Optional[int] = None,
206+
convert_to_numpy: Optional[bool] = None,
207+
instruction: Optional[str] = None,
208+
instruction_format: Optional[str] = None,
209+
**kwargs: Any
214210
):
215211
"""encode the input sentences with the embedding model.
216212
@@ -265,13 +261,13 @@ def __del__(self):
265261

266262
@abstractmethod
267263
def encode_single_device(
268-
self,
269-
sentences: Union[List[str], str],
270-
batch_size: int = 256,
271-
max_length: int = 512,
272-
convert_to_numpy: bool = True,
273-
device: Optional[str] = None,
274-
**kwargs: Any,
264+
self,
265+
sentences: Union[List[str], str],
266+
batch_size: int = 256,
267+
max_length: int = 512,
268+
convert_to_numpy: bool = True,
269+
device: Optional[str] = None,
270+
**kwargs: Any,
275271
):
276272
"""
277273
This method should encode sentences and return embeddings on a single device.
@@ -280,8 +276,8 @@ def encode_single_device(
280276

281277
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L807
282278
def start_multi_process_pool(
283-
self,
284-
process_target_func: Any,
279+
self,
280+
process_target_func: Any,
285281
) -> Dict[Literal["input", "output", "processes"], Any]:
286282
"""
287283
Starts a multi-process pool to process the encoding with several independent processes
@@ -320,7 +316,7 @@ def start_multi_process_pool(
320316
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L976
321317
@staticmethod
322318
def _encode_multi_process_worker(
323-
target_device: str, model: 'AbsEmbedder', input_queue: Queue, results_queue: Queue
319+
target_device: str, model: 'AbsEmbedder', input_queue: Queue, results_queue: Queue
324320
) -> None:
325321
"""
326322
Internal working process to encode sentences in multi-process setup
@@ -364,10 +360,10 @@ def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"],
364360

365361
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L877
366362
def encode_multi_process(
367-
self,
368-
sentences: List[str],
369-
pool: Dict[Literal["input", "output", "processes"], Any],
370-
**kwargs
363+
self,
364+
sentences: List[str],
365+
pool: Dict[Literal["input", "output", "processes"], Any],
366+
**kwargs
371367
):
372368
chunk_size = math.ceil(len(sentences) / len(pool["processes"]))
373369

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .arguments import AIRBenchEvalModelArgs, AIRBenchEvalArgs
2+
from .runner import AIRBenchEvalRunner
3+
4+
__all__ = [
5+
"AIRBenchEvalModelArgs",
6+
"AIRBenchEvalArgs",
7+
"AIRBenchEvalRunner"
8+
]
Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,28 @@
11
from transformers import HfArgumentParser
22

3-
from .arguments import AIRBenchEvalArgs, AIRBenchEvalModelArgs
4-
from .runner import AIRBenchEvalRunner
3+
from FlagEmbedding.evaluation.air_bench import (
4+
AIRBenchEvalArgs, AIRBenchEvalModelArgs,
5+
AIRBenchEvalRunner
6+
)
57

6-
def main():
7-
parser = HfArgumentParser((
8-
AIRBenchEvalArgs,
9-
AIRBenchEvalModelArgs
10-
))
118

12-
eval_args, model_args = parser.parse_args_into_dataclasses()
13-
eval_args: AIRBenchEvalArgs
14-
model_args: AIRBenchEvalModelArgs
9+
parser = HfArgumentParser((
10+
AIRBenchEvalArgs,
11+
AIRBenchEvalModelArgs
12+
))
1513

16-
runner = AIRBenchEvalRunner(
17-
eval_args=eval_args,
18-
model_args=model_args
19-
)
14+
eval_args, model_args = parser.parse_args_into_dataclasses()
15+
eval_args: AIRBenchEvalArgs
16+
model_args: AIRBenchEvalModelArgs
2017

21-
runner.run()
18+
runner = AIRBenchEvalRunner(
19+
eval_args=eval_args,
20+
model_args=model_args
21+
)
2222

23-
print("==============================================")
24-
print("Search results have been generated.")
25-
print("For computing metrics, please refer to the official AIR-Bench docs:")
26-
print("- https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/submit_to_leaderboard.md")
23+
runner.run()
2724

28-
if __name__ == "__main__":
29-
main()
25+
print("==============================================")
26+
print("Search results have been generated.")
27+
print("For computing metrics, please refer to the official AIR-Bench docs:")
28+
print("- https://github.com/AIR-Bench/AIR-Bench/blob/main/docs/submit_to_leaderboard.md")

FlagEmbedding/evaluation/air_bench/arguments.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from dataclasses import dataclass, field
22
from typing import List, Optional
33

4-
from air_benchmark import EvalArgs as AIRBenchEvalArgs
5-
64

75
@dataclass
86
class AIRBenchEvalModelArgs:

FlagEmbedding/evaluation/beir/__main__.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,19 @@
55
BEIREvalRunner
66
)
77

8-
def main():
9-
parser = HfArgumentParser((
10-
BEIREvalArgs,
11-
BEIREvalModelArgs
12-
))
138

14-
eval_args, model_args = parser.parse_args_into_dataclasses()
15-
eval_args: BEIREvalArgs
16-
model_args: BEIREvalModelArgs
9+
parser = HfArgumentParser((
10+
BEIREvalArgs,
11+
BEIREvalModelArgs
12+
))
1713

18-
runner = BEIREvalRunner(
19-
eval_args=eval_args,
20-
model_args=model_args
21-
)
14+
eval_args, model_args = parser.parse_args_into_dataclasses()
15+
eval_args: BEIREvalArgs
16+
model_args: BEIREvalModelArgs
2217

23-
runner.run()
18+
runner = BEIREvalRunner(
19+
eval_args=eval_args,
20+
model_args=model_args
21+
)
2422

25-
if __name__ == "__main__":
26-
main()
23+
runner.run()

FlagEmbedding/evaluation/custom/__main__.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,19 @@
55
CustomEvalRunner
66
)
77

8-
def main():
9-
parser = HfArgumentParser((
10-
CustomEvalArgs,
11-
CustomEvalModelArgs
12-
))
138

14-
eval_args, model_args = parser.parse_args_into_dataclasses()
15-
eval_args: CustomEvalArgs
16-
model_args: CustomEvalModelArgs
9+
parser = HfArgumentParser((
10+
CustomEvalArgs,
11+
CustomEvalModelArgs
12+
))
1713

18-
runner = CustomEvalRunner(
19-
eval_args=eval_args,
20-
model_args=model_args
21-
)
14+
eval_args, model_args = parser.parse_args_into_dataclasses()
15+
eval_args: CustomEvalArgs
16+
model_args: CustomEvalModelArgs
2217

23-
runner.run()
18+
runner = CustomEvalRunner(
19+
eval_args=eval_args,
20+
model_args=model_args
21+
)
2422

25-
if __name__ == "__main__":
26-
main()
23+
runner.run()

FlagEmbedding/evaluation/custom/data_loader.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
import os
2-
import json
31
import logging
4-
import datasets
52
from tqdm import tqdm
63
from typing import List, Optional
74

@@ -15,4 +12,4 @@ def available_dataset_names(self) -> List[str]:
1512
return []
1613

1714
def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
18-
return ["train", "dev", "test"]
15+
return ["test"]

FlagEmbedding/evaluation/miracl/__main__.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,19 @@
55
MIRACLEvalRunner
66
)
77

8-
def main():
9-
parser = HfArgumentParser((
10-
MIRACLEvalArgs,
11-
MIRACLEvalModelArgs
12-
))
138

14-
eval_args, model_args = parser.parse_args_into_dataclasses()
15-
eval_args: MIRACLEvalArgs
16-
model_args: MIRACLEvalModelArgs
9+
parser = HfArgumentParser((
10+
MIRACLEvalArgs,
11+
MIRACLEvalModelArgs
12+
))
1713

18-
runner = MIRACLEvalRunner(
19-
eval_args=eval_args,
20-
model_args=model_args
21-
)
14+
eval_args, model_args = parser.parse_args_into_dataclasses()
15+
eval_args: MIRACLEvalArgs
16+
model_args: MIRACLEvalModelArgs
2217

23-
runner.run()
18+
runner = MIRACLEvalRunner(
19+
eval_args=eval_args,
20+
model_args=model_args
21+
)
2422

25-
if __name__ == "__main__":
26-
main()
23+
runner.run()

FlagEmbedding/evaluation/mkqa/__main__.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,19 @@
55
MKQAEvalRunner
66
)
77

8-
def main():
9-
parser = HfArgumentParser((
10-
MKQAEvalArgs,
11-
MKQAEvalModelArgs
12-
))
138

14-
eval_args, model_args = parser.parse_args_into_dataclasses()
15-
eval_args: MKQAEvalArgs
16-
model_args: MKQAEvalModelArgs
9+
parser = HfArgumentParser((
10+
MKQAEvalArgs,
11+
MKQAEvalModelArgs
12+
))
1713

18-
runner = MKQAEvalRunner(
19-
eval_args=eval_args,
20-
model_args=model_args
21-
)
14+
eval_args, model_args = parser.parse_args_into_dataclasses()
15+
eval_args: MKQAEvalArgs
16+
model_args: MKQAEvalModelArgs
2217

23-
runner.run()
18+
runner = MKQAEvalRunner(
19+
eval_args=eval_args,
20+
model_args=model_args
21+
)
2422

25-
if __name__ == "__main__":
26-
main()
23+
runner.run()

0 commit comments

Comments
 (0)