|
1 | 1 | import os |
| 2 | +import re |
| 3 | + |
2 | 4 | import torch |
3 | 5 | from transformers import AutoConfig, AutoModel, AutoTokenizer |
4 | 6 | from peft import LoraConfig, TaskType, get_peft_model, PeftModel |
5 | 7 |
|
| 8 | +def find_largest_checkpoint(checkpoint_dir): |
| 9 | + checkpoint_pattern = re.compile(r'checkpoint-(\d+)') |
| 10 | + max_number = -1 |
| 11 | + max_checkpoint_file = None |
| 12 | + for file in os.listdir(checkpoint_dir): |
| 13 | + match = checkpoint_pattern.search(file) |
| 14 | + if match: |
| 15 | + number = int(match.group(1)) |
| 16 | + if number > max_number: |
| 17 | + max_number = number |
| 18 | + max_checkpoint_file = file |
| 19 | + if max_checkpoint_file: |
| 20 | + return os.path.join(checkpoint_dir, max_checkpoint_file) |
| 21 | + else: |
| 22 | + return None |
6 | 23 |
|
7 | 24 | def get_model(model_args, output_dir, resize, resize_tokens): |
8 | 25 |
|
@@ -112,8 +129,12 @@ def save_merged_model(model_args, output_dir): |
112 | 129 | if os.path.exists(os.path.join(output_dir, 'embedding', 'emb.pth')): |
113 | 130 | model.set_input_embeddings(torch.load(os.path.join(output_dir, 'embedding', 'emb.pth'))) |
114 | 131 |
|
115 | | - model = PeftModel.from_pretrained(model, output_dir) |
116 | | - model = model.merge_and_unload() |
| 132 | + try: |
| 133 | + model = PeftModel.from_pretrained(model, output_dir) |
| 134 | + model = model.merge_and_unload() |
| 135 | + except: |
| 136 | + model = PeftModel.from_pretrained(model, find_largest_checkpoint(output_dir)) |
| 137 | + model = model.merge_and_unload() |
117 | 138 |
|
118 | 139 | model.save_pretrained(os.path.join(output_dir, 'full_model')) |
119 | 140 |
|
|
0 commit comments