Skip to content

Commit 235a775

Browse files
committed
ft base & m3
1 parent 834498f commit 235a775

8 files changed

Lines changed: 330 additions & 2 deletions

File tree

FlagEmbedding/abc/finetune/embedder/AbsModeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class AbsEmbedderModel(ABC, nn.Module):
3434
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
3535
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
3636
Defaults to ``-1``.
37-
kd_loss_type (str, optional): Knowledge distillation type. Defaults to ``"kl_div"``.
37+
kd_loss_type (str, optional): Type of knowledge distillation loss. Defaults to ``"kl_div"``.
3838
"""
3939
def __init__(
4040
self,

FlagEmbedding/finetune/embedder/encoder_only/base/modeling.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@
99

1010

1111
class BiEncoderOnlyEmbedderModel(AbsEmbedderModel):
12+
"""Embedder class for encoder only model.
13+
14+
Args:
15+
base_model (AutoModel): The base model to train on.
16+
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
17+
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
18+
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
19+
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
20+
Defaults to ``-1``.
21+
kd_loss_type (str, optional): Type of knowledge distillation loss. Defaults to ``"kl_div"``.
22+
sentence_pooling_method (str, optional): Pooling method to get sentence embedding. Defaults to '`cls`'.
23+
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to ``False``.
24+
"""
1225
TRANSFORMER_CLS = AutoModel
1326

1427
def __init__(
@@ -35,6 +48,14 @@ def __init__(
3548
self.cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean')
3649

3750
def encode(self, features):
51+
"""Encode and get the embedding.
52+
53+
Args:
54+
features (Union[list, dict]): Features feed to the model.
55+
56+
Returns:
57+
torch.Tensor: The embedding vectors.
58+
"""
3859
if features is None:
3960
return None
4061
if not isinstance(features, list):
@@ -70,6 +91,18 @@ def encode(self, features):
7091
return all_p_reps.contiguous()
7192

7293
def _sentence_embedding(self, last_hidden_state, attention_mask):
94+
"""Use the pooling method to get the sentence embedding.
95+
96+
Args:
97+
last_hidden_state (torch.Tensor): The model output's last hidden state.
98+
attention_mask (torch.Tensor): Mask out padding tokens during pooling.
99+
100+
Raises:
101+
NotImplementedError: Specified pooling method not implemented.
102+
103+
Returns:
104+
torch.Tensor: The sentence embeddings.
105+
"""
73106
if self.sentence_pooling_method == "cls":
74107
return last_hidden_state[:, 0]
75108
elif self.sentence_pooling_method == "mean":
@@ -93,25 +126,63 @@ def _sentence_embedding(self, last_hidden_state, attention_mask):
93126
raise NotImplementedError(f"pooling method {self.sentence_pooling_method} not implemented")
94127

95128
def compute_score(self, q_reps, p_reps):
129+
"""Computes the scores between query and passage representations.
130+
131+
Args:
132+
q_reps (torch.Tensor): Query representations.
133+
p_reps (torch.Tensor): Passage representations.
134+
135+
Returns:
136+
torch.Tensor: The computed scores, adjusted by temperature.
137+
"""
96138
scores = self._compute_similarity(q_reps, p_reps) / self.temperature
97139
scores = scores.view(q_reps.size(0), -1)
98140
return scores
99141

100142
def _compute_similarity(self, q_reps, p_reps):
143+
"""Computes the similarity between query and passage representations using inner product.
144+
145+
Args:
146+
q_reps (torch.Tensor): Query representations.
147+
p_reps (torch.Tensor): Passage representations.
148+
149+
Returns:
150+
torch.Tensor: The computed similarity matrix.
151+
"""
101152
if len(p_reps.size()) == 2:
102153
return torch.matmul(q_reps, p_reps.transpose(0, 1))
103154
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
104155

105156
def compute_loss(self, scores, target):
157+
"""Compute the loss using cross entropy.
158+
159+
Args:
160+
scores (torch.Tensor): Computed score.
161+
target (torch.Tensor): The target value.
162+
163+
Returns:
164+
torch.Tensor: The computed cross entropy loss.
165+
"""
106166
return self.cross_entropy(scores, target)
107167

108168
def gradient_checkpointing_enable(self, **kwargs):
169+
"""
170+
Activates gradient checkpointing for the current model.
171+
"""
109172
self.model.gradient_checkpointing_enable(**kwargs)
110173

111174
def enable_input_require_grads(self, **kwargs):
175+
"""
176+
Enables the gradients for the input embeddings.
177+
"""
112178
self.model.enable_input_require_grads(**kwargs)
113179

114180
def save(self, output_dir: str):
181+
"""Save the model to the directory.
182+
183+
Args:
184+
output_dir (str): Directory for saving the model.
185+
"""
115186
state_dict = self.model.state_dict()
116187
state_dict = type(state_dict)(
117188
{k: v.clone().cpu()

FlagEmbedding/finetune/embedder/encoder_only/base/runner.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,15 @@
1313

1414

1515
class EncoderOnlyEmbedderRunner(AbsEmbedderRunner):
16+
"""
17+
Finetune Runner for base embedding models.
18+
"""
1619
def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderModel]:
20+
"""Load tokenizer and model.
21+
22+
Returns:
23+
Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Tokenizer and model instances.
24+
"""
1725
tokenizer = AutoTokenizer.from_pretrained(
1826
self.model_args.model_name_or_path,
1927
cache_dir=self.model_args.cache_dir,
@@ -58,6 +66,11 @@ def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderMode
5866
return tokenizer, model
5967

6068
def load_trainer(self) -> EncoderOnlyEmbedderTrainer:
69+
"""Load the trainer.
70+
71+
Returns:
72+
EncoderOnlyEmbedderTrainer: Loaded trainer instance.
73+
"""
6174
trainer = EncoderOnlyEmbedderTrainer(
6275
model=self.model,
6376
args=self.training_args,

FlagEmbedding/finetune/embedder/encoder_only/base/trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,18 @@
99

1010

1111
class EncoderOnlyEmbedderTrainer(AbsEmbedderTrainer):
12+
"""
13+
Trainer class for base encoder models.
14+
"""
1215
def _save(self, output_dir: Optional[str] = None, state_dict=None):
16+
"""Save the model to directory.
17+
18+
Args:
19+
output_dir (Optional[str], optional): Output directory to save the model. Defaults to ``None``.
20+
21+
Raises:
22+
NotImplementedError
23+
"""
1324
output_dir = output_dir if output_dir is not None else self.args.output_dir
1425
os.makedirs(output_dir, exist_ok=True)
1526
logger.info("Saving model checkpoint to %s", output_dir)

FlagEmbedding/finetune/embedder/encoder_only/m3/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,17 @@
88

99
@dataclass
1010
class EncoderOnlyEmbedderM3ModelArguments(AbsEmbedderModelArguments):
11+
"""
12+
Model argument class for M3.
13+
"""
1114
colbert_dim: int = field(default=-1, metadata={"help": "Dim of colbert linear"})
1215

1316

1417
@dataclass
1518
class EncoderOnlyEmbedderM3TrainingArguments(AbsEmbedderTrainingArguments):
19+
"""
20+
Training argument class for M3.
21+
"""
1622
unified_finetuning: bool = field(default=False, metadata={"help": "use unify fine-tuning"})
1723
use_self_distill: bool = field(default=False, metadata={"help": "use self-distill when using unify fine-tuning"})
1824
fix_encoder: bool = field(default=False, metadata={"help": "Freeze the parameters of encoder"})

0 commit comments

Comments
 (0)