Skip to content

Commit a14a2b5

Browse files
committed
update bge-en-icl
1 parent 9b6e521 commit a14a2b5

7 files changed

Lines changed: 1313 additions & 0 deletions

File tree

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
<div align="center">
2+
<h1> BGE-ICL </h1>
3+
</div>
4+
5+
**BGE-EN-ICL** primarily demonstrates the following capabilities:
6+
- In-context learning ability: By providing few-shot examples in the query, it can significantly enhance the model's ability to handle new tasks.
7+
- Outstanding performance: The model has achieved state-of-the-art (SOTA) performance on both BEIR and AIR-Bench.
8+
9+
## Environment
10+
```bash
11+
conda create icl python=3.10
12+
13+
conda activate icl
14+
15+
# You may need to adjust the cuda version
16+
conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
17+
pip install transformers==4.41.0 deepspeed accelerate datasets peft pandas
18+
pip install flash-attn --no-build-isolation
19+
```
20+
21+
## Model List
22+
23+
| Model | Introduction |
24+
| ------------------------------------------------------------ | ------------------------------------------------------------ |
25+
| [BAAI/bge-en-icl](https://huggingface.co/BAAI/bge-en-icl) | BGE-ICL trained on the full dataset |
26+
| BAAI/bge-en-icl-e5data | BGE-ICL trained on the same public dataset as e5-mistral |
27+
28+
## Data List
29+
30+
| Data | Introduction |
31+
| ------------------------------------------------------------ | ------------------------------------------------------------ |
32+
| [e5-data](https://huggingface.co/datasets/cfli/bge-e5data) | Public data identical to [e5-mistral](BGE-ICL trained on the) |
33+
| [full-data](https://huggingface.co/datasets/cfli/bge-full-data) | The full dataset we used for training |
34+
35+
## Usage
36+
37+
### Using FlagEmbedding
38+
```
39+
git clone https://github.com/FlagOpen/FlagEmbedding.git
40+
cd FlagEmbedding
41+
pip install -e .
42+
```
43+
44+
```python
45+
from FlagEmbedding import FlagICLModel
46+
queries = ["how much protein should a female eat", "summit define"]
47+
documents = [
48+
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
49+
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
50+
]
51+
examples = [
52+
{'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
53+
'query': 'what is a virtual interface',
54+
'response': "A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes."},
55+
{'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
56+
'query': 'causes of back pain in female for a week',
57+
'response': "Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management."}
58+
]
59+
model = FlagICLModel('BAAI/bge-en-icl',
60+
query_instruction_for_retrieval="Given a web search query, retrieve relevant passages that answer the query.",
61+
examples_for_task=examples, # set `examples_for_task=None` to use model without examples
62+
use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation
63+
embeddings_1 = model.encode_queries(queries)
64+
embeddings_2 = model.encode_corpus(documents)
65+
similarity = embeddings_1 @ embeddings_2.T
66+
print(similarity)
67+
```
68+
69+
By default, FlagICLModel will use all available GPUs when encoding. Please set `os.environ["CUDA_VISIBLE_DEVICES"]` to select specific GPUs.
70+
You also can set `os.environ["CUDA_VISIBLE_DEVICES"]=""` to make all GPUs unavailable.
71+
72+
73+
### Using HuggingFace Transformers
74+
75+
With the transformers package, you can use the model like this: First, you pass your input through the transformer model, then you select the last hidden state of the first token (i.e., [CLS]) as the sentence embedding.
76+
77+
```python
78+
import torch
79+
import torch.nn.functional as F
80+
81+
from torch import Tensor
82+
from transformers import AutoTokenizer, AutoModel
83+
84+
85+
def last_token_pool(last_hidden_states: Tensor,
86+
attention_mask: Tensor) -> Tensor:
87+
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
88+
if left_padding:
89+
return last_hidden_states[:, -1]
90+
else:
91+
sequence_lengths = attention_mask.sum(dim=1) - 1
92+
batch_size = last_hidden_states.shape[0]
93+
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
94+
95+
96+
def get_detailed_instruct(task_description: str, query: str) -> str:
97+
return f'<instruct>{task_description}\n<query>{query}'
98+
99+
def get_detailed_example(task_description: str, query: str, response: str) -> str:
100+
return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'
101+
102+
def get_new_queries(queries, query_max_len, examples_prefix, tokenizer):
103+
inputs = tokenizer(
104+
queries,
105+
max_length=query_max_len - len(tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(
106+
tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']),
107+
return_token_type_ids=False,
108+
truncation=True,
109+
return_tensors=None,
110+
add_special_tokens=False
111+
)
112+
prefix_ids = tokenizer(examples_prefix, add_special_tokens=False)['input_ids']
113+
suffix_ids = tokenizer('\n<response>', add_special_tokens=False)['input_ids']
114+
new_max_length = (len(prefix_ids) + len(suffix_ids) + query_max_len + 8) // 8 * 8 + 8
115+
new_queries = tokenizer.batch_decode(inputs['input_ids'])
116+
for i in range(len(new_queries)):
117+
new_queries[i] = examples_prefix + new_queries[i] + '\n<response>'
118+
return new_max_length, new_queries
119+
120+
task = 'Given a web search query, retrieve relevant passages that answer the query.'
121+
examples = [
122+
{'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
123+
'query': 'what is a virtual interface',
124+
'response': "A virtual interface is a software-defined abstraction that mimics the behavior and characteristics of a physical network interface. It allows multiple logical network connections to share the same physical network interface, enabling efficient utilization of network resources. Virtual interfaces are commonly used in virtualization technologies such as virtual machines and containers to provide network connectivity without requiring dedicated hardware. They facilitate flexible network configurations and help in isolating network traffic for security and management purposes."},
125+
{'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
126+
'query': 'causes of back pain in female for a week',
127+
'response': "Back pain in females lasting a week can stem from various factors. Common causes include muscle strain due to lifting heavy objects or improper posture, spinal issues like herniated discs or osteoporosis, menstrual cramps causing referred pain, urinary tract infections, or pelvic inflammatory disease. Pregnancy-related changes can also contribute. Stress and lack of physical activity may exacerbate symptoms. Proper diagnosis by a healthcare professional is crucial for effective treatment and management."}
128+
]
129+
examples = [get_detailed_example(e['instruct'], e['query'], e['response']) for e in examples]
130+
examples_prefix = '\n\n'.join(examples) + '\n\n' # if there not exists any examples, just set examples_prefix = ''
131+
queries = [
132+
get_detailed_instruct(task, 'how much protein should a female eat'),
133+
get_detailed_instruct(task, 'summit define')
134+
]
135+
documents = [
136+
"As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
137+
"Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
138+
]
139+
query_max_len, doc_max_len = 512, 512
140+
141+
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-en-icl')
142+
model = AutoModel.from_pretrained('BAAI/bge-en-icl')
143+
model.eval()
144+
145+
new_query_max_len, new_queries = get_new_queries(queries, query_max_len, examples_prefix, tokenizer)
146+
147+
query_batch_dict = tokenizer(new_queries, max_length=new_query_max_len, padding=True, truncation=True, return_tensors='pt')
148+
doc_batch_dict = tokenizer(documents, max_length=doc_max_len, padding=True, truncation=True, return_tensors='pt')
149+
150+
with torch.no_grad():
151+
query_outputs = model(**query_batch_dict)
152+
query_embeddings = last_token_pool(query_outputs.last_hidden_state, query_batch_dict['attention_mask'])
153+
doc_outputs = model(**doc_batch_dict)
154+
doc_embeddings = last_token_pool(doc_outputs.last_hidden_state, doc_batch_dict['attention_mask'])
155+
156+
# normalize embeddings
157+
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
158+
doc_embeddings = F.normalize(doc_embeddings, p=2, dim=1)
159+
scores = (query_embeddings @ doc_embeddings.T) * 100
160+
print(scores.tolist())
161+
```
162+
163+
## Fine-tune
164+
165+
Here is an example for fine-tune:
166+
```shell
167+
cd ./finetune
168+
torchrun --nproc_per_node 8 \
169+
run.py \
170+
--output_dir ./test \
171+
--model_name_or_path mistralai/Mistral-7B-v0.1 \
172+
--train_data cfli/bge-e5data \
173+
--learning_rate 1e-4 \
174+
--num_train_epochs 1 \
175+
--per_device_train_batch_size 16 \
176+
--lora_alpha 64 \
177+
--lora_rank 32 \
178+
--dataloader_drop_last True \
179+
--normlized True \
180+
--temperature 0.02 \
181+
--query_max_len 512 \
182+
--passage_max_len 512 \
183+
--train_group_size 8 \
184+
--logging_steps 1 \
185+
--save_steps 250 \
186+
--save_total_limit 20 \
187+
--ddp_find_unused_parameters False \
188+
--negatives_cross_device \
189+
--gradient_checkpointing \
190+
--deepspeed ../../LLARA/stage1.json \
191+
--warmup_steps 100 \
192+
--fp16 \
193+
--cache_dir ./cache/model_cache \
194+
--token ... \
195+
--cache_path ./cache/data_cache \
196+
--sub_batch_size 64 \
197+
--target_modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj \
198+
--use_special_tokens \
199+
--symmetric_batch_size 256 \
200+
--symmetric_train_group_size 8 \
201+
--max_class_neg 7
202+
```
203+
204+
## Citation
205+
206+
If you find this repository useful, please give us a star ⭐.
207+
208+
To cite our work:
209+
210+
```
211+
@misc{li2023makinglargelanguagemodels,
212+
title={Making Large Language Models A Better Foundation For Dense Retrieval},
213+
author={Chaofan Li and Zheng Liu and Shitao Xiao and Yingxia Shao},
214+
year={2023},
215+
eprint={2312.15503},
216+
archivePrefix={arXiv},
217+
primaryClass={cs.CL},
218+
url={https://arxiv.org/abs/2312.15503},
219+
}
220+
```
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import os
2+
from dataclasses import dataclass, field
3+
from typing import Optional, List
4+
5+
from transformers import TrainingArguments
6+
7+
8+
def default_list() -> List[int]:
9+
return ['v_proj', 'q_proj', 'k_proj', 'gate_proj', 'down_proj', 'o_proj', 'up_proj']
10+
11+
12+
@dataclass
13+
class ModelArguments:
14+
"""
15+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
16+
"""
17+
18+
model_name_or_path: str = field(
19+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
20+
)
21+
peft_model_path: str = field(
22+
default=''
23+
)
24+
config_name: Optional[str] = field(
25+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
26+
)
27+
tokenizer_name: Optional[str] = field(
28+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
29+
)
30+
use_lora: bool = field(
31+
default=True,
32+
metadata={"help": "If passed, will use LORA (low-rank parameter-efficient training) to train the model."}
33+
)
34+
lora_rank: int = field(
35+
default=64,
36+
metadata={"help": "The rank of lora."}
37+
)
38+
lora_alpha: float = field(
39+
default=16,
40+
metadata={"help": "The alpha parameter of lora."}
41+
)
42+
lora_dropout: float = field(
43+
default=0.1,
44+
metadata={"help": "The dropout rate of lora modules."}
45+
)
46+
target_modules: List[str] = field(
47+
default_factory=default_list
48+
)
49+
save_merged_lora_model: bool = field(
50+
default=False,
51+
metadata={"help": "If passed, will merge the lora modules and save the entire model."}
52+
)
53+
use_flash_attn: bool = field(
54+
default=True,
55+
metadata={"help": "If passed, will use flash attention to train the model."}
56+
)
57+
token: str = field(
58+
default="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"
59+
)
60+
cache_dir: str = field(
61+
default="/share/LMs"
62+
)
63+
from_peft: str = field(
64+
default=None
65+
)
66+
modules_to_save: str = field(
67+
default=None
68+
)
69+
raw_peft: str = field(
70+
default=None
71+
)
72+
73+
74+
@dataclass
75+
class DataArguments:
76+
train_data: str = field(
77+
default='cfli/bge-e5data',
78+
metadata={"help": "Path to train data"}
79+
)
80+
train_group_size: int = field(default=8)
81+
82+
query_max_len: int = field(
83+
default=32,
84+
metadata={
85+
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
86+
"than this will be truncated, sequences shorter will be padded."
87+
},
88+
)
89+
90+
passage_max_len: int = field(
91+
default=128,
92+
metadata={
93+
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
94+
"than this will be truncated, sequences shorter will be padded."
95+
},
96+
)
97+
98+
max_example_num_per_dataset: int = field(
99+
default=10000000000, metadata={"help": "the max number of examples for each dataset"}
100+
)
101+
102+
query_instruction_for_retrieval: str = field(
103+
default="query: ", metadata={"help": "query: "}
104+
)
105+
passage_instruction_for_retrieval: str = field(
106+
default="passage: ", metadata={"help": "passage: "}
107+
)
108+
109+
cache_path: str = field(
110+
default='./data_dir'
111+
)
112+
113+
load_from_disk: bool = field(
114+
default=False, metadata={"help": " whether load the data from disk"}
115+
)
116+
117+
load_disk_path: str = field(
118+
default=None, metadata={"help": " the path to load the data", "nargs": "+"}
119+
)
120+
121+
save_to_disk: bool = field(
122+
default=False, metadata={"help": " whether save the data to disk"}
123+
)
124+
125+
save_disk_path: str = field(
126+
default=None, metadata={"help": " the path to save the data"}
127+
)
128+
129+
num_shards: int = field(
130+
default=0, metadata={
131+
"help": "number of shards to write, prior than `save_max_shard_size`, default depends on `save_max_shard_size`"}
132+
)
133+
134+
save_max_shard_size: str = field(
135+
default="50GB", metadata={"help": "the max size of the shard"}
136+
)
137+
138+
exit_after_save: bool = field(
139+
default=False, metadata={"help": " whether exit after save the data"}
140+
)
141+
142+
shuffle_ratio: float = field(
143+
default=0.0, metadata={"help": "The ratio of shuffling the text"}
144+
)
145+
use_special_tokens: bool = field(default=True)
146+
nli_all_prompt: bool = field(default=True)
147+
symmetric_batch_size: int = field(default=128)
148+
symmetric_train_group_size: int = field(default=8)
149+
max_class_neg: int = field(default=1000)
150+
example_query_max_len: int = field(default=64)
151+
example_passage_max_len: int = field(default=96)
152+
retrieval_use_examples: bool = field(default=True)
153+
154+
@dataclass
155+
class RetrieverTrainingArguments(TrainingArguments):
156+
negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
157+
temperature: Optional[float] = field(default=0.02)
158+
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})
159+
sentence_pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})
160+
normlized: bool = field(default=True)
161+
sub_batch_size: int = field(default=None)
162+
cache_chunk_size: int = field(default=-1, metadata={"help": "用于缓存每一步的执行."})

0 commit comments

Comments
 (0)