Skip to content

Commit c4103af

Browse files
committed
update embedder FT readme
1 parent 880ed8d commit c4103af

4 files changed

Lines changed: 322 additions & 1 deletion

File tree

FlagEmbedding/abc/evaluation/arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ class AbsEvalModelArgs:
8686
normalize_embeddings: bool = field(
8787
default=True, metadata={"help": "whether to normalize the embeddings"}
8888
)
89+
pooling_method: bool = field(
90+
default="cls", metadata={"help": "The pooling method fot the embedder."}
91+
)
8992
use_fp16: bool = field(
9093
default=True, metadata={"help": "whether to use fp16 for inference"}
9194
)

FlagEmbedding/abc/evaluation/runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def get_models(model_args: AbsEvalModelArgs) -> Tuple[FlagAutoModel, Union[FlagA
3232
model_name_or_path=model_args.embedder_name_or_path,
3333
model_class=model_args.embedder_model_class,
3434
normalize_embeddings=model_args.normalize_embeddings,
35+
pooling_method=model_args.pooling_method,
3536
use_fp16=model_args.use_fp16,
3637
query_instruction_for_retrieval=model_args.query_instruction_for_retrieval,
3738
query_instruction_format=model_args.query_instruction_format_for_retrieval,

examples/evaluation/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ First, we will introduce the commonly used parameters, followed by an introducti
5656
- **`embedder_name_or_path`**: The name or path to the embedder.
5757
- **`embedder_model_class`**: Class of the model used for embedding (options include 'auto', 'encoder-only-base', etc.). Default is `auto`.
5858
- **`normalize_embeddings`**: Set to `true` to normalize embeddings.
59+
- **`pooling_method`**: The pooling method for the embedder.
5960
- **`use_fp16`**: Use FP16 precision for inference.
6061
- **`devices`**: List of devices used for inference.
6162
- **`query_instruction_for_retrieval`**, **`query_instruction_format_for_retrieval`**: Instructions and format for query during retrieval.
@@ -342,4 +343,4 @@ python -m FlagEmbedding.evaluation.custom \
342343
--cache_dir ./cache/model \
343344
--reranker_query_max_length 512 \
344345
--reranker_max_length 1024
345-
```
346+
```
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
# Finetune
2+
3+
In this example, we show how to finetune the embedder with your data.
4+
5+
## 1. Installation
6+
7+
- **with pip**
8+
9+
```shell
10+
pip install -U FlagEmbedding
11+
```
12+
13+
- **from source**
14+
15+
```shell
16+
git clone https://github.com/FlagOpen/FlagEmbedding.git
17+
cd FlagEmbedding
18+
pip install .
19+
```
20+
21+
For development, install as editable:
22+
23+
```shell
24+
pip install -e .
25+
```
26+
27+
## 2. Data format
28+
29+
Train data should be a json file, where each line is a dict like this:
30+
31+
```shell
32+
{"query": str, "pos": List[str], "neg":List[str], "pos_scores": List[int], "neg_scores": List[int], "prompt": str, "type": str}
33+
```
34+
35+
`query` is the query, and `pos` is a list of positive texts, `neg` is a list of negative texts. `pos_scores` is a list of scores corresponding to the `query` and `pos`, `neg_scores` is a list of scores corresponding to the `query` and `neg`, if you don't use knowledge distillation, it can be ignored. `prompt` is the prompt used for the query, it will cover `query_instruction_for_retrieval`. `type` is used for `bge-en-icl`, it includes `normal`, `symmetric_class`, `symmetric_clustering`, .etc. If you have no negative texts for a query, you can random sample some from the entire corpus as the negatives.
36+
37+
See [example_data](https://github.com/hanhainebula/FlagEmbedding/tree/new-flagembedding-v1/examples/finetune/embedder/example_data) for more detailed files.
38+
39+
### Hard Negatives
40+
41+
Hard negatives is a widely used method to improve the quality of sentence embedding. You can mine hard negatives following this command:
42+
43+
```shell
44+
git clone https://github.com/FlagOpen/FlagEmbedding.git
45+
cd FlagEmbedding/scripts
46+
```
47+
48+
```shell
49+
python hn_mine.py \
50+
--model_name_or_path BAAI/bge-base-en-v1.5 \
51+
--input_file toy_finetune_data.jsonl \
52+
--output_file toy_finetune_data_minedHN.jsonl \
53+
--range_for_sampling 2-200 \
54+
--negative_number 15 \
55+
--use_gpu_for_searching
56+
```
57+
58+
- `input_file`: json data for finetuning. This script will retrieve top-k documents for each query, and random sample negatives from the top-k documents (not including the positive documents).
59+
- `output_file`: path to save JSON data with mined hard negatives for finetuning
60+
- `negative_number`: the number of sampled negatives
61+
- `range_for_sampling`: where to sample negative. For example, `2-100` means sampling `negative_number` negatives from top2-top200 documents. **You can set larger value to reduce the difficulty of negatives (e.g., set it `60-300` to sample negatives from top60-300 passages)**
62+
- `candidate_pool`: The pool to retrieval. The default value is None, and this script will retrieve from the combination of all `neg` in `input_file`. The format of this file is the same as [pretrain data](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/pretrain#2-data-format). If input a candidate_pool, this script will retrieve negatives from this file.
63+
- `use_gpu_for_searching`: whether to use faiss-gpu to retrieve negatives.
64+
65+
## 3. Train
66+
67+
Detailed examples of various fine-tuning can be found in the bash files located in the corresponding folders. Here, we simply provide the training methods for the `standard model`, `bge-m3`, `bge-multilingual-gemma2` and `bge-en-icl`.
68+
69+
Here are some import arguments:
70+
71+
- **`model_name_or_path`**: The model checkpoint for initialization.
72+
- **`config_name`**: Pretrained config name or path if not the same as model_name.
73+
- **`tokenizer_name`**: Pretrained tokenizer name or path if not the same as model_name.
74+
- **`cache_dir`**: Where do you want to store the pre-trained models downloaded from s3.
75+
- **`trust_remote_code`**: Trust remote code
76+
- **`token`**: The token to use when accessing the model.
77+
- **`train_data`**: One or more paths to training data. `query: str`, `pos: List[str]`, `neg: List[str]` are required in the training data. Argument type: multiple.
78+
- **`cache_path`**: Where do you want to store the cached data.
79+
- **`train_group_size`**: (No metadata provided)
80+
- **`query_max_len`**: The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated.
81+
- **`passage_max_len`**: The maximum total input sequence length after tokenization for passage. Sequences longer than this will be truncated.
82+
- **`pad_to_multiple_of`**: If set will pad the sequence to be a multiple of the provided value.
83+
- **`max_example_num_per_dataset`**: The max number of examples for each dataset.
84+
- **`query_instruction_for_retrieval`**: Instruction for query.
85+
- **`query_instruction_format`**: Format for query instruction.
86+
- **`knowledge_distillation`**: Use knowledge distillation when `pos_scores: List[float]` and `neg_scores: List[float]` are in features of training data.
87+
- **`passage_instruction_for_retrieval`**: Instruction for passage.
88+
- **`passage_instruction_format`**: Format for passage instruction.
89+
- **`shuffle_ratio`**: The ratio of shuffling the text.
90+
- **`same_dataset_within_batch`**: All samples in the same batch comes from the same dataset.
91+
- **`small_threshold`**: The threshold of small dataset. All small dataset in the same directory will be merged into one dataset.
92+
- **`drop_threshold`**: The threshold for dropping merged small dataset. If the number of examples in the merged small dataset is less than this threshold, it will be dropped.
93+
- **`negatives_cross_device`**: Share negatives across devices.
94+
- **`temperature`**: Temperature used for similarity score.
95+
- **`fix_position_embedding`**: Freeze the parameters of position embeddings.
96+
- **`sentence_pooling_method`**: The pooling method. Available options: cls, mean, last_token. Default: cls.
97+
- **`normalize_embeddings`**: Whether to normalize the embeddings.
98+
- **`sub_batch_size`**: Sub batch size for training.
99+
- **`kd_loss_type`**: The loss type for knowledge distillation. Available options: kl_div, m3_kd_loss. Default: kl_div.
100+
101+
### (1) standard model
102+
103+
```shell
104+
torchrun --nproc_per_node 2 \
105+
-m FlagEmbedding.finetune.embedder.encoder_only.base \
106+
--model_name_or_path BAAI/bge-large-en-v1.5 \
107+
--cache_dir ./cache/model \
108+
--train_data ./example_data/retrieval \
109+
./example_data/sts/sts.jsonl \
110+
./example_data/classification-no_in_batch_neg \
111+
./example_data/clustering-no_in_batch_neg \
112+
--cache_path ./cache/data \
113+
--train_group_size 8 \
114+
--query_max_len 512 \
115+
--passage_max_len 512 \
116+
--pad_to_multiple_of 8 \
117+
--query_instruction_for_retrieval 'Represent this sentence for searching relevant passages: ' \
118+
--query_instruction_format '{}{}' \
119+
--knowledge_distillation False \
120+
--output_dir ./test_encoder_only_base_bge-large-en-v1.5 \
121+
--overwrite_output_dir \
122+
--learning_rate 1e-5 \
123+
--fp16 \
124+
--num_train_epochs 2 \
125+
--per_device_train_batch_size 2 \
126+
--dataloader_drop_last True \
127+
--warmup_ratio 0.1 \
128+
--gradient_checkpointing \
129+
--deepspeed ../ds_stage0.json \
130+
--logging_steps 1 \
131+
--save_steps 1000 \
132+
--negatives_cross_device \
133+
--temperature 0.02 \
134+
--sentence_pooling_method cls \
135+
--normalize_embeddings True \
136+
--kd_loss_type kl_div
137+
```
138+
139+
### (2) bge-m3
140+
141+
```shell
142+
torchrun --nproc_per_node 2 \
143+
-m FlagEmbedding.finetune.embedder.encoder_only.m3 \
144+
--model_name_or_path BAAI/bge-m3 \
145+
--cache_dir ./cache/model \
146+
--train_data ./example_data/retrieval \
147+
./example_data/sts/sts.jsonl \
148+
./example_data/classification-no_in_batch_neg \
149+
./example_data/clustering-no_in_batch_neg \
150+
--cache_path ./cache/data \
151+
--train_group_size 8 \
152+
--query_max_len 512 \
153+
--passage_max_len 512 \
154+
--pad_to_multiple_of 8 \
155+
--knowledge_distillation True \
156+
--same_dataset_within_batch True \
157+
--small_threshold 0 \
158+
--drop_threshold 0 \
159+
--output_dir ./test_encoder_only_m3_bge-m3_sd \
160+
--overwrite_output_dir \
161+
--learning_rate 1e-5 \
162+
--fp16 \
163+
--num_train_epochs 2 \
164+
--per_device_train_batch_size 2 \
165+
--dataloader_drop_last True \
166+
--warmup_ratio 0.1 \
167+
--gradient_checkpointing \
168+
--deepspeed ../ds_stage0.json \
169+
--logging_steps 1 \
170+
--save_steps 1000 \
171+
--negatives_cross_device \
172+
--temperature 0.02 \
173+
--sentence_pooling_method cls \
174+
--normalize_embeddings True \
175+
--kd_loss_type m3_kd_loss \
176+
--unified_finetuning True \
177+
--use_self_distill True \
178+
--fix_encoder False \
179+
--self_distill_start_step 0
180+
```
181+
182+
Here are some new arguments:
183+
184+
- **`colbert_dim`**: Dim of colbert linear
185+
- **`unified_finetuning`**: Use unify fine-tuning
186+
- **`use_self_distill`**: Use self-distill when using unify fine-tuning
187+
- **`fix_encoder`**: Freeze the parameters of encoder
188+
- **`self_distill_start_step`**: Num of step when using self-distill
189+
190+
### (3) bge-multilingual-gemma2
191+
192+
```shell
193+
torchrun --nproc_per_node 2 \
194+
-m FlagEmbedding.finetune.embedder.decoder_only.base \
195+
--model_name_or_path BAAI/bge-multilingual-gemma2 \
196+
--cache_dir ./cache/model \
197+
--use_lora True \
198+
--lora_rank 32 \
199+
--lora_alpha 64 \
200+
--target_modules q_proj k_proj v_proj o_proj gate_proj down_proj up_proj \
201+
--additional_special_tokens '<instruct>' '<query>' \
202+
--save_merged_lora_model True \
203+
--train_data ./example_data/retrieval \
204+
./example_data/sts/sts.jsonl \
205+
./example_data/classification-no_in_batch_neg \
206+
./example_data/clustering-no_in_batch_neg \
207+
--cache_path ./cache/data \
208+
--train_group_size 8 \
209+
--query_max_len 512 \
210+
--passage_max_len 512 \
211+
--pad_to_multiple_of 8 \
212+
--query_instruction_for_retrieval 'Given a query, retrieve passages that are relevant to the query.' \
213+
--query_instruction_format '<instruct>{}\n<query>{}' \
214+
--knowledge_distillation True \
215+
--same_dataset_within_batch True \
216+
--small_threshold 0 \
217+
--drop_threshold 0 \
218+
--output_dir ./test_decoder_only_base_bge-multilingual-gemma2_sd \
219+
--overwrite_output_dir \
220+
--learning_rate 1e-4 \
221+
--fp16 \
222+
--num_train_epochs 1 \
223+
--per_device_train_batch_size 2 \
224+
--dataloader_drop_last True \
225+
--warmup_ratio 0.1 \
226+
--gradient_checkpointing \
227+
--deepspeed ../ds_stage1.json \
228+
--logging_steps 1 \
229+
--save_steps 1000 \
230+
--negatives_cross_device \
231+
--temperature 0.02 \
232+
--sentence_pooling_method last_token \
233+
--normalize_embeddings True \
234+
--kd_loss_type m3_kd_loss
235+
```
236+
237+
Here are some new arguments:
238+
239+
- **`peft_model_path`**: The peft model checkpoint for initialization.
240+
- **`use_lora`**: If passed, will use LORA (low-rank parameter-efficient training) to train the model.
241+
- **`lora_rank`**: The rank of lora.
242+
- **`lora_alpha`**: The alpha parameter of lora.
243+
- **`lora_dropout`**: The dropout rate of lora modules.
244+
- **`target_modules`**: The target modules to apply LORA.
245+
- **`use_flash_attn`**: If passed, will use flash attention to train the model.
246+
- **`use_slow_tokenizer`**: If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).
247+
- **`additional_special_tokens`**: Additional special tokens.
248+
- **`save_merged_lora_model`**: If passed, will merge the lora modules and save the entire model.
249+
250+
### (4) bge-en-icl
251+
252+
```shell
253+
torchrun --nproc_per_node 2 \
254+
-m FlagEmbedding.finetune.embedder.decoder_only.base \
255+
--model_name_or_path BAAI/bge-multilingual-gemma2 \
256+
--cache_dir ./cache/model \
257+
--use_lora True \
258+
--lora_rank 32 \
259+
--lora_alpha 64 \
260+
--target_modules q_proj k_proj v_proj o_proj gate_proj down_proj up_proj \
261+
--additional_special_tokens '<instruct>' '<query>' \
262+
--save_merged_lora_model True \
263+
--train_data ./example_data/retrieval \
264+
./example_data/sts/sts.jsonl \
265+
./example_data/classification-no_in_batch_neg \
266+
./example_data/clustering-no_in_batch_neg \
267+
--cache_path ./cache/data \
268+
--train_group_size 8 \
269+
--query_max_len 512 \
270+
--passage_max_len 512 \
271+
--pad_to_multiple_of 8 \
272+
--query_instruction_for_retrieval 'Given a query, retrieve passages that are relevant to the query.' \
273+
--query_instruction_format '<instruct>{}\n<query>{}' \
274+
--knowledge_distillation True \
275+
--same_dataset_within_batch True \
276+
--small_threshold 0 \
277+
--drop_threshold 0 \
278+
--output_dir ./test_decoder_only_base_bge-en-icl_sd \
279+
--overwrite_output_dir \
280+
--learning_rate 1e-4 \
281+
--fp16 \
282+
--num_train_epochs 1 \
283+
--per_device_train_batch_size 2 \
284+
--dataloader_drop_last True \
285+
--warmup_ratio 0.1 \
286+
--gradient_checkpointing \
287+
--deepspeed ../ds_stage1.json \
288+
--logging_steps 1 \
289+
--save_steps 1000 \
290+
--negatives_cross_device \
291+
--temperature 0.02 \
292+
--sentence_pooling_method last_token \
293+
--normalize_embeddings True \
294+
--kd_loss_type kl_div
295+
```
296+
297+
Here are some new arguments:
298+
299+
- **`peft_model_path`**: The peft model checkpoint for initialization.
300+
- **`use_lora`**: If passed, will use LORA (low-rank parameter-efficient training) to train the model.
301+
- **`lora_rank`**: The rank of LORA.
302+
- **`lora_alpha`**: The alpha parameter of LORA.
303+
- **`lora_dropout`**: The dropout rate of LORA modules.
304+
- **`target_modules`**: The target modules to apply LORA.
305+
- **`use_flash_attn`**: If passed, will use flash attention to train the model.
306+
- **`use_slow_tokenizer`**: If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).
307+
- **`from_peft`** (no metadata provided)
308+
- **`modules_to_save`** (no metadata provided)
309+
- **`raw_peft`** (no metadata provided)
310+
- **`additional_special_tokens`**: additional special tokens
311+
- **`save_merged_lora_model`**: If passed, will merge the LORA modules and save the entire model.
312+
- **`example_query_max_len`**: The max length of example query.
313+
- **`example_passage_max_len`**: The max length of example passage.
314+
- **`retrieval_use_examples`**: If passed, will use examples for retrieval.
315+
- **`icl_suffix_str`**: The suffix string for ICL dataset.
316+

0 commit comments

Comments
 (0)