Skip to content

Commit d76e51c

Browse files
authored
Merge pull request #1177 from ZiyiXia/master
Docstring of abc
2 parents dd7d32b + 7ae0ecf commit d76e51c

9 files changed

Lines changed: 400 additions & 12 deletions

File tree

.gitignore

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,4 +139,9 @@ pic2.py
139139
.pyre/
140140

141141
# MacOS associated
142-
.DS_Store
142+
.DS_Store
143+
144+
# results
145+
en_results
146+
zh_results
147+
docs

FlagEmbedding/abc/evaluation/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
@dataclass
1010
class AbsEvalArgs:
11+
"""
12+
Base class for evaluation arguments.
13+
"""
1114
eval_name: str = field(
1215
default=None,
1316
metadata={"help": "The name of the evaluation task, such as msmarco, beir, miracl, etc."}
@@ -77,6 +80,9 @@ class AbsEvalArgs:
7780

7881
@dataclass
7982
class AbsEvalModelArgs:
83+
"""
84+
Base class for model arguments during evaluation.
85+
"""
8086
embedder_name_or_path: str = field(
8187
metadata={"help": "The embedder name or path.", "required": True}
8288
)

FlagEmbedding/abc/evaluation/data_loader.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212

1313

1414
class AbsEvalDataLoader(ABC):
15+
"""
16+
Base class of data loader for evaluation.
17+
18+
Args:
19+
eval_name (str): The experiment name of current evaluation.
20+
dataset_dir (str, optional): path to the datasets. Defaults to None.
21+
cache_dir (str, optional): Path to HuggingFace cache directory. Defaults to None.
22+
token (str, optional): HF_TOKEN to access the private datasets/models in HF. Defaults to None.
23+
force_redownload: If True, will force redownload the dataset to cover the local dataset. Defaults to False.
24+
"""
1525
def __init__(
1626
self,
1727
eval_name: str,
@@ -43,6 +53,17 @@ def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
4353
pass
4454

4555
def check_dataset_names(self, dataset_names: Union[str, List[str]]) -> List[str]:
56+
"""Check the validity of dataset names
57+
58+
Args:
59+
dataset_names (Union[str, List[str]]): a dataset name (str) or a list of dataset names (List[str])
60+
61+
Raises:
62+
ValueError
63+
64+
Returns:
65+
List[str]: List of valid dataset names.
66+
"""
4667
available_dataset_names = self.available_dataset_names()
4768
if isinstance(dataset_names, str):
4869
dataset_names = [dataset_names]
@@ -53,6 +74,15 @@ def check_dataset_names(self, dataset_names: Union[str, List[str]]) -> List[str]
5374
return dataset_names
5475

5576
def check_splits(self, splits: Union[str, List[str]], dataset_name: Optional[str] = None) -> List[str]:
77+
"""Check whether the splits are available in the dataset.
78+
79+
Args:
80+
splits (Union[str, List[str]]): Splits to check.
81+
dataset_name (Optional[str], optional): Name of dataset to check. Defaults to None.
82+
83+
Returns:
84+
List[str]: The available splits.
85+
"""
5686
available_splits = self.available_splits(dataset_name=dataset_name)
5787
if isinstance(splits, str):
5888
splits = [splits]
@@ -65,6 +95,14 @@ def check_splits(self, splits: Union[str, List[str]], dataset_name: Optional[str
6595
return checked_splits
6696

6797
def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDict:
98+
"""Load the corpus from the dataset.
99+
100+
Args:
101+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
102+
103+
Returns:
104+
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
105+
"""
68106
if self.dataset_dir is not None:
69107
if dataset_name is None:
70108
save_dir = self.dataset_dir
@@ -75,6 +113,18 @@ def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDic
75113
return self._load_remote_corpus(dataset_name=dataset_name)
76114

77115
def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
116+
"""Load the corpus from the dataset.
117+
118+
Args:
119+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
120+
split (str, optional): The split to load relevance from. Defaults to 'test'.
121+
122+
Raises:
123+
ValueError
124+
125+
Returns:
126+
datasets.DatasetDict: A dict of relevance of query and document.
127+
"""
78128
if self.dataset_dir is not None:
79129
if dataset_name is None:
80130
save_dir = self.dataset_dir
@@ -91,6 +141,18 @@ def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') ->
91141
return self._load_remote_qrels(dataset_name=dataset_name, split=split)
92142

93143
def load_queries(self, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
144+
"""Load the queries from the dataset.
145+
146+
Args:
147+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
148+
split (str, optional): The split to load queries from. Defaults to 'test'.
149+
150+
Raises:
151+
ValueError
152+
153+
Returns:
154+
datasets.DatasetDict: A dict of queries with id as key, query text as value.
155+
"""
94156
if self.dataset_dir is not None:
95157
if dataset_name is None:
96158
save_dir = self.dataset_dir
@@ -111,6 +173,18 @@ def _load_remote_corpus(
111173
dataset_name: Optional[str] = None,
112174
save_dir: Optional[str] = None
113175
) -> datasets.DatasetDict:
176+
"""Abstract method to load corpus from remote dataset, to be overrode in child class.
177+
178+
Args:
179+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
180+
save_dir (Optional[str], optional): Path to save the new downloaded corpus. Defaults to None.
181+
182+
Raises:
183+
NotImplementedError: Loading remote corpus is not implemented.
184+
185+
Returns:
186+
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
187+
"""
114188
raise NotImplementedError("Loading remote corpus is not implemented.")
115189

116190
def _load_remote_qrels(
@@ -119,6 +193,19 @@ def _load_remote_qrels(
119193
split: str = 'test',
120194
save_dir: Optional[str] = None
121195
) -> datasets.DatasetDict:
196+
"""Abstract method to load relevance from remote dataset, to be overrode in child class.
197+
198+
Args:
199+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
200+
split (str, optional): Split to load from the remote dataset. Defaults to 'test'.
201+
save_dir (Optional[str], optional): Path to save the new downloaded relevance. Defaults to None.
202+
203+
Raises:
204+
NotImplementedError: Loading remote qrels is not implemented.
205+
206+
Returns:
207+
datasets.DatasetDict: A dict of relevance of query and document.
208+
"""
122209
raise NotImplementedError("Loading remote qrels is not implemented.")
123210

124211
def _load_remote_queries(
@@ -127,9 +214,31 @@ def _load_remote_queries(
127214
split: str = 'test',
128215
save_dir: Optional[str] = None
129216
) -> datasets.DatasetDict:
217+
"""Abstract method to load queries from remote dataset, to be overrode in child class.
218+
219+
Args:
220+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
221+
split (str, optional): Split to load from the remote dataset. Defaults to 'test'.
222+
save_dir (Optional[str], optional): Path to save the new downloaded queries. Defaults to None.
223+
224+
Raises:
225+
NotImplementedError
226+
227+
Returns:
228+
datasets.DatasetDict: A dict of queries with id as key, query text as value.
229+
"""
130230
raise NotImplementedError("Loading remote queries is not implemented.")
131231

132232
def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None) -> datasets.DatasetDict:
233+
"""Load corpus from local dataset.
234+
235+
Args:
236+
save_dir (str): Path to save the loaded corpus.
237+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
238+
239+
Returns:
240+
datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
241+
"""
133242
corpus_path = os.path.join(save_dir, 'corpus.jsonl')
134243
if self.force_redownload or not os.path.exists(corpus_path):
135244
logger.warning(f"Corpus not found in {corpus_path}. Trying to download the corpus from the remote and save it to {save_dir}.")
@@ -144,6 +253,19 @@ def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None)
144253
return datasets.DatasetDict(corpus)
145254

146255
def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
256+
"""Load relevance from local dataset.
257+
258+
Args:
259+
save_dir (str): Path to save the loaded relevance.
260+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
261+
split (str, optional): Split to load from the local dataset. Defaults to 'test'.
262+
263+
Raises:
264+
ValueError
265+
266+
Returns:
267+
datasets.DatasetDict: A dict of relevance of query and document.
268+
"""
147269
checked_split = self.check_splits(split)
148270
if len(checked_split) == 0:
149271
raise ValueError(f"Split {split} not found in the dataset.")
@@ -166,6 +288,19 @@ def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, s
166288
return datasets.DatasetDict(qrels)
167289

168290
def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None, split: str = 'test') -> datasets.DatasetDict:
291+
"""Load queries from local dataset.
292+
293+
Args:
294+
save_dir (str): Path to save the loaded queries.
295+
dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
296+
split (str, optional): Split to load from the local dataset. Defaults to 'test'.
297+
298+
Raises:
299+
ValueError
300+
301+
Returns:
302+
datasets.DatasetDict: A dict of queries with id as key, query text as value.
303+
"""
169304
checked_split = self.check_splits(split)
170305
if len(checked_split) == 0:
171306
raise ValueError(f"Split {split} not found in the dataset.")
@@ -182,6 +317,18 @@ def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None,
182317
return datasets.DatasetDict(queries)
183318

184319
def _download_file(self, download_url: str, save_dir: str):
320+
"""Download file from provided URL.
321+
322+
Args:
323+
download_url (str): Source URL of the file.
324+
save_dir (str): Path to the directory to save the zip file.
325+
326+
Raises:
327+
FileNotFoundError
328+
329+
Returns:
330+
str: The path of the downloaded file.
331+
"""
185332
save_path = os.path.join(save_dir, download_url.split('/')[-1])
186333

187334
if self.force_redownload or (not os.path.exists(save_path) or os.path.getsize(save_path) == 0):
@@ -201,6 +348,14 @@ def _download_file(self, download_url: str, save_dir: str):
201348
return save_path
202349

203350
def _get_fpath_size(self, fpath: str) -> int:
351+
"""Get the total size of the files in provided path.
352+
353+
Args:
354+
fpath (str): path of files to compute the size.
355+
356+
Returns:
357+
int: The total size in bytes.
358+
"""
204359
if not os.path.isdir(fpath):
205360
return os.path.getsize(fpath)
206361
else:
@@ -212,6 +367,18 @@ def _get_fpath_size(self, fpath: str) -> int:
212367
return total_size
213368

214369
def _download_gz_file(self, download_url: str, save_dir: str):
370+
"""Download and unzip the gzip file from provided URL.
371+
372+
Args:
373+
download_url (str): Source URL of the gzip file.
374+
save_dir (str): Path to the directory to save the gzip file.
375+
376+
Raises:
377+
FileNotFoundError: _description_
378+
379+
Returns:
380+
str: The path to the file after unzip.
381+
"""
215382
gz_file_path = self._download_file(download_url, save_dir)
216383
cmd = ["gzip", "-d", gz_file_path]
217384
try:
@@ -226,6 +393,18 @@ def _download_gz_file(self, download_url: str, save_dir: str):
226393
return file_path
227394

228395
def _download_zip_file(self, download_url: str, save_dir: str):
396+
"""Download and unzip the zip file from provided URL.
397+
398+
Args:
399+
download_url (str): Source URL of the zip file.
400+
save_dir (str): Path to the directory to save the zip file.
401+
402+
Raises:
403+
FileNotFoundError
404+
405+
Returns:
406+
str: The path to the file after unzip.
407+
"""
229408
zip_file_path = self._download_file(download_url, save_dir)
230409
file_path = zip_file_path.replace(".zip", "")
231410
if self.force_redownload or not os.path.exists(file_path):

0 commit comments

Comments
 (0)