Skip to content

Commit 6f224f1

Browse files
committed
ft reranker
1 parent aa9a4a9 commit 6f224f1

15 files changed

Lines changed: 182 additions & 2 deletions

File tree

FlagEmbedding/finetune/embedder/decoder_only/base/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def load_trainer(self) -> DecoderOnlyEmbedderTrainer:
9898
"""Load the trainer.
9999
100100
Returns:
101-
EncoderOnlyEmbedderTrainer: Loaded trainer instance.
101+
DecoderOnlyEmbedderTrainer: Loaded trainer instance.
102102
"""
103103
trainer = DecoderOnlyEmbedderTrainer(
104104
model=self.model,

FlagEmbedding/finetune/embedder/decoder_only/icl/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def load_trainer(self) -> DecoderOnlyEmbedderICLTrainer:
102102
"""Load the trainer.
103103
104104
Returns:
105-
EncoderOnlyEmbedderTrainer: Loaded trainer instance.
105+
DecoderOnlyEmbedderICLTrainer: Loaded trainer instance.
106106
"""
107107
trainer = DecoderOnlyEmbedderICLTrainer(
108108
model=self.model,

FlagEmbedding/finetune/reranker/decoder_only/base/arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ def default_target_modules() -> List[int]:
1010

1111
@dataclass
1212
class RerankerModelArguments(AbsRerankerModelArguments):
13+
"""
14+
Model argument class for decoder only reranker.
15+
"""
1316
use_lora: bool = field(
1417
default=True,
1518
metadata={"help": "If passed, will use LORA (low-rank parameter-efficient training) to train the model."}

FlagEmbedding/finetune/reranker/decoder_only/base/load_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010

1111

1212
def find_largest_checkpoint(checkpoint_dir):
13+
"""Find the largest checkpoint from directory.
14+
15+
Args:
16+
checkpoint_dir (str): Directory to the checkpoint.
17+
18+
Returns:
19+
str: Directory to the checkpoint, None no matching found.
20+
"""
1321
checkpoint_pattern = re.compile(r'checkpoint-(\d+)')
1422
max_number = -1
1523
max_checkpoint_file = None
@@ -27,6 +35,14 @@ def find_largest_checkpoint(checkpoint_dir):
2735

2836

2937
def get_model(model_args: RerankerModelArguments):
38+
"""Get the model.
39+
40+
Args:
41+
model_args (RerankerModelArguments): Model arguments instance.
42+
43+
Returns:
44+
transformers.PreTrainedModel or PeftModel: The loaded model.
45+
"""
3046
if model_args.config_name:
3147
config = AutoConfig.from_pretrained(
3248
model_args.config_name,
@@ -88,6 +104,13 @@ def get_model(model_args: RerankerModelArguments):
88104

89105

90106
def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
107+
"""
108+
Loads and save a model with specified configurations, merges it with PEFT layers if available.
109+
110+
Args:
111+
model_args (RerankerModelArguments): Model arguments instance.
112+
output_dir (str): Directory to save the model.
113+
"""
91114
if model_args.config_name:
92115
config = AutoConfig.from_pretrained(
93116
model_args.config_name,

FlagEmbedding/finetune/reranker/decoder_only/base/modeling.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@
88

99

1010
class CrossDecoderModel(AbsRerankerModel):
11+
"""
12+
Model class for decoder only reranker.
13+
14+
Args:
15+
base_model (PreTrainedModel): The underlying pre-trained model used for encoding and scoring input pairs.
16+
tokenizer (AutoTokenizer, optional): The tokenizer for encoding input text. Defaults to ``None``.
17+
train_batch_size (int, optional): The batch size to use. Defaults to ``4``.
18+
"""
1119
def __init__(
1220
self,
1321
base_model: PreTrainedModel,
@@ -21,6 +29,14 @@ def __init__(
2129
)
2230

2331
def encode(self, features):
32+
"""Encodes input features to logits.
33+
34+
Args:
35+
features (dict): Dictionary with input features.
36+
37+
Returns:
38+
torch.Tensor: The logits output from the model.
39+
"""
2440
if features is None:
2541
return None
2642
outputs = self.model(input_ids=features['input_ids'],

FlagEmbedding/finetune/reranker/decoder_only/base/runner.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717

1818

1919
class DecoderOnlyRerankerRunner(AbsRerankerRunner):
20+
"""
21+
Decoder only reranker runner for finetuning.
22+
23+
Args:
24+
model_args (RerankerModelArguments): Model arguments instance.
25+
data_args (AbsRerankerDataArguments): Data arguments instance.
26+
training_args (AbsRerankerTrainingArguments): Trainer arguments.
27+
"""
2028
def __init__(
2129
self,
2230
model_args: RerankerModelArguments,
@@ -26,6 +34,11 @@ def __init__(
2634
super().__init__(model_args, data_args, training_args)
2735

2836
def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsRerankerModel]:
37+
"""Load the tokenizer and model.
38+
39+
Returns:
40+
Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Tokenizer and model instances.
41+
"""
2942
tokenizer = AutoTokenizer.from_pretrained(
3043
self.model_args.tokenizer_name if self.model_args.tokenizer_name else self.model_args.model_name_or_path,
3144
token=self.model_args.token,
@@ -66,6 +79,11 @@ def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsRerankerMode
6679
return tokenizer, model
6780

6881
def load_trainer(self) -> DecoderOnlyRerankerTrainer:
82+
"""Load the trainer.
83+
84+
Returns:
85+
DecoderOnlyRerankerTrainer: Loaded trainer instance.
86+
"""
6987
trainer = DecoderOnlyRerankerTrainer(
7088
model=self.model,
7189
args=self.training_args,
@@ -76,6 +94,9 @@ def load_trainer(self) -> DecoderOnlyRerankerTrainer:
7694
return trainer
7795

7896
def run(self):
97+
"""
98+
Run the finetuning.
99+
"""
79100
Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True)
80101

81102
# Training

FlagEmbedding/finetune/reranker/decoder_only/base/trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,18 @@
1111

1212

1313
class DecoderOnlyRerankerTrainer(AbsRerankerTrainer):
14+
"""
15+
Trainer class for encoder only base reranker models.
16+
"""
1417
def _save(self, output_dir: Optional[str] = None, state_dict=None):
18+
"""Save the model to directory.
19+
20+
Args:
21+
output_dir (Optional[str], optional): Output directory to save the model. Defaults to ``None``.
22+
23+
Raises:
24+
NotImplementedError
25+
"""
1526
output_dir = output_dir if output_dir is not None else self.args.output_dir
1627
os.makedirs(output_dir, exist_ok=True)
1728
logger.info("Saving model checkpoint to %s", output_dir)

FlagEmbedding/finetune/reranker/decoder_only/layerwise/arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ def default_target_modules() -> List[int]:
1010

1111
@dataclass
1212
class RerankerModelArguments(AbsRerankerModelArguments):
13+
"""
14+
Model argument class for decoder only reranker.
15+
"""
1316
use_lora: bool = field(
1417
default=True,
1518
metadata={"help": "If passed, will use LORA (low-rank parameter-efficient training) to train the model."}

FlagEmbedding/finetune/reranker/decoder_only/layerwise/load_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414

1515

1616
def find_largest_checkpoint(checkpoint_dir):
17+
"""Find the largest checkpoint from directory.
18+
19+
Args:
20+
checkpoint_dir (str): Directory to the checkpoint.
21+
22+
Returns:
23+
str: Directory to the checkpoint, None no matching found.
24+
"""
1725
checkpoint_pattern = re.compile(r'checkpoint-(\d+)')
1826
max_number = -1
1927
max_checkpoint_file = None
@@ -31,6 +39,14 @@ def find_largest_checkpoint(checkpoint_dir):
3139

3240

3341
def get_model(model_args: RerankerModelArguments, only_for_one_logit):
42+
"""Get the model.
43+
44+
Args:
45+
model_args (RerankerModelArguments): Model arguments instance.
46+
47+
Returns:
48+
transformers.PreTrainedModel or PeftModel: The loaded model.
49+
"""
3450
if model_args.config_name:
3551
config = AutoConfig.from_pretrained(
3652
model_args.config_name,
@@ -152,6 +168,13 @@ def get_model(model_args: RerankerModelArguments, only_for_one_logit):
152168

153169

154170
def save_merged_model(model_args: RerankerModelArguments, output_dir: str):
171+
"""
172+
Loads and save a model with specified configurations, merges it with PEFT layers if available.
173+
174+
Args:
175+
model_args (RerankerModelArguments): Model arguments instance.
176+
output_dir (str): Directory to save the model.
177+
"""
155178
if model_args.config_name:
156179
config = AutoConfig.from_pretrained(
157180
model_args.config_name,

FlagEmbedding/finetune/reranker/decoder_only/layerwise/modeling.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@
1010

1111

1212
class CrossDecoderModel(AbsRerankerModel):
13+
"""
14+
Model class for decoder only reranker.
15+
16+
Args:
17+
base_model (PreTrainedModel): The underlying pre-trained model used for encoding and scoring input pairs.
18+
tokenizer (AutoTokenizer, optional): The tokenizer for encoding input text. Defaults to ``None``.
19+
train_batch_size (int, optional): The batch size to use. Defaults to ``4``.
20+
start_layer (int, optional): Starting layer for layerwise. Defaults to ``8``.
21+
"""
1322
def __init__(
1423
self,
1524
base_model: PreTrainedModel,

0 commit comments

Comments
 (0)