Skip to content

Commit e513431

Browse files
committed
Merge branch 'new-flagembedding-v1' of https://github.com/hanhainebula/FlagEmbedding into new-flagembedding-v1
2 parents f683f2f + 33df5c6 commit e513431

85 files changed

Lines changed: 1743 additions & 778 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

FlagEmbedding/abc/inference/AbsEmbedder.py

Lines changed: 63 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,23 @@ class AbsEmbedder(ABC):
3939
Default: `True`.
4040
kwargs (Dict[Any], optional): Additional parameters for HuggingFace Transformers config or children classes.
4141
"""
42+
4243
def __init__(
43-
self,
44-
model_name_or_path: str,
45-
normalize_embeddings: bool = True,
46-
use_fp16: bool = True,
47-
query_instruction_for_retrieval: Optional[str] = None,
48-
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
49-
devices: Optional[Union[str, int, List[str], List[int]]] = None,
50-
# inference
51-
batch_size: int = 256,
52-
query_max_length: int = 512,
53-
passage_max_length: int = 512,
54-
instruction: Optional[str] = None,
55-
instruction_format: str = "{}{}",
56-
convert_to_numpy: bool = True,
57-
**kwargs: Any,
44+
self,
45+
model_name_or_path: str,
46+
normalize_embeddings: bool = True,
47+
use_fp16: bool = True,
48+
query_instruction_for_retrieval: Optional[str] = None,
49+
query_instruction_format: str = "{}{}", # specify the format of query_instruction_for_retrieval
50+
devices: Optional[Union[str, int, List[str], List[int]]] = None,
51+
# inference
52+
batch_size: int = 256,
53+
query_max_length: int = 512,
54+
passage_max_length: int = 512,
55+
instruction: Optional[str] = None,
56+
instruction_format: str = "{}{}",
57+
convert_to_numpy: bool = True,
58+
**kwargs: Any,
5859
):
5960
self.model_name_or_path = model_name_or_path
6061
self.normalize_embeddings = normalize_embeddings
@@ -78,6 +79,7 @@ def __init__(
7879
# tokenizer and model are initialized in the child class
7980
self.tokenizer = None
8081
self.model = None
82+
self.pool = None
8183

8284
@staticmethod
8385
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:
@@ -109,12 +111,12 @@ def get_detailed_instruct(instruction_format: str, instruction: str, sentence: s
109111
return instruction_format.format(instruction, sentence)
110112

111113
def encode_queries(
112-
self,
113-
queries: Union[List[str], str],
114-
batch_size: Optional[int] = None,
115-
max_length: Optional[int] = None,
116-
convert_to_numpy: Optional[bool] = None,
117-
**kwargs: Any
114+
self,
115+
queries: Union[List[str], str],
116+
batch_size: Optional[int] = None,
117+
max_length: Optional[int] = None,
118+
convert_to_numpy: Optional[bool] = None,
119+
**kwargs: Any
118120
):
119121
if batch_size is None: batch_size = self.batch_size
120122
if max_length is None: max_length = self.query_max_length
@@ -131,12 +133,12 @@ def encode_queries(
131133
)
132134

133135
def encode_corpus(
134-
self,
135-
corpus: Union[List[str], str],
136-
batch_size: Optional[int] = None,
137-
max_length: Optional[int] = None,
138-
convert_to_numpy: Optional[bool] = None,
139-
**kwargs: Any
136+
self,
137+
corpus: Union[List[str], str],
138+
batch_size: Optional[int] = None,
139+
max_length: Optional[int] = None,
140+
convert_to_numpy: Optional[bool] = None,
141+
**kwargs: Any
140142
):
141143
passage_instruction_for_retrieval = self.kwargs.get("passage_instruction_for_retrieval", None)
142144
passage_instruction_format = self.kwargs.get("passage_instruction_format", "{}{}")
@@ -156,23 +158,27 @@ def encode_corpus(
156158
)
157159

158160
def encode(
159-
self,
160-
sentences: Union[List[str], str],
161-
batch_size: Optional[int] = None,
162-
max_length: Optional[int] = None,
163-
convert_to_numpy: Optional[bool] = None,
164-
instruction: Optional[str] = None,
165-
instruction_format: Optional[str] = None,
166-
**kwargs: Any
161+
self,
162+
sentences: Union[List[str], str],
163+
batch_size: Optional[int] = None,
164+
max_length: Optional[int] = None,
165+
convert_to_numpy: Optional[bool] = None,
166+
instruction: Optional[str] = None,
167+
instruction_format: Optional[str] = None,
168+
**kwargs: Any
167169
):
168170
if instruction is None: instruction = self.instruction
169171
if instruction_format is None: instruction_format = self.instruction_format
172+
if batch_size is None: batch_size = self.batch_size
173+
if max_length is None: max_length = self.passage_max_length
174+
if convert_to_numpy is None: convert_to_numpy = self.convert_to_numpy
170175

171176
if instruction is not None:
172177
if isinstance(sentences, str):
173178
sentences = self.get_detailed_instruct(instruction_format, instruction, sentences)
174179
else:
175-
sentences = [self.get_detailed_instruct(instruction_format, instruction, sentence) for sentence in sentences]
180+
sentences = [self.get_detailed_instruct(instruction_format, instruction, sentence) for sentence in
181+
sentences]
176182

177183
if isinstance(sentences, str) or len(self.target_devices) == 1:
178184
return self.encode_single_device(
@@ -184,27 +190,31 @@ def encode(
184190
**kwargs
185191
)
186192

187-
pool = self.start_multi_process_pool(AbsEmbedder._encode_multi_process_worker)
193+
if self.pool is None:
194+
self.pool = self.start_multi_process_pool(AbsEmbedder._encode_multi_process_worker)
188195
embeddings = self.encode_multi_process(
189196
sentences,
190-
pool,
197+
self.pool,
191198
batch_size=batch_size,
192199
max_length=max_length,
193200
convert_to_numpy=convert_to_numpy,
194201
**kwargs
195202
)
196-
self.stop_multi_process_pool(pool)
197203
return embeddings
198204

205+
def __del__(self):
206+
if self.pool is not None:
207+
self.stop_multi_process_pool(self.pool)
208+
199209
@abstractmethod
200210
def encode_single_device(
201-
self,
202-
sentences: Union[List[str], str],
203-
batch_size: int = 256,
204-
max_length: int = 512,
205-
convert_to_numpy: bool = True,
206-
device: Optional[str] = None,
207-
**kwargs: Any,
211+
self,
212+
sentences: Union[List[str], str],
213+
batch_size: int = 256,
214+
max_length: int = 512,
215+
convert_to_numpy: bool = True,
216+
device: Optional[str] = None,
217+
**kwargs: Any,
208218
):
209219
"""
210220
This method should encode sentences and return embeddings on a single device.
@@ -213,8 +223,8 @@ def encode_single_device(
213223

214224
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L807
215225
def start_multi_process_pool(
216-
self,
217-
process_target_func: Any,
226+
self,
227+
process_target_func: Any,
218228
) -> Dict[Literal["input", "output", "processes"], Any]:
219229
"""
220230
Starts a multi-process pool to process the encoding with several independent processes
@@ -253,7 +263,7 @@ def start_multi_process_pool(
253263
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L976
254264
@staticmethod
255265
def _encode_multi_process_worker(
256-
target_device: str, model: 'AbsEmbedder', input_queue: Queue, results_queue: Queue
266+
target_device: str, model: 'AbsEmbedder', input_queue: Queue, results_queue: Queue
257267
) -> None:
258268
"""
259269
Internal working process to encode sentences in multi-process setup
@@ -297,10 +307,10 @@ def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"],
297307

298308
# adapted from https://github.com/UKPLab/sentence-transformers/blob/1802076d4eae42ff0a5629e1b04e75785d4e193b/sentence_transformers/SentenceTransformer.py#L877
299309
def encode_multi_process(
300-
self,
301-
sentences: List[str],
302-
pool: Dict[Literal["input", "output", "processes"], Any],
303-
**kwargs
310+
self,
311+
sentences: List[str],
312+
pool: Dict[Literal["input", "output", "processes"], Any],
313+
**kwargs
304314
):
305315
chunk_size = math.ceil(len(sentences) / len(pool["processes"]))
306316

FlagEmbedding/abc/inference/AbsReranker.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
# tokenizer and model are initialized in the child class
5858
self.model = None
5959
self.tokenizer = None
60+
self.pool = None
6061

6162
@staticmethod
6263
def get_target_devices(devices: Union[str, int, List[str], List[int]]) -> List[str]:
@@ -137,13 +138,17 @@ def compute_score(
137138
**kwargs
138139
)
139140

140-
pool = self.start_multi_process_pool()
141+
if self.pool is None:
142+
self.pool = self.start_multi_process_pool()
141143
scores = self.encode_multi_process(sentence_pairs,
142-
pool,
144+
self.pool,
143145
**kwargs)
144-
self.stop_multi_process_pool(pool)
145146
return scores
146147

148+
def __del__(self):
149+
if self.pool is not None:
150+
self.stop_multi_process_pool(self.pool)
151+
147152
@abstractmethod
148153
def compute_score_single_gpu(
149154
self,

FlagEmbedding/inference/embedder/encoder_only/m3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,12 @@ def encode(
198198
Literal["dense_vecs", "lexical_weights", "colbert_vecs"],
199199
Union[np.ndarray, List[Dict[str, float]], List[np.ndarray]]
200200
]:
201+
if batch_size is None: batch_size = self.batch_size
202+
if max_length is None: max_length = self.passage_max_length
203+
if return_dense is None: return_dense = self.return_dense
204+
if return_sparse is None: return_sparse = self.return_sparse
205+
if return_colbert_vecs is None: return_colbert_vecs = self.return_colbert_vecs
206+
201207
return super().encode(
202208
queries,
203209
batch_size=batch_size,

0 commit comments

Comments
 (0)