Skip to content

Commit c0c040f

Browse files
authored
Merge pull request #1108 from 545999961/master
update bge-en-icl
2 parents 658fbb9 + 8d8fb8f commit c0c040f

4 files changed

Lines changed: 39 additions & 7 deletions

File tree

FlagEmbedding/llm_dense_retriever/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,8 @@ run.py \
208208
--use_special_tokens \
209209
--symmetric_batch_size 256 \
210210
--symmetric_train_group_size 8 \
211-
--max_class_neg 7
211+
--max_class_neg 7 \
212+
--save_merged_lora_model True
212213
```
213214

214215
## Citation

FlagEmbedding/llm_dense_retriever/finetune/arguments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,10 @@ class ModelArguments:
5555
metadata={"help": "If passed, will use flash attention to train the model."}
5656
)
5757
token: str = field(
58-
default="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"
58+
default=".."
5959
)
6060
cache_dir: str = field(
61-
default="/share/LMs"
61+
default="../LMs"
6262
)
6363
from_peft: str = field(
6464
default=None

FlagEmbedding/llm_dense_retriever/finetune/load_model.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
11
import os
2+
import re
3+
24
import torch
35
from transformers import AutoConfig, AutoModel, AutoTokenizer
46
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
57

8+
def find_largest_checkpoint(checkpoint_dir):
9+
checkpoint_pattern = re.compile(r'checkpoint-(\d+)')
10+
max_number = -1
11+
max_checkpoint_file = None
12+
for file in os.listdir(checkpoint_dir):
13+
match = checkpoint_pattern.search(file)
14+
if match:
15+
number = int(match.group(1))
16+
if number > max_number:
17+
max_number = number
18+
max_checkpoint_file = file
19+
if max_checkpoint_file:
20+
return os.path.join(checkpoint_dir, max_checkpoint_file)
21+
else:
22+
return None
623

724
def get_model(model_args, output_dir, resize, resize_tokens):
825

@@ -112,8 +129,12 @@ def save_merged_model(model_args, output_dir):
112129
if os.path.exists(os.path.join(output_dir, 'embedding', 'emb.pth')):
113130
model.set_input_embeddings(torch.load(os.path.join(output_dir, 'embedding', 'emb.pth')))
114131

115-
model = PeftModel.from_pretrained(model, output_dir)
116-
model = model.merge_and_unload()
132+
try:
133+
model = PeftModel.from_pretrained(model, output_dir)
134+
model = model.merge_and_unload()
135+
except:
136+
model = PeftModel.from_pretrained(model, find_largest_checkpoint(output_dir))
137+
model = model.merge_and_unload()
117138

118139
model.save_pretrained(os.path.join(output_dir, 'full_model'))
119140

FlagEmbedding/llm_dense_retriever/finetune/run.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from data import SameDatasetTrainDataset, SameEmbedCollator
1515
from modeling import BiEncoderModel
1616
from trainer import BiTrainer
17-
from load_model import get_model
17+
from load_model import get_model, save_merged_model
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -143,6 +143,16 @@ def main():
143143
# os.makedirs(os.path.join(training_args.output_dir, 'embedding'), exist_ok=True)
144144
# torch.save(base_model.model.model.embed_tokens, os.path.join(training_args.output_dir, 'embedding', 'emb.pth'))
145145

146+
def save_model():
147+
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
148+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
149+
model_args: ModelArguments
150+
data_args: DataArguments
151+
training_args: TrainingArguments
152+
153+
if model_args.save_merged_lora_model and training_args.process_index == 0:
154+
save_merged_model(model_args, training_args.output_dir)
146155

147156
if __name__ == "__main__":
148-
main()
157+
main()
158+
save_model()

0 commit comments

Comments
 (0)