Skip to content

Commit c8c4fad

Browse files
committed
Merge branch 'new-flagembedding-v1' of https://github.com/hanhainebula/FlagEmbedding into new-flagembedding-v1
2 parents 6525bcf + 4c75692 commit c8c4fad

1 file changed

Lines changed: 164 additions & 16 deletions

File tree

examples/inference/embedder/README.md

Lines changed: 164 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -244,38 +244,186 @@ print(scores)
244244

245245
### Using HuggingFace Transformers
246246

247-
With the transformers package, you can use the model like this: First, you pass your input through the transformer model, then you select the last hidden state of the first token (i.e., [CLS]) as the sentence embedding.
247+
#### 1. Normal Model
248+
249+
It supports `BAAI/bge-large-en-v1.5`, `BAAI/bge-base-en-v1.5`, `BAAI/bge-small-en-v1.5`, `BAAI/bge-large-zh-v1.5`, `BAAI/bge-base-zh-v1.5`, `BAAI/bge-small-zh-v1.5`, `BAAI/bge-large-en`, `BAAI/bge-base-en`, `BAAI/bge-small-en`, `BAAI/bge-large-zh`, `BAAI/bge-base-zh`, `BAAI/bge-small-zh'`, the **dense method** of `BAAI/bge-m3`:
248250

249251
```python
250-
from transformers import AutoTokenizer, AutoModel
251252
import torch
252-
# Sentences we want sentence embeddings for
253-
sentences = ["样例数据-1", "样例数据-2"]
253+
from transformers import AutoModel, AutoTokenizer
254254

255-
# Load model from HuggingFace Hub
256255
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh-v1.5')
257256
model = AutoModel.from_pretrained('BAAI/bge-large-zh-v1.5')
258257
model.eval()
259258

260-
# Tokenize sentences
261-
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
262-
# for s2p(short query to long passage) retrieval task, add an instruction to query (not add instruction for passages)
263-
# encoded_input = tokenizer([instruction + q for q in queries], padding=True, truncation=True, return_tensors='pt')
259+
sentences_1 = ["样例数据-1", "样例数据-2"]
260+
sentences_2 = ["样例数据-3", "样例数据-4"]
261+
with torch.no_grad():
262+
encoded_input_1 = tokenizer(sentences_1, padding=True, truncation=True, return_tensors='pt')
263+
encoded_input_2 = tokenizer(sentences_2, padding=True, truncation=True, return_tensors='pt')
264+
model_output_1 = model(**encoded_input_1)
265+
model_output_2 = model(**encoded_input_2)
266+
embeddings_1 = model_output_1[0][:, 0]
267+
embeddings_2 = model_output_2[0][:, 0]
268+
similarity = embeddings_1 @ embeddings_2.T
269+
print(similarity)
270+
```
271+
272+
#### 2. M3 Model
273+
274+
It only supports the **dense method** of `BAAI/bge-m3`, you can refer to the above code.
275+
276+
#### 3. LLM-based Model
277+
278+
It supports `BAAI/bge-multilingual-gemma2`:
279+
280+
```python
281+
import torch
282+
import torch.nn.functional as F
283+
284+
from torch import Tensor
285+
from transformers import AutoTokenizer, AutoModel
286+
287+
288+
def last_token_pool(last_hidden_states: Tensor,
289+
attention_mask: Tensor) -> Tensor:
290+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
291+
if left_padding:
292+
return last_hidden_states[:, -1]
293+
else:
294+
sequence_lengths = attention_mask.sum(dim=1) - 1
295+
batch_size = last_hidden_states.shape[0]
296+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
297+
298+
299+
def get_detailed_instruct(task_description: str, query: str) -> str:
300+
return f'<instruct>{task_description}\n<query>{query}'
301+
302+
303+
task = 'Given a web search query, retrieve relevant passages that answer the query.'
304+
queries = [
305+
get_detailed_instruct(task, 'how much protein should a female eat'),
306+
get_detailed_instruct(task, 'summit define')
307+
]
308+
# No need to add instructions for documents
309+
documents = [
310+
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
311+
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
312+
]
313+
input_texts = queries + documents
314+
315+
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-multilingual-gemma2')
316+
model = AutoModel.from_pretrained('BAAI/bge-multilingual-gemma2')
317+
model.eval()
318+
319+
max_length = 4096
320+
# Tokenize the input texts
321+
batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8)
264322

265-
# Compute token embeddings
266323
with torch.no_grad():
267-
model_output = model(**encoded_input)
268-
# Perform pooling. In this case, cls pooling.
269-
sentence_embeddings = model_output[0][:, 0]
324+
outputs = model(**batch_dict)
325+
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
326+
270327
# normalize embeddings
271-
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
272-
print("Sentence embeddings:", sentence_embeddings)
328+
embeddings = F.normalize(embeddings, p=2, dim=1)
329+
scores = (embeddings[:2] @ embeddings[2:].T) * 100
330+
print(scores.tolist())
331+
# [[55.92064666748047, 1.6549524068832397], [-0.2698777914047241, 49.95653533935547]]
273332
```
274333

334+
#### 4. LLM-based ICL Model
335+
336+
It supports `BAAI/bge-en-icl`:
337+
338+
```python
339+
import torch
340+
import torch.nn.functional as F
341+
342+
from torch import Tensor
343+
from transformers import AutoTokenizer, AutoModel
344+
345+
346+
def last_token_pool(last_hidden_states: Tensor,
347+
attention_mask: Tensor) -> Tensor:
348+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
349+
if left_padding:
350+
return last_hidden_states[:, -1]
351+
else:
352+
sequence_lengths = attention_mask.sum(dim=1) - 1
353+
batch_size = last_hidden_states.shape[0]
354+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
355+
356+
357+
def get_detailed_instruct(task_description: str, query: str) -> str:
358+
return f'<instruct>{task_description}\n<query>{query}'
359+
360+
def get_detailed_example(task_description: str, query: str, response: str) -> str:
361+
return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'
362+
363+
def get_new_queries(queries, query_max_len, examples_prefix, tokenizer):
364+
inputs = tokenizer(
365+
queries,
366+
max_length=query_max_len - len(tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(
367+
tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']),
368+
return_token_type_ids=False,
369+
truncation=True,
370+
return_tensors=None,
371+
add_special_tokens=False
372+
)
373+
prefix_ids = tokenizer(examples_prefix, add_special_tokens=False)['input_ids']
374+
suffix_ids = tokenizer('\n<response>', add_special_tokens=False)['input_ids']
375+
new_max_length = (len(prefix_ids) + len(suffix_ids) + query_max_len + 8) // 8 * 8 + 8
376+
new_queries = tokenizer.batch_decode(inputs['input_ids'])
377+
for i in range(len(new_queries)):
378+
new_queries[i] = examples_prefix + new_queries[i] + '\n<response>'
379+
return new_max_length, new_queries
380+
381+
task = 'Given a web search query, retrieve relevant passages that answer the query.'
382+
examples = [
383+
{'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
384+
'query': 'what is a virtual interface',
385+
'response': "A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes."},
386+
{'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
387+
'query': 'causes of back pain in female for a week',
388+
'response': "Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management."}
389+
]
390+
examples = [get_detailed_example(e['instruct'], e['query'], e['response']) for e in examples]
391+
examples_prefix = '\n\n'.join(examples) + '\n\n' # if there not exists any examples, just set examples_prefix = ''
392+
queries = [
393+
get_detailed_instruct(task, 'how much protein should a female eat'),
394+
get_detailed_instruct(task, 'summit define')
395+
]
396+
documents = [
397+
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
398+
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
399+
]
400+
query_max_len, doc_max_len = 512, 512
401+
402+
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-en-icl')
403+
model = AutoModel.from_pretrained('BAAI/bge-en-icl')
404+
model.eval()
405+
406+
new_query_max_len, new_queries = get_new_queries(queries, query_max_len, examples_prefix, tokenizer)
407+
408+
query_batch_dict = tokenizer(new_queries, max_length=new_query_max_len, padding=True, truncation=True, return_tensors='pt')
409+
doc_batch_dict = tokenizer(documents, max_length=doc_max_len, padding=True, truncation=True, return_tensors='pt')
410+
411+
with torch.no_grad():
412+
query_outputs = model(**query_batch_dict)
413+
query_embeddings = last_token_pool(query_outputs.last_hidden_state, query_batch_dict['attention_mask'])
414+
doc_outputs = model(**doc_batch_dict)
415+
doc_embeddings = last_token_pool(doc_outputs.last_hidden_state, doc_batch_dict['attention_mask'])
416+
417+
# normalize embeddings
418+
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
419+
doc_embeddings = F.normalize(doc_embeddings, p=2, dim=1)
420+
scores = (query_embeddings @ doc_embeddings.T) * 100
421+
print(scores.tolist())
422+
```
275423

276424
### Using Sentence-Transformers
277425

278-
You can also use the `bge` models with [sentence-transformers](https://www.sbert.net/):
426+
You can also use the `bge` models with [sentence-transformers](https://www.sbert.net/). It currently supports `BAAI/bge-large-en-v1.5`, `BAAI/bge-base-en-v1.5`, `BAAI/bge-small-en-v1.5`, `BAAI/bge-large-zh-v1.5`, `BAAI/bge-base-zh-v1.5`, `BAAI/bge-small-zh-v1.5`, `BAAI/bge-large-en`, `BAAI/bge-base-en`, `BAAI/bge-small-en`, `BAAI/bge-large-zh`, `BAAI/bge-base-zh`, `BAAI/bge-small-zh'`, the **dense method** of `BAAI/bge-m3`, `BAAI/bge-multilingual-gemma2`:
279427

280428
```
281429
pip install -U sentence-transformers

0 commit comments

Comments
 (0)