|
| 1 | +# 1. Introduction |
| 2 | + |
| 3 | +In this example, we show how to **inference**, **finetune** and **evaluation** the baai-general-embedding. |
| 4 | + |
| 5 | +# 2. Installation |
| 6 | + |
| 7 | +* **with pip** |
| 8 | +```shell |
| 9 | +pip install -U FlagEmbedding |
| 10 | +``` |
| 11 | + |
| 12 | +* **from source** |
| 13 | +```shell |
| 14 | +git clone https://github.com/FlagOpen/FlagEmbedding.git |
| 15 | +cd FlagEmbedding |
| 16 | +pip install . |
| 17 | +``` |
| 18 | +For development, install as editable: |
| 19 | +```shell |
| 20 | +pip install -e . |
| 21 | +``` |
| 22 | + |
| 23 | +# 3. Inference |
| 24 | + |
| 25 | +We have provided the inference code for two models, namely the **embedder** and the **reranker**. These can be loaded using `FlagAutoModel` and `FlagAutoReranker`, respectively. For more detailed instructions on their use, please refer to the documentation for the [embedder](https://github.com/hanhainebula/FlagEmbedding/blob/new-flagembedding-v1/examples/inference/embedder) and [reranker](https://github.com/hanhainebula/FlagEmbedding/blob/new-flagembedding-v1/examples/inference/reranker). |
| 26 | + |
| 27 | +## 1. Embedder |
| 28 | + |
| 29 | +```python |
| 30 | +from FlagEmbedding import FlagAutoModel |
| 31 | +sentences_1 = ["样例数据-1", "样例数据-2"] |
| 32 | +sentences_2 = ["样例数据-3", "样例数据-4"] |
| 33 | +model = FlagAutoModel.from_finetuned('BAAI/bge-large-zh-v1.5', |
| 34 | + query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", |
| 35 | + use_fp16=True, |
| 36 | + devices=['cuda:1']) # Setting use_fp16 to True speeds up computation with a slight performance degradation |
| 37 | +embeddings_1 = model.encode_corpus(sentences_1) |
| 38 | +embeddings_2 = model.encode_corpus(sentences_2) |
| 39 | +similarity = embeddings_1 @ embeddings_2.T |
| 40 | +print(similarity) |
| 41 | + |
| 42 | +# for s2p(short query to long passage) retrieval task, suggest to use encode_queries() which will automatically add the instruction to each query |
| 43 | +# corpus in retrieval task can still use encode_corpus(), since they don't need instruction |
| 44 | +queries = ['query_1', 'query_2'] |
| 45 | +passages = ["样例文档-1", "样例文档-2"] |
| 46 | +q_embeddings = model.encode_queries(queries) |
| 47 | +p_embeddings = model.encode_corpus(passages) |
| 48 | +scores = q_embeddings @ p_embeddings.T |
| 49 | +print(scores) |
| 50 | +``` |
| 51 | + |
| 52 | +## 2. Reranker |
| 53 | + |
| 54 | +```python |
| 55 | +from FlagEmbedding import FlagAutoReranker |
| 56 | +pairs = [("样例数据-1", "样例数据-3"), ("样例数据-2", "样例数据-4")] |
| 57 | +model = FlagAutoReranker.from_finetuned('BAAI/bge-reranker-large', |
| 58 | + use_fp16=True, |
| 59 | + devices=['cuda:1']) # Setting use_fp16 to True speeds up computation with a slight performance degradation |
| 60 | +similarity = model.compute_score(pairs, normalize=True) |
| 61 | +print(similarity) |
| 62 | + |
| 63 | +pairs = [("query_1", "样例文档-1"), ("query_2", "样例文档-2")] |
| 64 | +scores = model.compute_score(pairs) |
| 65 | +print(scores) |
| 66 | +``` |
| 67 | + |
| 68 | +# 4. Finetune |
| 69 | + |
| 70 | +We support the finetune of various BGE series models, including bge-large-en-v1.5, bge-m3, bge-en-icl, bge-reranker-v2-m3, bge-reranker-v2-gemma, and bge-reranker-v2-minicpm-layerwise, etc. Here, we take the basic models bge-en-large-v1.5 and bge-reranker-large as examples. For more details, please see the [embedder](https://github.com/hanhainebula/FlagEmbedding/tree/new-flagembedding-v1/examples/finetune/embedder) and [reranker](https://github.com/hanhainebula/FlagEmbedding/tree/new-flagembedding-v1/examples/finetune/reranker) sections. |
| 71 | + |
| 72 | +## 1. Embedder |
| 73 | + |
| 74 | +```shell |
| 75 | +torchrun --nproc_per_node 2 \ |
| 76 | + -m FlagEmbedding.finetune.embedder.encoder_only.base \ |
| 77 | + --model_name_or_path BAAI/bge-large-en-v1.5 \ |
| 78 | + --cache_dir ./cache/model \ |
| 79 | + --train_data ./finetune/embedder/example_data/retrieval \ |
| 80 | + --cache_path ./cache/data \ |
| 81 | + --train_group_size 8 \ |
| 82 | + --query_max_len 512 \ |
| 83 | + --passage_max_len 512 \ |
| 84 | + --pad_to_multiple_of 8 \ |
| 85 | + --query_instruction_for_retrieval 'Represent this sentence for searching relevant passages: ' \ |
| 86 | + --query_instruction_format '{}{}' \ |
| 87 | + --knowledge_distillation False \ |
| 88 | + --output_dir ./test_encoder_only_base_bge-large-en-v1.5 \ |
| 89 | + --overwrite_output_dir \ |
| 90 | + --learning_rate 1e-5 \ |
| 91 | + --fp16 \ |
| 92 | + --num_train_epochs $num_train_epochs \ |
| 93 | + --per_device_train_batch_size $per_device_train_batch_size \ |
| 94 | + --dataloader_drop_last True \ |
| 95 | + --warmup_ratio 0.1 \ |
| 96 | + --gradient_checkpointing \ |
| 97 | + --deepspeed ./finetune/ds_stage0.json \ |
| 98 | + --logging_steps 1 \ |
| 99 | + --save_steps 1000 \ |
| 100 | + --negatives_cross_device \ |
| 101 | + --temperature 0.02 \ |
| 102 | + --sentence_pooling_method cls \ |
| 103 | + --normalize_embeddings True \ |
| 104 | + --kd_loss_type kl_div |
| 105 | +``` |
| 106 | + |
| 107 | +## 2. Reranker |
| 108 | + |
| 109 | +```shell |
| 110 | +torchrun --nproc_per_node 2 \ |
| 111 | + -m FlagEmbedding.finetune.reranker.encoder_only.base \ |
| 112 | + --model_name_or_path BAAI/bge-reranker-large \ |
| 113 | + --cache_dir ./cache/model \ |
| 114 | + --train_data ./finetune/reranker/example_data/normal/examples.jsonl \ |
| 115 | + --cache_path ~/.cache \ |
| 116 | + --train_group_size 8 \ |
| 117 | + --query_max_len 256 \ |
| 118 | + --passage_max_len 256 \ |
| 119 | + --pad_to_multiple_of 8 \ |
| 120 | + --knowledge_distillation True \ |
| 121 | + --output_dir ./test_encoder_only_base_bge-reranker-large \ |
| 122 | + --overwrite_output_dir \ |
| 123 | + --learning_rate 6e-5 \ |
| 124 | + --fp16 \ |
| 125 | + --num_train_epochs $num_train_epochs \ |
| 126 | + --per_device_train_batch_size 2 \ |
| 127 | + --gradient_accumulation_steps 1 \ |
| 128 | + --dataloader_drop_last True \ |
| 129 | + --warmup_ratio 0.1 \ |
| 130 | + --gradient_checkpointing \ |
| 131 | + --weight_decay 0.01 \ |
| 132 | + --deepspeed ./finetune/ds_stage0.json \ |
| 133 | + --logging_steps 1 \ |
| 134 | + --save_steps 1000 \ |
| 135 | +``` |
| 136 | + |
| 137 | +# 5. Evaluation |
| 138 | + |
| 139 | +We support evaluations on MTEB, BEIR, MSMARCO, MIRACL, MLDR, MKQA, and AIR-Bench. Here, we provide an example of evaluating MSMARCO passages. For more details, please refer to the [evaluation examples](https://github.com/hanhainebula/FlagEmbedding/tree/new-flagembedding-v1/examples/evaluation). |
| 140 | + |
| 141 | +```shell |
| 142 | +export HF_HUB_CACHE="$HOME/.cache/huggingface/hub" |
| 143 | + |
| 144 | +python -m FlagEmbedding.evaluation.msmarco \ |
| 145 | + --eval_name msmarco \ |
| 146 | + --dataset_dir ./data/msmarco \ |
| 147 | + --dataset_names passage \ |
| 148 | + --splits dev dl19 dl20 \ |
| 149 | + --corpus_embd_save_dir ./data/msmarco/corpus_embd \ |
| 150 | + --output_dir ./data/msmarco/search_results \ |
| 151 | + --search_top_k 1000 \ |
| 152 | + --rerank_top_k 100 \ |
| 153 | + --cache_path ./cache/data \ |
| 154 | + --overwrite True \ |
| 155 | + --k_values 10 100 \ |
| 156 | + --eval_output_method markdown \ |
| 157 | + --eval_output_path ./data/msmarco/msmarco_eval_results.md \ |
| 158 | + --eval_metrics ndcg_at_10 mrr_at_10 recall_at_100 \ |
| 159 | + --embedder_name_or_path BAAI/bge-large-en-v1.5 \ |
| 160 | + --embedder_batch_size 512 \ |
| 161 | + --embedder_query_max_length 512 \ |
| 162 | + --embedder_passage_max_length 512 \ |
| 163 | + --reranker_name_or_path BAAI/bge-reranker-v2-m3 \ |
| 164 | + --reranker_batch_size 512 \ |
| 165 | + --reranker_query_max_length 512 \ |
| 166 | + --reranker_max_length 1024 \ |
| 167 | + --devices cuda:0 cuda:1 cuda:2 cuda:3 cuda:4 cuda:5 cuda:6 cuda:7 \ |
| 168 | + --cache_dir ./cache/model |
| 169 | +``` |
| 170 | + |
0 commit comments