Skip to content

Commit 0e56bd9

Browse files
committed
ft embedder docstring
1 parent 235a775 commit 0e56bd9

13 files changed

Lines changed: 303 additions & 6 deletions

File tree

FlagEmbedding/finetune/embedder/decoder_only/base/arguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ def default_target_modules() -> List[int]:
1010

1111
@dataclass
1212
class DecoderOnlyEmbedderModelArguments(AbsEmbedderModelArguments):
13+
"""
14+
Model argument class for decoder only base model.
15+
"""
1316
peft_model_path: str = field(
1417
default='', metadata={"help": "The peft model checkpoint for initialization."}
1518
)

FlagEmbedding/finetune/embedder/decoder_only/base/load_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111

1212

1313
def find_largest_checkpoint(checkpoint_dir):
14+
"""Find the largest checkpoint from directory.
15+
16+
Args:
17+
checkpoint_dir (str): Directory to the checkpoint.
18+
19+
Returns:
20+
str: Directory to the checkpoint, None no matching found.
21+
"""
1422
checkpoint_pattern = re.compile(r'checkpoint-(\d+)')
1523
max_number = -1
1624
max_checkpoint_file = None
@@ -28,6 +36,17 @@ def find_largest_checkpoint(checkpoint_dir):
2836

2937

3038
def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, resize: bool, resize_tokens: int):
39+
"""Get the model.
40+
41+
Args:
42+
model_args (DecoderOnlyEmbedderModelArguments): Model arguments instance.
43+
output_dir (str): Directory to save the model.
44+
resize (bool): Whether to resize the number of tokens.
45+
resize_tokens (int): The new token size.
46+
47+
Returns:
48+
transformers.PreTrainedModel or PeftModel: The loaded model.
49+
"""
3150
if model_args.config_name:
3251
config = AutoConfig.from_pretrained(
3352
model_args.config_name,
@@ -99,6 +118,13 @@ def get_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str, re
99118

100119

101120
def save_merged_model(model_args: DecoderOnlyEmbedderModelArguments, output_dir: str):
121+
"""
122+
Loads a model with specified configurations, merges it with PEFT layers if available.
123+
124+
Args:
125+
model_args (DecoderOnlyEmbedderModelArguments): Model arguments instance.
126+
output_dir (str): Directory to save the model.
127+
"""
102128
if model_args.config_name:
103129
config = AutoConfig.from_pretrained(
104130
model_args.config_name,

FlagEmbedding/finetune/embedder/decoder_only/base/modeling.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@
99

1010

1111
class BiDecoderOnlyEmbedderModel(AbsEmbedderModel):
12+
"""Embedder model class for decoder only model.
13+
14+
Args:
15+
base_model (AutoModel): The base model to train on.
16+
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
17+
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
18+
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
19+
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
20+
Defaults to ``-1``.
21+
kd_loss_type (str, optional): Type of knowledge distillation loss. Defaults to ``'kl_div'``.
22+
sentence_pooling_method (str, optional): Pooling method to get sentence embedding. Defaults to ``'last_token'``.
23+
normalize_embeddings (bool, optional): If True, normalize the embedding vector. Defaults to ``False``.
24+
"""
1225
TRANSFORMER_CLS = AutoModel
1326

1427
def __init__(
@@ -35,6 +48,15 @@ def __init__(
3548
self.cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean')
3649

3750
def encode(self, features):
51+
"""
52+
Encode and get the embedding.
53+
54+
Args:
55+
features (Union[list, dict]): Features feed to the model.
56+
57+
Returns:
58+
torch.Tensor: The embedding vectors.
59+
"""
3860
if features is None:
3961
return None
4062
if not isinstance(features, list):
@@ -70,6 +92,18 @@ def encode(self, features):
7092
return all_p_reps.contiguous()
7193

7294
def _sentence_embedding(self, last_hidden_state, attention_mask):
95+
"""Use the pooling method to get the sentence embedding.
96+
97+
Args:
98+
last_hidden_state (torch.Tensor): The model output's last hidden state.
99+
attention_mask (torch.Tensor): Mask out padding tokens during pooling.
100+
101+
Raises:
102+
NotImplementedError: Specified pooling method not implemented.
103+
104+
Returns:
105+
torch.Tensor: The sentence embeddings.
106+
"""
73107
if self.sentence_pooling_method == "cls":
74108
return last_hidden_state[:, 0]
75109
elif self.sentence_pooling_method == "mean":
@@ -93,25 +127,63 @@ def _sentence_embedding(self, last_hidden_state, attention_mask):
93127
raise NotImplementedError(f"pooling method {self.sentence_pooling_method} not implemented")
94128

95129
def compute_score(self, q_reps, p_reps):
130+
"""Computes the scores between query and passage representations.
131+
132+
Args:
133+
q_reps (torch.Tensor): Query representations.
134+
p_reps (torch.Tensor): Passage representations.
135+
136+
Returns:
137+
torch.Tensor: The computed scores, adjusted by temperature.
138+
"""
96139
scores = self._compute_similarity(q_reps, p_reps) / self.temperature
97140
scores = scores.view(q_reps.size(0), -1)
98141
return scores
99142

100143
def _compute_similarity(self, q_reps, p_reps):
144+
"""Computes the similarity between query and passage representations using inner product.
145+
146+
Args:
147+
q_reps (torch.Tensor): Query representations.
148+
p_reps (torch.Tensor): Passage representations.
149+
150+
Returns:
151+
torch.Tensor: The computed similarity matrix.
152+
"""
101153
if len(p_reps.size()) == 2:
102154
return torch.matmul(q_reps, p_reps.transpose(0, 1))
103155
return torch.matmul(q_reps, p_reps.transpose(-2, -1))
104156

105157
def compute_loss(self, scores, target):
158+
"""Compute the loss using cross entropy.
159+
160+
Args:
161+
scores (torch.Tensor): Computed score.
162+
target (torch.Tensor): The target value.
163+
164+
Returns:
165+
torch.Tensor: The computed cross entropy loss.
166+
"""
106167
return self.cross_entropy(scores, target)
107168

108169
def gradient_checkpointing_enable(self, **kwargs):
170+
"""
171+
Activates gradient checkpointing for the current model.
172+
"""
109173
self.model.gradient_checkpointing_enable(**kwargs)
110174

111175
def enable_input_require_grads(self, **kwargs):
176+
"""
177+
Enables the gradients for the input embeddings.
178+
"""
112179
self.model.enable_input_require_grads(**kwargs)
113180

114181
def save(self, output_dir: str):
182+
"""Save the model to the directory.
183+
184+
Args:
185+
output_dir (str): Directory for saving the model.
186+
"""
115187
state_dict = self.model.state_dict()
116188
state_dict = type(state_dict)(
117189
{k: v.clone().cpu()

FlagEmbedding/finetune/embedder/decoder_only/base/runner.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515

1616

1717
class DecoderOnlyEmbedderRunner(AbsEmbedderRunner):
18+
"""Runner class for decoder only embedding model.
19+
20+
Args:
21+
model_args (DecoderOnlyEmbedderModelArguments): Model arguments instance.
22+
data_args (AbsEmbedderDataArguments): Data arguments instance.
23+
training_args (AbsEmbedderTrainingArguments): Trainer arguments.
24+
"""
1825
def __init__(
1926
self,
2027
model_args: DecoderOnlyEmbedderModelArguments,
@@ -24,6 +31,11 @@ def __init__(
2431
super().__init__(model_args, data_args, training_args)
2532

2633
def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderModel]:
34+
"""Load tokenizer and model.
35+
36+
Returns:
37+
Tuple[PreTrainedTokenizer, AbsEmbedderModel]: Tokenizer and model instances.
38+
"""
2739
tokenizer = AutoTokenizer.from_pretrained(
2840
self.model_args.tokenizer_name if self.model_args.tokenizer_name else self.model_args.model_name_or_path,
2941
token=self.model_args.token,
@@ -83,6 +95,11 @@ def load_tokenizer_and_model(self) -> Tuple[PreTrainedTokenizer, AbsEmbedderMode
8395
return tokenizer, model
8496

8597
def load_trainer(self) -> DecoderOnlyEmbedderTrainer:
98+
"""Load the trainer.
99+
100+
Returns:
101+
EncoderOnlyEmbedderTrainer: Loaded trainer instance.
102+
"""
86103
trainer = DecoderOnlyEmbedderTrainer(
87104
model=self.model,
88105
args=self.training_args,
@@ -95,6 +112,9 @@ def load_trainer(self) -> DecoderOnlyEmbedderTrainer:
95112
return trainer
96113

97114
def run(self):
115+
"""
116+
Run the finetune.
117+
"""
98118
Path(self.training_args.output_dir).mkdir(parents=True, exist_ok=True)
99119

100120
# Training

FlagEmbedding/finetune/embedder/decoder_only/base/trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,18 @@
99

1010

1111
class DecoderOnlyEmbedderTrainer(AbsEmbedderTrainer):
12+
"""
13+
Trainer class for base encoder models.
14+
"""
1215
def _save(self, output_dir: Optional[str] = None, state_dict=None):
16+
"""Save the model to directory.
17+
18+
Args:
19+
output_dir (Optional[str], optional): Output directory to save the model. Defaults to ``None``.
20+
21+
Raises:
22+
NotImplementedError
23+
"""
1324
output_dir = output_dir if output_dir is not None else self.args.output_dir
1425
os.makedirs(output_dir, exist_ok=True)
1526
logger.info("Saving model checkpoint to %s", output_dir)

FlagEmbedding/finetune/embedder/decoder_only/icl/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ def default_target_modules() -> List[int]:
1313

1414
@dataclass
1515
class DecoderOnlyEmbedderICLModelArguments(AbsEmbedderModelArguments):
16+
"""
17+
Model argument class for decoder only icl model.
18+
"""
1619
peft_model_path: str = field(
1720
default='', metadata={"help": "The peft model checkpoint for initialization."}
1821
)
@@ -73,6 +76,9 @@ class DecoderOnlyEmbedderICLModelArguments(AbsEmbedderModelArguments):
7376

7477
@dataclass
7578
class DecoderOnlyEmbedderICLDataArguments(AbsEmbedderDataArguments):
79+
"""
80+
Data argument class for decoder only icl model.
81+
"""
7682
example_query_max_len: int = field(
7783
default=64,
7884
metadata={"help": "The max length of example query."}

FlagEmbedding/finetune/embedder/decoder_only/icl/dataset.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@
1515

1616

1717
class DecoderOnlyEmbedderICLSameDatasetTrainDataset(AbsEmbedderSameDatasetTrainDataset):
18+
"""Dataset class for icl model.
19+
20+
Args:
21+
args (DecoderOnlyEmbedderICLDataArguments): Data argument class for icl model.
22+
default_batch_size (int): The default batch size.
23+
seed (int): Random seed to use.
24+
tokenizer (PreTrainedTokenizer): Tokenzier.
25+
process_index (int, optional): Current process index. Defaults to 0.
26+
num_processes (int, optional): Total number of processes. Defaults to 1.
27+
"""
1828
def __init__(
1929
self,
2030
args: DecoderOnlyEmbedderICLDataArguments,
@@ -39,6 +49,16 @@ def __init__(
3949
self.prefix = self.tokenizer(f"{self.tokenizer.bos_token}", add_special_tokens=False)['input_ids']
4050

4151
def _create_batch_data(self, batch_raw_data):
52+
"""Create a comple batch of data with queries, documents and teacher scores.
53+
54+
Args:
55+
batch_raw_data (datasets.Dataset): One batch of raw data.
56+
57+
Returns:
58+
List[str]: Queries with instruction format.
59+
List[str]: Documents with instruction format.
60+
List[float]: Teacher scores for model distillation.
61+
"""
4262
queries, passages, teacher_scores = [], [], []
4363

4464
train_group_size, data_type = self._get_train_group_size(batch_raw_data)
@@ -179,10 +199,12 @@ def _create_batch_data(self, batch_raw_data):
179199
@dataclass
180200
class AbsEmbedderSameDatasetCollator(DataCollatorWithPadding):
181201
"""
182-
EmbedCollator for SameDataset
202+
EmbedCollator for SameDataset.
183203
Note that after using this collator, the training_args should be set as:
184-
training_args.per_device_train_batch_size = 1
185-
training_args.dataloader_num_workers = 0 # avoid multi-processing
204+
205+
``training_args.per_device_train_batch_size = 1``
206+
207+
``training_args.dataloader_num_workers = 0 # avoid multi-processing``
186208
"""
187209
query_max_len: int = 32
188210
passage_max_len: int = 128

FlagEmbedding/finetune/embedder/decoder_only/icl/load_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111

1212

1313
def find_largest_checkpoint(checkpoint_dir):
14+
"""Find the largest checkpoint from directory.
15+
16+
Args:
17+
checkpoint_dir (str): Directory to the checkpoint.
18+
19+
Returns:
20+
str: Directory to the checkpoint, None no matching found.
21+
"""
1422
checkpoint_pattern = re.compile(r'checkpoint-(\d+)')
1523
max_number = -1
1624
max_checkpoint_file = None
@@ -28,6 +36,17 @@ def find_largest_checkpoint(checkpoint_dir):
2836

2937

3038
def get_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str, resize: bool, resize_tokens: int):
39+
"""Get the model.
40+
41+
Args:
42+
model_args (DecoderOnlyEmbedderModelArguments): Model arguments instance.
43+
output_dir (str): Directory to save the model.
44+
resize (bool): Whether to resize the number of tokens.
45+
resize_tokens (int): The new token size.
46+
47+
Returns:
48+
transformers.PreTrainedModel or PeftModel: The loaded model.
49+
"""
3150
if model_args.config_name:
3251
config = AutoConfig.from_pretrained(
3352
model_args.config_name,
@@ -99,6 +118,13 @@ def get_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str,
99118

100119

101120
def save_merged_model(model_args: DecoderOnlyEmbedderICLModelArguments, output_dir: str):
121+
"""
122+
Loads a model with specified configurations, merges it with PEFT layers if available.
123+
124+
Args:
125+
model_args (DecoderOnlyEmbedderModelArguments): Model arguments instance.
126+
output_dir (str): Directory to save the model.
127+
"""
102128
if model_args.config_name:
103129
config = AutoConfig.from_pretrained(
104130
model_args.config_name,

0 commit comments

Comments
 (0)