Skip to content

Commit 0dff36c

Browse files
authored
Merge pull request #1200 from ZiyiXia/master
Docstring
2 parents 068e86f + 6cc832c commit 0dff36c

55 files changed

Lines changed: 1420 additions & 26 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

FlagEmbedding/abc/evaluation/data_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def _download_gz_file(self, download_url: str, save_dir: str):
374374
save_dir (str): Path to the directory to save the gzip file.
375375
376376
Raises:
377-
FileNotFoundError: _description_
377+
FileNotFoundError
378378
379379
Returns:
380380
str: The path to the file after unzip.

FlagEmbedding/abc/finetune/embedder/AbsArguments.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ class AbsEmbedderModelArguments:
3838

3939
@dataclass
4040
class AbsEmbedderDataArguments:
41+
"""
42+
Abstract class for data arguments.
43+
"""
4144
train_data: str = field(
4245
default=None, metadata={
4346
"help": "One or more paths to training data. `query: str`, `pos: List[str]`, `neg: List[str]` are required in the training data.",

FlagEmbedding/abc/finetune/embedder/AbsDataset.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@
2121

2222

2323
class AbsEmbedderTrainDataset(Dataset):
24+
"""Abstract class for training dataset.
25+
26+
Args:
27+
args (AbsEmbedderDataArguments): Data arguments.
28+
tokenizer (PreTrainedTokenizer): Tokenizer to use.
29+
"""
2430
def __init__(
2531
self,
2632
args: AbsEmbedderDataArguments,
@@ -46,6 +52,17 @@ def __init__(
4652
self.dataset = datasets.concatenate_datasets(train_datasets)
4753

4854
def _load_dataset(self, file_path: str):
55+
"""Load dataset from path.
56+
57+
Args:
58+
file_path (str): Path to load the datasets from.
59+
60+
Raises:
61+
ValueError: `pos_scores` and `neg_scores` not found in the features of training data
62+
63+
Returns:
64+
datasets.Dataset: Loaded HF dataset.
65+
"""
4966
if dist.get_rank() == 0:
5067
logger.info(f'loading data from {file_path} ...')
5168

@@ -63,6 +80,14 @@ def _load_dataset(self, file_path: str):
6380
return temp_dataset
6481

6582
def _shuffle_text(self, text):
83+
"""shuffle the input text.
84+
85+
Args:
86+
text (str): Input text.
87+
88+
Returns:
89+
str: Shuffled text.
90+
"""
6691
if self.shuffle_ratio > 0 and len(text) > 100 and random.random() < self.shuffle_ratio:
6792
split_text = []
6893
chunk_size = len(text)//3 + 1
@@ -126,6 +151,9 @@ def __getitem__(self, item):
126151

127152
@dataclass
128153
class AbsEmbedderCollator(DataCollatorWithPadding):
154+
"""
155+
The abstract embedder collator.
156+
"""
129157
query_max_len: int = 32
130158
passage_max_len: int = 128
131159
sub_batch_size: int = -1
@@ -214,6 +242,16 @@ def __call__(self, features):
214242

215243

216244
class AbsEmbedderSameDatasetTrainDataset(AbsEmbedderTrainDataset):
245+
"""Abstract class for training dataset that samples batches from same dataset.
246+
247+
Args:
248+
args (AbsEmbedderDataArguments): Data arguments.
249+
default_batch_size (int): The default batch size for training.
250+
seed (int): Random seed.
251+
tokenizer (PreTrainedTokenizer): Tokenizer to use.
252+
process_index (int, optional): Current process index. Defaults to 0.
253+
num_processes (int, optional): Total number of processes. Defaults to 1.
254+
"""
217255
def __init__(
218256
self,
219257
args: AbsEmbedderDataArguments,
@@ -296,6 +334,14 @@ def __init__(
296334
self.refresh_epoch()
297335

298336
def _load_dataset(self, file_path: str):
337+
"""Load datset from given path.
338+
339+
Args:
340+
file_path (str): The path to load or download from HF hub.
341+
342+
Returns:
343+
datasets.Dataset: The loaded dataset.
344+
"""
299345
if dist.get_rank() == 0:
300346
logger.info(f'loading data from {file_path} ...')
301347

@@ -311,6 +357,15 @@ def _load_dataset(self, file_path: str):
311357

312358
@staticmethod
313359
def _get_file_batch_size(temp_dataset: datasets.Dataset, default_batch_size: int):
360+
"""Get the appropriate batch size for the dataset.
361+
362+
Args:
363+
temp_dataset (datasets.Dataset): Loaded :data:`datasets.Dataset` object.
364+
default_batch_size (int): The default batch size to use if not specified in the dataset.
365+
366+
Returns:
367+
int: The final batch size to use.
368+
"""
314369
if 'batch_size' in temp_dataset.column_names:
315370
return temp_dataset['batch_size'][0]
316371
if 'type' in temp_dataset.column_names:
@@ -320,6 +375,9 @@ def _get_file_batch_size(temp_dataset: datasets.Dataset, default_batch_size: int
320375
return default_batch_size
321376

322377
def refresh_epoch(self):
378+
"""
379+
Refresh data for epoch.
380+
"""
323381
logger.info(f'-- Rank {self.process_index}: refresh data --')
324382
self.deterministic_generator.shuffle(self.datasets_inxs)
325383

@@ -353,6 +411,15 @@ def __getitem__(self, _):
353411
return queries, passages, teacher_scores, no_in_batch_neg_flag
354412

355413
def _get_train_group_size(self, batch_raw_data):
414+
"""Get the training group size and data type.
415+
416+
Args:
417+
batch_raw_data (datasets.Dataset): One batch of raw data.
418+
419+
Returns:
420+
int: The training group size.
421+
str: The type of data for the task.
422+
"""
356423
if 'type' in batch_raw_data:
357424
data_type = batch_raw_data['type'][0]
358425
if data_type in ['only_1neg']:
@@ -362,6 +429,16 @@ def _get_train_group_size(self, batch_raw_data):
362429
return self.args.train_group_size, None
363430

364431
def _create_batch_data(self, batch_raw_data):
432+
"""Create a comple batch of data with queries, documents and teacher scores.
433+
434+
Args:
435+
batch_raw_data (datasets.Dataset): One batch of raw data.
436+
437+
Returns:
438+
List[str]: Queries with instruction format.
439+
List[str]: Documents with instruction format.
440+
List[float]: Teacher scores for model distillation.
441+
"""
365442
queries, passages, teacher_scores = [], [], []
366443

367444
train_group_size, data_type = self._get_train_group_size(batch_raw_data)
@@ -426,10 +503,12 @@ def _create_batch_data(self, batch_raw_data):
426503
@dataclass
427504
class AbsEmbedderSameDatasetCollator(DataCollatorWithPadding):
428505
"""
429-
EmbedCollator for SameDataset
506+
EmbedCollator for SameDataset.
430507
Note that after using this collator, the training_args should be set as:
431-
training_args.per_device_train_batch_size = 1
432-
training_args.dataloader_num_workers = 0 # avoid multi-processing
508+
509+
``training_args.per_device_train_batch_size = 1``
510+
511+
``training_args.dataloader_num_workers = 0 # avoid multi-processing``
433512
"""
434513
query_max_len: int = 32
435514
passage_max_len: int = 128
@@ -516,6 +595,9 @@ def __call__(self, features):
516595

517596

518597
class EmbedderTrainerCallbackForDataRefresh(TrainerCallback):
598+
"""
599+
Callback class to inspect the state of the training loop and take decision.
600+
"""
519601
def __init__(self, train_dataset: AbsEmbedderSameDatasetTrainDataset):
520602
self.train_dataset = train_dataset
521603

FlagEmbedding/abc/finetune/embedder/AbsModeling.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,27 @@
1515

1616
@dataclass
1717
class EmbedderOutput(ModelOutput):
18+
"""
19+
Output information returned by the model.
20+
"""
1821
q_reps: Optional[Tensor] = None
1922
p_reps: Optional[Tensor] = None
2023
loss: Optional[Tensor] = None
2124
scores: Optional[Tensor] = None
2225

2326

2427
class AbsEmbedderModel(ABC, nn.Module):
28+
"""Abstract class of embedding model for training.
29+
30+
Args:
31+
base_model: The base model to train on.
32+
tokenizer (AutoTokenizer, optional): The tokenizer to use. Defaults to ``None``.
33+
negatives_cross_device (bool, optional): If True, will compute cross devices negative loss. Defaults to ``False``.
34+
temperature (float, optional): Temperature to control the scale of scores. Defaults to ``1.0``.
35+
sub_batch_size (int, optional): Sub-batch size during encoding. If negative, will not split to sub-batch.
36+
Defaults to ``-1``.
37+
kd_loss_type (str, optional): Type of knowledge distillation loss. Defaults to ``"kl_div"``.
38+
"""
2539
def __init__(
2640
self,
2741
base_model,
@@ -48,21 +62,53 @@ def __init__(
4862

4963
@abstractmethod
5064
def encode(self, features):
65+
"""Abstract method encode and get the embedding.
66+
67+
Args:
68+
features (Union[list, dict]): Features feed to the model.
69+
"""
5170
pass
5271

5372
@abstractmethod
5473
def compute_loss(self, scores, target):
74+
"""Abstract method compute the loss.
75+
76+
Args:
77+
scores (torch.Tensor): Computed score.
78+
target (torch.Tensor): The target value.
79+
"""
5580
pass
5681

5782
@abstractmethod
5883
def compute_score(self, q_reps, p_reps):
84+
"""Abstract method to compute the score.
85+
86+
Args:
87+
q_reps (torch.Tensor): Queries representations.
88+
p_reps (torch.Tensor): Passages rerpresentations.
89+
"""
5990
pass
6091

6192
@abstractmethod
6293
def save(self, output_dir: str):
94+
"""Abstract method to save the model.
95+
96+
Args:
97+
output_dir (str): Directory for saving the model.
98+
"""
6399
pass
64100

65101
def get_local_score(self, q_reps, p_reps, all_scores):
102+
"""Get the local score of queries and passages.
103+
104+
Args:
105+
q_reps (torch.Tensor): Queries representations.
106+
p_reps (torch.Tensor): Passages rerpresentations.
107+
all_scores (torch.Tensor): All the query-passage scores computed.
108+
109+
Returns:
110+
torch.Tensor: Local scores to compute loss.
111+
"""
66112
group_size = p_reps.size(0) // q_reps.size(0)
67113
indices = torch.arange(0, q_reps.size(0), device=q_reps.device) * group_size
68114
specific_scores = []
@@ -73,6 +119,17 @@ def get_local_score(self, q_reps, p_reps, all_scores):
73119
return torch.stack(specific_scores, dim=1).view(q_reps.size(0), -1)
74120

75121
def compute_local_score(self, q_reps, p_reps, compute_score_func=None, **kwargs):
122+
"""Compute the local score of queries and passages.
123+
124+
Args:
125+
q_reps (torch.Tensor): Queries representations.
126+
p_reps (torch.Tensor): Passages rerpresentations.
127+
compute_score_func (function, optional): Function to compute score. Defaults to ``None``, which will use the
128+
:meth:`self.compute_score`.
129+
130+
Returns:
131+
torch.Tensor: Local scores to compute loss.
132+
"""
76133
if compute_score_func is None:
77134
all_scores = self.compute_score(q_reps, p_reps)
78135
else:
@@ -181,6 +238,17 @@ def forward(
181238
teacher_scores: Union[None, List[float]] = None,
182239
no_in_batch_neg_flag: bool = False,
183240
):
241+
"""The computation performed at every call.
242+
243+
Args:
244+
queries (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input queries. Defaults to ``None``.
245+
passages (Union[Dict[str, Tensor], List[Dict[str, Tensor]]], optional): Input passages. Defaults to ``None``.
246+
teacher_scores (Union[None, List[float]], optional): Teacher scores for distillation. Defaults to ``None``.
247+
no_in_batch_neg_flag (bool, optional): If True, use no in-batch negatives and no cross-device negatives. Defaults to ``False``.
248+
249+
Returns:
250+
EmbedderOutput: Output of the forward call of model.
251+
"""
184252
q_reps = self.encode(queries) # (batch_size, dim)
185253
p_reps = self.encode(passages) # (batch_size * group_size, dim)
186254

@@ -210,6 +278,20 @@ def forward(
210278

211279
@staticmethod
212280
def distill_loss(kd_loss_type, teacher_targets, student_scores, group_size=None):
281+
"""Compute the distillation loss.
282+
283+
Args:
284+
kd_loss_type (str): Type of knowledge distillation loss, supports "kl_div" and "m3_kd_loss".
285+
teacher_targets (torch.Tensor): Targets from the teacher model.
286+
student_scores (torch.Tensor): Score of student model.
287+
group_size (int, optional): Number of groups for . Defaults to ``None``.
288+
289+
Raises:
290+
ValueError: Invalid kd_loss_type
291+
292+
Returns:
293+
torch.Tensor: A scalar of computed distillation loss.
294+
"""
213295
if kd_loss_type == 'kl_div':
214296
# teacher_targets: (batch_size, group_size) / (world_size * batch_size, group_size)
215297
# student_scores: (batch_size, group_size) / (world_size * batch_size, group_size)
@@ -236,6 +318,15 @@ def distill_loss(kd_loss_type, teacher_targets, student_scores, group_size=None)
236318
raise ValueError(f"Invalid kd_loss_type: {kd_loss_type}")
237319

238320
def _dist_gather_tensor(self, t: Optional[torch.Tensor]):
321+
"""Gather a tensor from all processes in a distributed setting.
322+
323+
Args:
324+
t (Optional[torch.Tensor]): The input tensor to be gathered. If `None`, no gathering is performed.
325+
326+
Returns:
327+
Union[torch.Tensor, None]: A concatenated tensor from all processes if ``t`` is not ``None``,
328+
otherwise returns ``None``.
329+
"""
239330
if t is None:
240331
return None
241332
t = t.contiguous()

0 commit comments

Comments
 (0)