1212
1313
1414class AbsEvalDataLoader (ABC ):
15+ """
16+ Base class of data loader for evaluation.
17+
18+ Args:
19+ eval_name (str): The experiment name of current evaluation.
20+ dataset_dir (str, optional): path to the datasets. Defaults to None.
21+ cache_dir (str, optional): Path to HuggingFace cache directory. Defaults to None.
22+ token (str, optional): HF_TOKEN to access the private datasets/models in HF. Defaults to None.
23+ force_redownload: If True, will force redownload the dataset to cover the local dataset. Defaults to False.
24+ """
1525 def __init__ (
1626 self ,
1727 eval_name : str ,
@@ -43,6 +53,17 @@ def available_splits(self, dataset_name: Optional[str] = None) -> List[str]:
4353 pass
4454
4555 def check_dataset_names (self , dataset_names : Union [str , List [str ]]) -> List [str ]:
56+ """Check the validity of dataset names
57+
58+ Args:
59+ dataset_names (Union[str, List[str]]): a dataset name (str) or a list of dataset names (List[str])
60+
61+ Raises:
62+ ValueError
63+
64+ Returns:
65+ List[str]: List of valid dataset names.
66+ """
4667 available_dataset_names = self .available_dataset_names ()
4768 if isinstance (dataset_names , str ):
4869 dataset_names = [dataset_names ]
@@ -53,6 +74,15 @@ def check_dataset_names(self, dataset_names: Union[str, List[str]]) -> List[str]
5374 return dataset_names
5475
5576 def check_splits (self , splits : Union [str , List [str ]], dataset_name : Optional [str ] = None ) -> List [str ]:
77+ """Check whether the splits are available in the dataset.
78+
79+ Args:
80+ splits (Union[str, List[str]]): Splits to check.
81+ dataset_name (Optional[str], optional): Name of dataset to check. Defaults to None.
82+
83+ Returns:
84+ List[str]: The available splits.
85+ """
5686 available_splits = self .available_splits (dataset_name = dataset_name )
5787 if isinstance (splits , str ):
5888 splits = [splits ]
@@ -65,6 +95,14 @@ def check_splits(self, splits: Union[str, List[str]], dataset_name: Optional[str
6595 return checked_splits
6696
6797 def load_corpus (self , dataset_name : Optional [str ] = None ) -> datasets .DatasetDict :
98+ """Load the corpus from the dataset.
99+
100+ Args:
101+ dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
102+
103+ Returns:
104+ datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
105+ """
68106 if self .dataset_dir is not None :
69107 if dataset_name is None :
70108 save_dir = self .dataset_dir
@@ -75,6 +113,18 @@ def load_corpus(self, dataset_name: Optional[str] = None) -> datasets.DatasetDic
75113 return self ._load_remote_corpus (dataset_name = dataset_name )
76114
77115 def load_qrels (self , dataset_name : Optional [str ] = None , split : str = 'test' ) -> datasets .DatasetDict :
116+ """Load the corpus from the dataset.
117+
118+ Args:
119+ dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
120+ split (str, optional): The split to load relevance from. Defaults to 'test'.
121+
122+ Raises:
123+ ValueError
124+
125+ Returns:
126+ datasets.DatasetDict: A dict of relevance of query and document.
127+ """
78128 if self .dataset_dir is not None :
79129 if dataset_name is None :
80130 save_dir = self .dataset_dir
@@ -91,6 +141,18 @@ def load_qrels(self, dataset_name: Optional[str] = None, split: str = 'test') ->
91141 return self ._load_remote_qrels (dataset_name = dataset_name , split = split )
92142
93143 def load_queries (self , dataset_name : Optional [str ] = None , split : str = 'test' ) -> datasets .DatasetDict :
144+ """Load the queries from the dataset.
145+
146+ Args:
147+ dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
148+ split (str, optional): The split to load queries from. Defaults to 'test'.
149+
150+ Raises:
151+ ValueError
152+
153+ Returns:
154+ datasets.DatasetDict: A dict of queries with id as key, query text as value.
155+ """
94156 if self .dataset_dir is not None :
95157 if dataset_name is None :
96158 save_dir = self .dataset_dir
@@ -111,6 +173,18 @@ def _load_remote_corpus(
111173 dataset_name : Optional [str ] = None ,
112174 save_dir : Optional [str ] = None
113175 ) -> datasets .DatasetDict :
176+ """Abstract method to load corpus from remote dataset, to be overrode in child class.
177+
178+ Args:
179+ dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
180+ save_dir (Optional[str], optional): Path to save the new downloaded corpus. Defaults to None.
181+
182+ Raises:
183+ NotImplementedError: Loading remote corpus is not implemented.
184+
185+ Returns:
186+ datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
187+ """
114188 raise NotImplementedError ("Loading remote corpus is not implemented." )
115189
116190 def _load_remote_qrels (
@@ -119,6 +193,19 @@ def _load_remote_qrels(
119193 split : str = 'test' ,
120194 save_dir : Optional [str ] = None
121195 ) -> datasets .DatasetDict :
196+ """Abstract method to load relevance from remote dataset, to be overrode in child class.
197+
198+ Args:
199+ dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
200+ split (str, optional): Split to load from the remote dataset. Defaults to 'test'.
201+ save_dir (Optional[str], optional): Path to save the new downloaded relevance. Defaults to None.
202+
203+ Raises:
204+ NotImplementedError: Loading remote qrels is not implemented.
205+
206+ Returns:
207+ datasets.DatasetDict: A dict of relevance of query and document.
208+ """
122209 raise NotImplementedError ("Loading remote qrels is not implemented." )
123210
124211 def _load_remote_queries (
@@ -127,9 +214,31 @@ def _load_remote_queries(
127214 split : str = 'test' ,
128215 save_dir : Optional [str ] = None
129216 ) -> datasets .DatasetDict :
217+ """Abstract method to load queries from remote dataset, to be overrode in child class.
218+
219+ Args:
220+ dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
221+ split (str, optional): Split to load from the remote dataset. Defaults to 'test'.
222+ save_dir (Optional[str], optional): Path to save the new downloaded queries. Defaults to None.
223+
224+ Raises:
225+ NotImplementedError
226+
227+ Returns:
228+ datasets.DatasetDict: A dict of queries with id as key, query text as value.
229+ """
130230 raise NotImplementedError ("Loading remote queries is not implemented." )
131231
132232 def _load_local_corpus (self , save_dir : str , dataset_name : Optional [str ] = None ) -> datasets .DatasetDict :
233+ """Load corpus from local dataset.
234+
235+ Args:
236+ save_dir (str): Path to save the loaded corpus.
237+ dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
238+
239+ Returns:
240+ datasets.DatasetDict: A dict of corpus with id as key, title and text as value.
241+ """
133242 corpus_path = os .path .join (save_dir , 'corpus.jsonl' )
134243 if self .force_redownload or not os .path .exists (corpus_path ):
135244 logger .warning (f"Corpus not found in { corpus_path } . Trying to download the corpus from the remote and save it to { save_dir } ." )
@@ -144,6 +253,19 @@ def _load_local_corpus(self, save_dir: str, dataset_name: Optional[str] = None)
144253 return datasets .DatasetDict (corpus )
145254
146255 def _load_local_qrels (self , save_dir : str , dataset_name : Optional [str ] = None , split : str = 'test' ) -> datasets .DatasetDict :
256+ """Load relevance from local dataset.
257+
258+ Args:
259+ save_dir (str): Path to save the loaded relevance.
260+ dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
261+ split (str, optional): Split to load from the local dataset. Defaults to 'test'.
262+
263+ Raises:
264+ ValueError
265+
266+ Returns:
267+ datasets.DatasetDict: A dict of relevance of query and document.
268+ """
147269 checked_split = self .check_splits (split )
148270 if len (checked_split ) == 0 :
149271 raise ValueError (f"Split { split } not found in the dataset." )
@@ -166,6 +288,19 @@ def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, s
166288 return datasets .DatasetDict (qrels )
167289
168290 def _load_local_queries (self , save_dir : str , dataset_name : Optional [str ] = None , split : str = 'test' ) -> datasets .DatasetDict :
291+ """Load queries from local dataset.
292+
293+ Args:
294+ save_dir (str): Path to save the loaded queries.
295+ dataset_name (Optional[str], optional): Name of the dataset. Defaults to None.
296+ split (str, optional): Split to load from the local dataset. Defaults to 'test'.
297+
298+ Raises:
299+ ValueError
300+
301+ Returns:
302+ datasets.DatasetDict: A dict of queries with id as key, query text as value.
303+ """
169304 checked_split = self .check_splits (split )
170305 if len (checked_split ) == 0 :
171306 raise ValueError (f"Split { split } not found in the dataset." )
@@ -182,6 +317,18 @@ def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None,
182317 return datasets .DatasetDict (queries )
183318
184319 def _download_file (self , download_url : str , save_dir : str ):
320+ """Download file from provided URL.
321+
322+ Args:
323+ download_url (str): Source URL of the file.
324+ save_dir (str): Path to the directory to save the zip file.
325+
326+ Raises:
327+ FileNotFoundError
328+
329+ Returns:
330+ str: The path of the downloaded file.
331+ """
185332 save_path = os .path .join (save_dir , download_url .split ('/' )[- 1 ])
186333
187334 if self .force_redownload or (not os .path .exists (save_path ) or os .path .getsize (save_path ) == 0 ):
@@ -201,6 +348,14 @@ def _download_file(self, download_url: str, save_dir: str):
201348 return save_path
202349
203350 def _get_fpath_size (self , fpath : str ) -> int :
351+ """Get the total size of the files in provided path.
352+
353+ Args:
354+ fpath (str): path of files to compute the size.
355+
356+ Returns:
357+ int: The total size in bytes.
358+ """
204359 if not os .path .isdir (fpath ):
205360 return os .path .getsize (fpath )
206361 else :
@@ -212,6 +367,18 @@ def _get_fpath_size(self, fpath: str) -> int:
212367 return total_size
213368
214369 def _download_gz_file (self , download_url : str , save_dir : str ):
370+ """Download and unzip the gzip file from provided URL.
371+
372+ Args:
373+ download_url (str): Source URL of the gzip file.
374+ save_dir (str): Path to the directory to save the gzip file.
375+
376+ Raises:
377+ FileNotFoundError: _description_
378+
379+ Returns:
380+ str: The path to the file after unzip.
381+ """
215382 gz_file_path = self ._download_file (download_url , save_dir )
216383 cmd = ["gzip" , "-d" , gz_file_path ]
217384 try :
@@ -226,6 +393,18 @@ def _download_gz_file(self, download_url: str, save_dir: str):
226393 return file_path
227394
228395 def _download_zip_file (self , download_url : str , save_dir : str ):
396+ """Download and unzip the zip file from provided URL.
397+
398+ Args:
399+ download_url (str): Source URL of the zip file.
400+ save_dir (str): Path to the directory to save the zip file.
401+
402+ Raises:
403+ FileNotFoundError
404+
405+ Returns:
406+ str: The path to the file after unzip.
407+ """
229408 zip_file_path = self ._download_file (download_url , save_dir )
230409 file_path = zip_file_path .replace (".zip" , "" )
231410 if self .force_redownload or not os .path .exists (file_path ):
0 commit comments