|
| 1 | +<div align="center"> |
| 2 | +<h1> Llama2Vec: Unsupervised Adaptation of Large Language Models for Dense Retrieval (LLARA) [<a href="https://arxiv.org/abs/2312.15503">paper</a>]</h1> |
| 3 | +</div> |
| 4 | + |
| 5 | +Llama2Vec consists of two pretext tasks: |
| 6 | +- **EBAE** (Embedding-Based Auto-Encoding) |
| 7 | +- **EBAR** (Embedding-Based Auto-Regression) |
| 8 | + |
| 9 | +The LLM is prompted to **reconstruct the input sentence** and **predict the next sentence** based on its text embeddings. |
| 10 | + |
| 11 | +It is known for the following features: |
| 12 | +- simple |
| 13 | +- lightweight |
| 14 | +- highly effective |
| 15 | + |
| 16 | +## Environment |
| 17 | +```bash |
| 18 | +conda create llara python=3.10 |
| 19 | + |
| 20 | +conda activate llara |
| 21 | + |
| 22 | +# You may need to adjust the cuda version |
| 23 | +conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia |
| 24 | +pip install transformers==4.41.0 deepspeed accelerate datasets peft pandas |
| 25 | +pip install flash-attn --no-build-isolation |
| 26 | +``` |
| 27 | + |
| 28 | +## Model List |
| 29 | + |
| 30 | +| Model | Introduction | |
| 31 | +| ------------------------------------------------------------ | ------------------------------------------------------------ | |
| 32 | +| [BAAI/LLARA-pretrain](https://huggingface.co/BAAI/LLARA-pretrain) | LLARA that has undergone unsupervised adaptation on Wikipedia | |
| 33 | +| [BAAI/LLARA-passage](https://huggingface.co/BAAI/LLARA-passage) | The LLARA-pretrain model fine-tuned on MS MARCO passage (the hard negatives come from dense retriever) | |
| 34 | +| [BAAI/LLARA-document](https://huggingface.co/BAAI/LLARA-document) | The LLARA-pretrain model fine-tuned on MS MARCO document | |
| 35 | +| [BAAI/LLARA-beir](https://huggingface.co/BAAI/LLARA-beir) | The LLARA-pretrain model fine-tuned on MS MARCO passage (the hard negatives come from BM25) | |
| 36 | + |
| 37 | +## Usage |
| 38 | + |
| 39 | +```python |
| 40 | +import torch |
| 41 | +from transformers import AutoModel, AutoTokenizer, LlamaModel |
| 42 | + |
| 43 | +def get_query_inputs(queries, tokenizer, max_length=512): |
| 44 | + prefix = '"' |
| 45 | + suffix = '", predict the following passage within eight words: <s9><s10><s11><s12><s13><s14><s15><s16>' |
| 46 | + prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids'] |
| 47 | + suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:] |
| 48 | + queries_inputs = [] |
| 49 | + for query in queries: |
| 50 | + inputs = tokenizer(query, |
| 51 | + return_tensors=None, |
| 52 | + max_length=max_length, |
| 53 | + truncation=True, |
| 54 | + add_special_tokens=False) |
| 55 | + inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids |
| 56 | + inputs['attention_mask'] = [1] * len(inputs['input_ids']) |
| 57 | + queries_inputs.append(inputs) |
| 58 | + return tokenizer.pad( |
| 59 | + queries_inputs, |
| 60 | + padding=True, |
| 61 | + max_length=max_length, |
| 62 | + pad_to_multiple_of=8, |
| 63 | + return_tensors='pt', |
| 64 | + ) |
| 65 | + |
| 66 | +def get_passage_inputs(passages, tokenizer, max_length=512): |
| 67 | + prefix = '"' |
| 68 | + suffix = '", summarize the above passage within eight words: <s1><s2><s3><s4><s5><s6><s7><s8>' |
| 69 | + prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids'] |
| 70 | + suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:] |
| 71 | + passages_inputs = [] |
| 72 | + for passage in passages: |
| 73 | + inputs = tokenizer(passage, |
| 74 | + return_tensors=None, |
| 75 | + max_length=max_length, |
| 76 | + truncation=True, |
| 77 | + add_special_tokens=False) |
| 78 | + inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids |
| 79 | + inputs['attention_mask'] = [1] * len(inputs['input_ids']) |
| 80 | + passages_inputs.append(inputs) |
| 81 | + return tokenizer.pad( |
| 82 | + passages_inputs, |
| 83 | + padding=True, |
| 84 | + max_length=max_length, |
| 85 | + pad_to_multiple_of=8, |
| 86 | + return_tensors='pt', |
| 87 | + ) |
| 88 | + |
| 89 | +# Load the tokenizer and model |
| 90 | +tokenizer = AutoTokenizer.from_pretrained('BAAI/LLARA-passage') |
| 91 | +model = AutoModel.from_pretrained('BAAI/LLARA-passage') |
| 92 | + |
| 93 | +# Define query and passage inputs |
| 94 | +query = "What is llama?" |
| 95 | +title = "Llama" |
| 96 | +passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era." |
| 97 | +query_input = get_query_inputs([query], tokenizer) |
| 98 | +passage_input = get_passage_inputs([passage], tokenizer) |
| 99 | + |
| 100 | + |
| 101 | +with torch.no_grad(): |
| 102 | + # compute query embedding |
| 103 | + query_outputs = model(**query_input, return_dict=True, output_hidden_states=True) |
| 104 | + query_embedding = query_outputs.hidden_states[-1][:, -8:, :] |
| 105 | + query_embedding = torch.mean(query_embedding, dim=1) |
| 106 | + query_embedding = torch.nn.functional.normalize(query_embedding, dim=-1) |
| 107 | + |
| 108 | + # compute passage embedding |
| 109 | + passage_outputs = model(**passage_input, return_dict=True, output_hidden_states=True) |
| 110 | + passage_embeddings = passage_outputs.hidden_states[-1][:, -8:, :] |
| 111 | + passage_embeddings = torch.mean(passage_embeddings, dim=1) |
| 112 | + passage_embeddings = torch.nn.functional.normalize(passage_embeddings, dim=-1) |
| 113 | + |
| 114 | + # compute similarity score |
| 115 | + score = query_embedding @ passage_embeddings.T |
| 116 | + print(score) |
| 117 | + |
| 118 | +``` |
| 119 | + |
| 120 | +## Unsupervised Adaption (pretrain) |
| 121 | +1. You can get the complete data here: [cfli/pretrain_wiki](https://huggingface.co/datasets/cfli/pretrain_wiki) |
| 122 | +2. Here is an example for pretrain: |
| 123 | +```shell |
| 124 | +cd ./pretrain |
| 125 | +torchrun --nproc_per_node 8 \ |
| 126 | +run.py \ |
| 127 | +--output_dir ./output \ |
| 128 | +--model_name_or_path meta-llama/Llama-2-7b-hf \ |
| 129 | +--train_data ../data/pretrain/toy_pretrain_data.jsonl \ |
| 130 | +--learning_rate 1e-5 \ |
| 131 | +--num_train_epochs 1 \ |
| 132 | +--per_device_train_batch_size 1 \ |
| 133 | +--gradient_accumulation_steps 1 \ |
| 134 | +--dataloader_drop_last True \ |
| 135 | +--cutoff_len 128 \ |
| 136 | +--logging_steps 1 \ |
| 137 | +--save_steps 500 \ |
| 138 | +--save_total_limit 20 \ |
| 139 | +--gradient_checkpointing \ |
| 140 | +--ddp_find_unused_parameters False \ |
| 141 | +--use_flash_attn False \ |
| 142 | +--deepspeed ../stage1.json \ |
| 143 | +--warmup_ratio 0.1 \ |
| 144 | +--remove_stop_words True \ |
| 145 | +--use_lora False \ |
| 146 | +--bf16 \ |
| 147 | +--cache_dir ./LMs \ |
| 148 | +--token ... |
| 149 | +``` |
| 150 | +If you want to pretrain based on the complete data, please use hype-parameters in our paper. |
| 151 | + |
| 152 | +## Fine-tune |
| 153 | + |
| 154 | +Here is an example for fine-tune: |
| 155 | +```shell |
| 156 | +cd ./finetune |
| 157 | +torchrun --nproc_per_node 8 \ |
| 158 | +run.py \ |
| 159 | +--output_dir ./output \ |
| 160 | +--model_name_or_path BAAI/LLARA-pretrain \ |
| 161 | +--train_data ../data/finetune/toy_finetune_data.jsonl \ |
| 162 | +--learning_rate 3e-4 \ |
| 163 | +--num_train_epochs 1 \ |
| 164 | +--per_device_train_batch_size 1 \ |
| 165 | +--dataloader_drop_last True \ |
| 166 | +--normlized True \ |
| 167 | +--temperature 0.01 \ |
| 168 | +--query_max_len 64 \ |
| 169 | +--passage_max_len 160 \ |
| 170 | +--train_group_size 16 \ |
| 171 | +--logging_steps 10 \ |
| 172 | +--save_steps 500 \ |
| 173 | +--save_total_limit 3 \ |
| 174 | +--ddp_find_unused_parameters False \ |
| 175 | +--negatives_cross_device \ |
| 176 | +--gradient_checkpointing \ |
| 177 | +--deepspeed ../stage1.json \ |
| 178 | +--warmup_ratio 0.1 \ |
| 179 | +--fp16 \ |
| 180 | +--cache_dir ./LMs \ |
| 181 | +--token ... |
| 182 | +``` |
| 183 | + |
| 184 | +## Citation |
| 185 | + |
| 186 | +If you find this repository useful, please give us a star ⭐. |
| 187 | + |
| 188 | +To cite our work: |
| 189 | + |
| 190 | +``` |
| 191 | +@misc{li2023makinglargelanguagemodels, |
| 192 | + title={Making Large Language Models A Better Foundation For Dense Retrieval}, |
| 193 | + author={Chaofan Li and Zheng Liu and Shitao Xiao and Yingxia Shao}, |
| 194 | + year={2023}, |
| 195 | + eprint={2312.15503}, |
| 196 | + archivePrefix={arXiv}, |
| 197 | + primaryClass={cs.CL}, |
| 198 | + url={https://arxiv.org/abs/2312.15503}, |
| 199 | +} |
| 200 | +``` |
0 commit comments