Skip to content

Commit 834498f

Browse files
committed
abs ft reranker
1 parent f43b46d commit 834498f

6 files changed

Lines changed: 135 additions & 2 deletions

File tree

FlagEmbedding/abc/finetune/embedder/AbsTrainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
2727
returns only the loss.
2828
2929
Returns:
30-
Union[torch.Tensor, tuple(torch.Tensor, ModelOutput)]: The computed loss. If ``return_outputs`` is ``True``,
30+
Union[torch.Tensor, tuple(torch.Tensor, EmbedderOutput)]: The computed loss. If ``return_outputs`` is ``True``,
3131
also returns the model's outputs in a tuple ``(loss, outputs)``.
3232
"""
3333

FlagEmbedding/abc/finetune/reranker/AbsArguments.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
@dataclass
99
class AbsRerankerModelArguments:
1010
"""
11-
Abstract class for model arguments.
11+
Abstract class for reranker model arguments.
1212
"""
1313

1414
model_name_or_path: str = field(
@@ -46,6 +46,9 @@ class AbsRerankerModelArguments:
4646

4747
@dataclass
4848
class AbsRerankerDataArguments:
49+
"""
50+
Abstract class for reranker data arguments.
51+
"""
4952
train_data: str = field(
5053
default=None, metadata={
5154
"help": "One or more paths to training data. `query: str`, `pos: List[str]`, `neg: List[str]` are required in the training data.",

FlagEmbedding/abc/finetune/reranker/AbsDataset.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121

2222

2323
class AbsRerankerTrainDataset(Dataset):
24+
"""Abstract class for reranker training dataset.
25+
26+
Args:
27+
args (AbsRerankerDataArguments): Data arguments.
28+
tokenizer (PreTrainedTokenizer): Tokenizer to use.
29+
"""
2430
def __init__(
2531
self,
2632
args: AbsRerankerDataArguments,
@@ -47,6 +53,17 @@ def __init__(
4753
self.max_length = self.args.query_max_len + self.args.passage_max_len
4854

4955
def _load_dataset(self, file_path: str):
56+
"""Load dataset from path.
57+
58+
Args:
59+
file_path (str): Path to load the datasets from.
60+
61+
Raises:
62+
ValueError: `pos_scores` and `neg_scores` not found in the features of training data
63+
64+
Returns:
65+
datasets.Dataset: Loaded HF dataset.
66+
"""
5067
if dist.get_rank() == 0:
5168
logger.info(f'loading data from {file_path} ...')
5269

@@ -64,6 +81,14 @@ def _load_dataset(self, file_path: str):
6481
return temp_dataset
6582

6683
def _shuffle_text(self, text):
84+
"""shuffle the input text.
85+
86+
Args:
87+
text (str): Input text.
88+
89+
Returns:
90+
str: Shuffled text.
91+
"""
6792
if self.args.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.args.shuffle_ratio:
6893
split_text = []
6994
chunk_size = len(text)//3 + 1
@@ -78,6 +103,15 @@ def __len__(self):
78103
return len(self.dataset)
79104

80105
def create_one_example(self, qry_encoding: str, doc_encoding: str):
106+
"""Creates a single input example by encoding and preparing a query and document pair for the model.
107+
108+
Args:
109+
qry_encoding (str): Query to be encoded.
110+
doc_encoding (str): Document to be encoded.
111+
112+
Returns:
113+
dict: A dictionary containing tokenized and prepared inputs, ready for model consumption.
114+
"""
81115
qry_inputs = self.tokenizer.encode(qry_encoding, truncation=True, max_length=self.args.query_max_len + self.args.passage_max_len // 4, add_special_tokens=False)
82116
doc_inputs = self.tokenizer.encode(doc_encoding, truncation=True, max_length=self.args.passage_max_len + self.args.query_max_len // 2, add_special_tokens=False)
83117
item = self.tokenizer.prepare_for_model(
@@ -143,6 +177,9 @@ def __getitem__(self, item):
143177

144178
@dataclass
145179
class AbsRerankerCollator(DataCollatorWithPadding):
180+
"""
181+
The abstract reranker collator.
182+
"""
146183
query_max_len: int = 32
147184
passage_max_len: int = 128
148185

@@ -171,6 +208,12 @@ def __call__(self, features) -> list[BatchEncoding]:
171208
}
172209

173210
class AbsLLMRerankerTrainDataset(AbsRerankerTrainDataset):
211+
"""Abstract class for LLM reranker training dataset.
212+
213+
Args:
214+
args (AbsRerankerDataArguments): Data arguments.
215+
tokenizer (PreTrainedTokenizer): Tokenizer to use.
216+
"""
174217
def __init__(
175218
self,
176219
args: AbsRerankerDataArguments,

FlagEmbedding/abc/finetune/reranker/AbsModeling.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@ class RerankerOutput(ModelOutput):
1818

1919

2020
class AbsRerankerModel(ABC, nn.Module):
21+
"""Abstract class of embedding model for training.
22+
23+
Args:
24+
base_model: The base model to train on.
25+
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
26+
train_batch_size (int, optional): Batch size used for training. Defaults to ``4``.
27+
"""
2128
def __init__(
2229
self,
2330
base_model: None,
@@ -38,16 +45,36 @@ def __init__(
3845
self.yes_loc = self.tokenizer('Yes', add_special_tokens=False)['input_ids'][-1]
3946

4047
def gradient_checkpointing_enable(self, **kwargs):
48+
"""
49+
Activates gradient checkpointing for the current model.
50+
"""
4151
self.model.gradient_checkpointing_enable(**kwargs)
4252

4353
def enable_input_require_grads(self, **kwargs):
54+
"""
55+
Enables the gradients for the input embeddings.
56+
"""
4457
self.model.enable_input_require_grads(**kwargs)
4558

4659
@abstractmethod
4760
def encode(self, features):
61+
"""Abstract method of encode.
62+
63+
Args:
64+
features (dict): Teatures to pass to the model.
65+
"""
4866
pass
4967

5068
def forward(self, pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None, teacher_scores: Optional[Tensor] = None):
69+
"""The computation performed at every call.
70+
71+
Args:
72+
pair (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): The query-document pair. Defaults to ``None``.
73+
teacher_scores (Optional[Tensor], optional): Teacher scores of knowledge distillation. Defaults to None.
74+
75+
Returns:
76+
RerankerOutput: Output of reranker model.
77+
"""
5178
ranker_logits = self.encode(pair) # (batch_size * num, dim)
5279
if teacher_scores is not None:
5380
teacher_scores = torch.Tensor(teacher_scores)
@@ -72,9 +99,23 @@ def forward(self, pair: Union[Dict[str, Tensor], List[Dict[str, Tensor]]] = None
7299
)
73100

74101
def compute_loss(self, scores, target):
102+
"""Compute the loss.
103+
104+
Args:
105+
scores (torch.Tensor): Computed scores.
106+
target (torch.Tensor): The target value.
107+
108+
Returns:
109+
torch.Tensor: The computed loss.
110+
"""
75111
return self.cross_entropy(scores, target)
76112

77113
def save(self, output_dir: str):
114+
"""Save the model.
115+
116+
Args:
117+
output_dir (str): Directory for saving the model.
118+
"""
78119
# self.model.save_pretrained(output_dir)
79120
state_dict = self.model.state_dict()
80121
state_dict = type(state_dict)(
@@ -84,5 +125,8 @@ def save(self, output_dir: str):
84125
self.model.save_pretrained(output_dir, state_dict=state_dict)
85126

86127
def save_pretrained(self, *args, **kwargs):
128+
"""
129+
Save the tokenizer and model.
130+
"""
87131
self.tokenizer.save_pretrained(*args, **kwargs)
88132
return self.model.save_pretrained(*args, **kwargs)

FlagEmbedding/abc/finetune/reranker/AbsRunner.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222

2323

2424
class AbsRerankerRunner(ABC):
25+
"""Abstract class to run reranker model fine-tuning.
26+
27+
Args:
28+
model_args (AbsRerankerModelArguments): Model arguments
29+
data_args (AbsRerankerDataArguments): Data arguments.
30+
training_args (AbsRerankerTrainingArguments): Training arguments.
31+
"""
2532
def __init__(
2633
self,
2734
model_args: AbsRerankerModelArguments,
@@ -70,13 +77,28 @@ def __init__(
7077

7178
@abstractmethod
7279
def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsRerankerModel]:
80+
"""Abstract method to load the tokenizer and model.
81+
82+
Returns:
83+
Tuple[PreTrainedTokenizer, AbsRerankerModel]: Loaded tokenizer and model instances.
84+
"""
7385
pass
7486

7587
@abstractmethod
7688
def load_trainer(self) -> AbsRerankerTrainer:
89+
"""Abstract method to load the trainer.
90+
91+
Returns:
92+
AbsRerankerTrainer: The loaded trainer instance.
93+
"""
7794
pass
7895

7996
def load_train_dataset(self) -> AbsRerankerTrainDataset:
97+
"""Loads the training dataset based on data arguments.
98+
99+
Returns:
100+
AbsRerankerTrainDataset: The loaded dataset instance.
101+
"""
80102
if self.model_args.model_type == 'encoder':
81103
train_dataset = AbsRerankerTrainDataset(
82104
args=self.data_args,
@@ -90,6 +112,11 @@ def load_train_dataset(self) -> AbsRerankerTrainDataset:
90112
return train_dataset
91113

92114
def load_data_collator(self) -> AbsRerankerCollator:
115+
"""Loads the appropriate data collator.
116+
117+
Returns:
118+
AbsRerankerCollator: Loaded data collator.
119+
"""
93120
if self.model_args.model_type == 'encoder':
94121
RerankerCollator = AbsRerankerCollator
95122
else:
@@ -106,6 +133,9 @@ def load_data_collator(self) -> AbsRerankerCollator:
106133
return data_collator
107134

108135
def run(self):
136+
"""
137+
Executes the training process.
138+
"""
109139
Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True)
110140

111141
# Training

FlagEmbedding/abc/finetune/reranker/AbsTrainer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88

99
class AbsRerankerTrainer(ABC, Trainer):
10+
"""
11+
Abstract class for the trainer of reranker.
12+
"""
1013
@abstractmethod
1114
def _save(self, output_dir: Optional[str] = None, state_dict=None):
1215
pass
@@ -16,6 +19,16 @@ def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
1619
How the loss is computed by Trainer. By default, all models return the loss in the first element.
1720
1821
Subclass and override for custom behavior.
22+
23+
Args:
24+
model (AbsRerankerModel): The model being trained.
25+
inputs (dict): A dictionary of input tensors to be passed to the model.
26+
return_outputs (bool, optional): If ``True``, returns both the loss and the model's outputs. Otherwise,
27+
returns only the loss. Defaults to ``False``.
28+
29+
Returns:
30+
Union[torch.Tensor, tuple(torch.Tensor, RerankerOutput)]: The computed loss. If ``return_outputs`` is ``True``,
31+
also returns the model's outputs in a tuple ``(loss, outputs)``.
1932
"""
2033

2134
outputs = model(**inputs)

0 commit comments

Comments
 (0)