1414import math
1515import time
1616import argparse
17+ import datasets
1718from tqdm import tqdm
1819from pprint import pprint
1920from transformers import AutoTokenizer
@@ -54,8 +55,7 @@ def _map_func(examples):
5455 results ['idx' ] = []
5556 results ['max_length' ] = []
5657 for i in range (len (examples ['query' ])):
57- results ['idx' ].append (i )
58-
58+ idx = examples ['idx' ][i ]
5959 query = examples ['query' ][i ]
6060 pos , neg = examples ['pos' ][i ], examples ['neg' ][i ]
6161 all_texts = [query ] + pos + neg
@@ -65,6 +65,8 @@ def _map_func(examples):
6565 tokenized_x = self .tokenizer (x )['input_ids' ]
6666 if len (tokenized_x ) > max_len :
6767 max_len = len (tokenized_x )
68+
69+ results ['idx' ].append (idx )
6870 results ['max_length' ].append (max_len )
6971 return results
7072
@@ -120,8 +122,15 @@ def _process_file(self, file_path: str, output_path: str):
120122 dataset = load_dataset ('json' , data_files = file_path , cache_dir = self .cache_dir , features = features )['train' ]
121123 except :
122124 dataset = load_dataset ('json' , data_files = file_path , cache_dir = self .cache_dir , features = kd_features )['train' ]
123- mapped_dataset = dataset .map (self ._map_func , batched = True , num_proc = self .num_proc )
124125
126+ dataset_with_idx_list = []
127+ for i , data in enumerate (dataset ):
128+ data ['idx' ] = i
129+ dataset_with_idx_list .append (data )
130+ dataset_with_idx = datasets .Dataset .from_list (dataset_with_idx_list )
131+
132+ mapped_dataset = dataset_with_idx .map (self ._map_func , batched = True , num_proc = self .num_proc )
133+
125134 split_info_dict = {}
126135 for length_l , length_r in self .length_ranges_list :
127136 save_path = output_path + f'_len-{ length_l } -{ length_r } .jsonl'
@@ -130,7 +139,8 @@ def _process_file(self, file_path: str, output_path: str):
130139 continue
131140
132141 idxs = mapped_dataset .filter (lambda x : length_l <= x ['max_length' ] < length_r , num_proc = self .num_proc )
133- split_dataset = dataset .select (idxs ['idx' ])
142+ split_dataset = dataset_with_idx .select (idxs ['idx' ])
143+ split_dataset = split_dataset .remove_columns ('idx' )
134144
135145 split_info_dict [f'len-{ length_l } -{ length_r } ' ] = len (split_dataset )
136146
0 commit comments