Skip to content

Commit e1a3b6a

Browse files
committed
upload LLARA code
1 parent 5c92602 commit e1a3b6a

18 files changed

Lines changed: 1917 additions & 0 deletions

LLARA/README.md

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
<div align="center">
2+
<h1> Llama2Vec: Unsupervised Adaptation of Large Language Models for Dense Retrieval (LLARA) [<a href="https://arxiv.org/abs/2312.15503">paper</a>]</h1>
3+
</div>
4+
5+
Llama2Vec consists of two pretext tasks:
6+
- **EBAE** (Embedding-Based Auto-Encoding)
7+
- **EBAR** (Embedding-Based Auto-Regression)
8+
9+
The LLM is prompted to **reconstruct the input sentence** and **predict the next sentence** based on its text embeddings.
10+
11+
It is known for the following features:
12+
- simple
13+
- lightweight
14+
- highly effective
15+
16+
## Environment
17+
```bash
18+
conda create llara python=3.10
19+
20+
conda activate llara
21+
22+
# You may need to adjust the cuda version
23+
conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
24+
pip install transformers==4.41.0 deepspeed accelerate datasets peft pandas
25+
pip install flash-attn --no-build-isolation
26+
```
27+
28+
## Model List
29+
30+
| Model | Introduction |
31+
| ------------------------------------------------------------ | ------------------------------------------------------------ |
32+
| [BAAI/LLARA-pretrain](https://huggingface.co/BAAI/LLARA-pretrain) | LLARA that has undergone unsupervised adaptation on Wikipedia |
33+
| [BAAI/LLARA-passage](https://huggingface.co/BAAI/LLARA-passage) | The LLARA-pretrain model fine-tuned on MS MARCO passage (the hard negatives come from dense retriever) |
34+
| [BAAI/LLARA-document](https://huggingface.co/BAAI/LLARA-document) | The LLARA-pretrain model fine-tuned on MS MARCO document |
35+
| [BAAI/LLARA-beir](https://huggingface.co/BAAI/LLARA-beir) | The LLARA-pretrain model fine-tuned on MS MARCO passage (the hard negatives come from BM25) |
36+
37+
## Usage
38+
39+
```python
40+
import torch
41+
from transformers import AutoModel, AutoTokenizer, LlamaModel
42+
43+
def get_query_inputs(queries, tokenizer, max_length=512):
44+
prefix = '"'
45+
suffix = '", predict the following passage within eight words: <s9><s10><s11><s12><s13><s14><s15><s16>'
46+
prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
47+
suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
48+
queries_inputs = []
49+
for query in queries:
50+
inputs = tokenizer(query,
51+
return_tensors=None,
52+
max_length=max_length,
53+
truncation=True,
54+
add_special_tokens=False)
55+
inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
56+
inputs['attention_mask'] = [1] * len(inputs['input_ids'])
57+
queries_inputs.append(inputs)
58+
return tokenizer.pad(
59+
queries_inputs,
60+
padding=True,
61+
max_length=max_length,
62+
pad_to_multiple_of=8,
63+
return_tensors='pt',
64+
)
65+
66+
def get_passage_inputs(passages, tokenizer, max_length=512):
67+
prefix = '"'
68+
suffix = '", summarize the above passage within eight words: <s1><s2><s3><s4><s5><s6><s7><s8>'
69+
prefix_ids = tokenizer(prefix, return_tensors=None)['input_ids']
70+
suffix_ids = tokenizer(suffix, return_tensors=None)['input_ids'][1:]
71+
passages_inputs = []
72+
for passage in passages:
73+
inputs = tokenizer(passage,
74+
return_tensors=None,
75+
max_length=max_length,
76+
truncation=True,
77+
add_special_tokens=False)
78+
inputs['input_ids'] = prefix_ids + inputs['input_ids'] + suffix_ids
79+
inputs['attention_mask'] = [1] * len(inputs['input_ids'])
80+
passages_inputs.append(inputs)
81+
return tokenizer.pad(
82+
passages_inputs,
83+
padding=True,
84+
max_length=max_length,
85+
pad_to_multiple_of=8,
86+
return_tensors='pt',
87+
)
88+
89+
# Load the tokenizer and model
90+
tokenizer = AutoTokenizer.from_pretrained('BAAI/LLARA-passage')
91+
model = AutoModel.from_pretrained('BAAI/LLARA-passage')
92+
93+
# Define query and passage inputs
94+
query = "What is llama?"
95+
title = "Llama"
96+
passage = "The llama is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era."
97+
query_input = get_query_inputs([query], tokenizer)
98+
passage_input = get_passage_inputs([passage], tokenizer)
99+
100+
101+
with torch.no_grad():
102+
# compute query embedding
103+
query_outputs = model(**query_input, return_dict=True, output_hidden_states=True)
104+
query_embedding = query_outputs.hidden_states[-1][:, -8:, :]
105+
query_embedding = torch.mean(query_embedding, dim=1)
106+
query_embedding = torch.nn.functional.normalize(query_embedding, dim=-1)
107+
108+
# compute passage embedding
109+
passage_outputs = model(**passage_input, return_dict=True, output_hidden_states=True)
110+
passage_embeddings = passage_outputs.hidden_states[-1][:, -8:, :]
111+
passage_embeddings = torch.mean(passage_embeddings, dim=1)
112+
passage_embeddings = torch.nn.functional.normalize(passage_embeddings, dim=-1)
113+
114+
# compute similarity score
115+
score = query_embedding @ passage_embeddings.T
116+
print(score)
117+
118+
```
119+
120+
## Unsupervised Adaption (pretrain)
121+
1. You can get the complete data here: [cfli/pretrain_wiki](https://huggingface.co/datasets/cfli/pretrain_wiki)
122+
2. Here is an example for pretrain:
123+
```shell
124+
cd ./pretrain
125+
torchrun --nproc_per_node 8 \
126+
run.py \
127+
--output_dir ./output \
128+
--model_name_or_path meta-llama/Llama-2-7b-hf \
129+
--train_data ../data/pretrain/toy_pretrain_data.jsonl \
130+
--learning_rate 1e-5 \
131+
--num_train_epochs 1 \
132+
--per_device_train_batch_size 1 \
133+
--gradient_accumulation_steps 1 \
134+
--dataloader_drop_last True \
135+
--cutoff_len 128 \
136+
--logging_steps 1 \
137+
--save_steps 500 \
138+
--save_total_limit 20 \
139+
--gradient_checkpointing \
140+
--ddp_find_unused_parameters False \
141+
--use_flash_attn False \
142+
--deepspeed ../stage1.json \
143+
--warmup_ratio 0.1 \
144+
--remove_stop_words True \
145+
--use_lora False \
146+
--bf16 \
147+
--cache_dir ./LMs \
148+
--token ...
149+
```
150+
If you want to pretrain based on the complete data, please use hype-parameters in our paper.
151+
152+
## Fine-tune
153+
154+
Here is an example for fine-tune:
155+
```shell
156+
cd ./finetune
157+
torchrun --nproc_per_node 8 \
158+
run.py \
159+
--output_dir ./output \
160+
--model_name_or_path BAAI/LLARA-pretrain \
161+
--train_data ../data/finetune/toy_finetune_data.jsonl \
162+
--learning_rate 3e-4 \
163+
--num_train_epochs 1 \
164+
--per_device_train_batch_size 1 \
165+
--dataloader_drop_last True \
166+
--normlized True \
167+
--temperature 0.01 \
168+
--query_max_len 64 \
169+
--passage_max_len 160 \
170+
--train_group_size 16 \
171+
--logging_steps 10 \
172+
--save_steps 500 \
173+
--save_total_limit 3 \
174+
--ddp_find_unused_parameters False \
175+
--negatives_cross_device \
176+
--gradient_checkpointing \
177+
--deepspeed ../stage1.json \
178+
--warmup_ratio 0.1 \
179+
--fp16 \
180+
--cache_dir ./LMs \
181+
--token ...
182+
```
183+
184+
## Citation
185+
186+
If you find this repository useful, please give us a star ⭐.
187+
188+
To cite our work:
189+
190+
```
191+
@misc{li2023makinglargelanguagemodels,
192+
title={Making Large Language Models A Better Foundation For Dense Retrieval},
193+
author={Chaofan Li and Zheng Liu and Shitao Xiao and Yingxia Shao},
194+
year={2023},
195+
eprint={2312.15503},
196+
archivePrefix={arXiv},
197+
primaryClass={cs.CL},
198+
url={https://arxiv.org/abs/2312.15503},
199+
}
200+
```

LLARA/data/finetune/toy_finetune_data.jsonl

Lines changed: 11 additions & 0 deletions
Large diffs are not rendered by default.

LLARA/data/pretrain/toy_pretrain_data.jsonl

Lines changed: 11 additions & 0 deletions
Large diffs are not rendered by default.

LLARA/finetune/__init__.py

Whitespace-only changes.

LLARA/finetune/arguments.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import os
2+
from dataclasses import dataclass, field
3+
from typing import Optional, List
4+
5+
from transformers import TrainingArguments
6+
7+
8+
def default_list() -> List[int]:
9+
return ['v_proj', 'q_proj', 'k_proj', 'gate_proj', 'down_proj', 'o_proj', 'up_proj']
10+
11+
12+
@dataclass
13+
class ModelArguments:
14+
"""
15+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
16+
"""
17+
18+
model_name_or_path: str = field(
19+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
20+
)
21+
22+
peft_model_path: str = field(
23+
default=''
24+
)
25+
config_name: Optional[str] = field(
26+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
27+
)
28+
tokenizer_name: Optional[str] = field(
29+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
30+
)
31+
# cache_dir: Optional[str] = field(
32+
# default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
33+
# )
34+
use_lora: bool = field(
35+
default=True,
36+
metadata={"help": "If passed, will use LORA (low-rank parameter-efficient training) to train the model."}
37+
)
38+
lora_rank: int = field(
39+
default=64,
40+
metadata={"help": "The rank of lora."}
41+
)
42+
lora_alpha: float = field(
43+
default=16,
44+
metadata={"help": "The alpha parameter of lora."}
45+
)
46+
lora_dropout: float = field(
47+
default=0.1,
48+
metadata={"help": "The dropout rate of lora modules."}
49+
)
50+
target_modules: List[str] = field(
51+
default_factory=default_list
52+
)
53+
save_merged_lora_model: bool = field(
54+
default=False,
55+
metadata={"help": "If passed, will merge the lora modules and save the entire model."}
56+
)
57+
use_flash_attn: bool = field(
58+
default=True,
59+
metadata={"help": "If passed, will use flash attention to train the model."}
60+
)
61+
use_slow_tokenizer: bool = field(
62+
default=False,
63+
metadata={"help": "If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library)."}
64+
)
65+
low_cpu_mem_usage: bool = field(
66+
default=False,
67+
metadata={"help": "It is an option to create the model as an empty shell,"
68+
"then only materialize its parameters when the pretrained weights are loaded."
69+
"If passed, LLM loading time and RAM consumption will be benefited."}
70+
)
71+
token: str = field(
72+
default=""
73+
)
74+
cache_dir: str = field(
75+
default="./LMs"
76+
)
77+
from_peft: str = field(
78+
default=None
79+
)
80+
81+
82+
@dataclass
83+
class DataArguments:
84+
train_data: str = field(
85+
default='./toy_finetune_data.jsonl', metadata={"help": "Path to train data"}
86+
)
87+
train_group_size: int = field(default=8)
88+
89+
query_max_len: int = field(
90+
default=32,
91+
metadata={
92+
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
93+
"than this will be truncated, sequences shorter will be padded."
94+
},
95+
)
96+
97+
passage_max_len: int = field(
98+
default=128,
99+
metadata={
100+
"help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
101+
"than this will be truncated, sequences shorter will be padded."
102+
},
103+
)
104+
105+
max_example_num_per_dataset: int = field(
106+
default=100000000, metadata={"help": "the max number of examples for each dataset"}
107+
)
108+
109+
query_instruction_for_retrieval: str = field(
110+
default="query: ", metadata={"help": "query: "}
111+
)
112+
passage_instruction_for_retrieval: str = field(
113+
default="passage: ", metadata={"help": "passage: "}
114+
)
115+
116+
cache_path: str = field(
117+
default='./data_dir'
118+
)
119+
120+
load_from_disk: bool = field(
121+
default=False, metadata={"help": " whether load the data from disk"}
122+
)
123+
124+
load_disk_path: str = field(
125+
default=None, metadata={"help": " the path to load the data", "nargs": "+"}
126+
)
127+
128+
save_to_disk: bool = field(
129+
default=False, metadata={"help": " whether save the data to disk"}
130+
)
131+
132+
save_disk_path: str = field(
133+
default=None, metadata={"help": " the path to save the data"}
134+
)
135+
136+
num_shards: int = field(
137+
default=0, metadata={
138+
"help": "number of shards to write, prior than `save_max_shard_size`, default depends on `save_max_shard_size`"}
139+
)
140+
141+
save_max_shard_size: str = field(
142+
default="50GB", metadata={"help": "the max size of the shard"}
143+
)
144+
145+
exit_after_save: bool = field(
146+
default=False, metadata={"help": " whether exit after save the data"}
147+
)
148+
149+
shuffle_ratio: float = field(
150+
default=0.0, metadata={"help": "The ratio of shuffling the text"}
151+
)
152+
153+
def __post_init__(self):
154+
if not os.path.exists(self.train_data):
155+
raise FileNotFoundError(f"cannot find file: {self.train_data}, please set a true path")
156+
157+
@dataclass
158+
class RetrieverTrainingArguments(TrainingArguments):
159+
negatives_cross_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
160+
temperature: Optional[float] = field(default=0.02)
161+
fix_position_embedding: bool = field(default=False, metadata={"help": "Freeze the parameters of position embeddings"})
162+
sentence_pooling_method: str = field(default='cls', metadata={"help": "the pooling method, should be cls or mean"})
163+
normlized: bool = field(default=True)
164+
sub_batch_size: int = field(default=None)
165+
cache_chunk_size: int = field(default=-1, metadata={"help": "用于缓存每一步的执行."})

0 commit comments

Comments
 (0)