1414logger = logging .get_logger (__name__ )
1515
1616
17- # RETRIEVAL_CAND = [(1024,1), (512,2), (256,4), (128,8), (512,1), (256,2), (128,4)]
18- RETRIEVAL_CAND = [(1024 ,1 )]
19-
2017
2118class Data :
19+ def _process_pretrain_data (data , indices ):
20+ outputs = {"labels" : [], "index" : [], "length" : []}
21+ for input_ids , index in zip (data ['input_ids' ], indices ):
22+ outputs ["index" ].append (index )
23+ outputs ["length" ].append (len (input_ids ))
24+ # NOTE: the labels will be automatically generated in Trainer._prepare_inputs
25+ outputs ["labels" ].append (None )
26+ return outputs
27+
2228 def _process_language_modeling (data , indices , tokenizer , min_length , max_length ):
23- outputs = {'input_ids' : [], 'attention_mask' : [], "labels" : [], "length" : [], "index" : []}
29+ outputs = {'input_ids' : [], "labels" : [], "length" : [], "index" : []}
2430
2531 for i , text in enumerate (data ['text' ]):
2632 # truncate text for faster processing
@@ -33,18 +39,20 @@ def _process_language_modeling(data, indices, tokenizer, min_length, max_length)
3339 for k , v in encoded .items ():
3440 encoded [k ] = v [:max_length ]
3541
36- encoded ["labels" ] = encoded ["input_ids" ].copy ()
42+ # NOTE: the labels will be automatically generated in Trainer._prepare_inputs
43+ encoded ["labels" ] = None
3744
3845 for k , v in encoded .items ():
39- outputs [k ].append (v )
46+ if k in outputs :
47+ outputs [k ].append (v )
4048 # length is required for grouping
4149 outputs ["length" ].append (len (encoded ['input_ids' ]))
4250 outputs ["index" ].append (indices [i ])
4351
4452 return outputs
4553
4654 def _process_instruction_tuning (data , indices , tokenizer , chat_template , min_length , max_length , eval_mode = False ):
47- outputs = {'input_ids' : [], 'attention_mask' : [], "labels" : [], "length" : [], "index" : []}
55+ outputs = {'input_ids' : [], "labels" : [], "length" : [], "index" : []}
4856
4957 for i , source in enumerate (data ['conversations' ]):
5058 if source [0 ]["role" ] != 'user' :
@@ -69,6 +77,11 @@ def _process_instruction_tuning(data, indices, tokenizer, chat_template, min_len
6977 add_generation_prompt = eval_mode ,
7078 ).encoded
7179
80+ # NOTE: shift the labels in advance
81+ # labels = encoded["labels"][1:]
82+ # labels.append(-100)
83+ # encoded["labels"] = labels
84+
7285 # skip data that not fall in between min_length and max_length
7386 if min_length is not None and len (encoded ["input_ids" ]) < min_length :
7487 continue
@@ -79,13 +92,14 @@ def _process_instruction_tuning(data, indices, tokenizer, chat_template, min_len
7992 encoded ["labels" ] = labels
8093
8194 for k , v in encoded .items ():
82- outputs [k ].append (v )
95+ if k in outputs :
96+ outputs [k ].append (v )
8397 outputs ['length' ].append (len (encoded ['input_ids' ]))
8498 outputs ['index' ].append (indices [i ])
8599
86100 return outputs
87101
88- def prepare_train_data (data_files = None , tokenizer = None , max_length = 4096 , min_length = 512 , chat_template = "vicuna" , seed = 42 , cache_dir = None , load_from_cache_file = None ):
102+ def prepare_train_data (data_files = None , tokenizer = None , max_length = 4096 , min_length = 512 , chat_template = "vicuna" , seed = 42 , cache_dir = None , load_from_cache_file = None , ignore_index = False , ignore_length = False ):
89103 if data_files is None :
90104 return None
91105
@@ -115,6 +129,7 @@ def prepare_train_data(data_files=None, tokenizer=None, max_length=4096, min_len
115129 if os .path .isdir (data_file ) and os .path .exists (os .path .join (data_file , "dataset_info.json" )):
116130 # the dataset may be save_to_disk in advance
117131 dataset = datasets .load_from_disk (data_file )
132+ dataset = dataset .map (Data ._process_pretrain_data , batched = True , num_proc = 32 , batch_size = 32 , with_indices = True )
118133
119134 else :
120135 # the dataset is a json file
@@ -145,16 +160,18 @@ def prepare_train_data(data_files=None, tokenizer=None, max_length=4096, min_len
145160 dataset = dataset .train_test_split (max_sample_num , seed = seed )["test" ]
146161
147162 # index column is useless in training
148- if "index" in dataset .column_names :
163+ if "index" in dataset .column_names and ignore_index :
149164 dataset = dataset .remove_columns (["index" ])
165+ if "length" in dataset .column_names and ignore_length :
166+ dataset = dataset .remove_columns (["length" ])
150167
151168 train_datasets .append (dataset )
152169
153170 dataset = datasets .concatenate_datasets (train_datasets )
154171
155172 return dataset
156173
157- def prepare_eval_data (data_files = None , tokenizer = None , max_length = 4096 , min_length = 512 , chat_template = "vicuna" , max_eval_num = None , cache_dir = None , seed = 42 , load_from_cache_file = None ):
174+ def prepare_eval_data (data_files = None , tokenizer = None , max_length = 4096 , min_length = 512 , chat_template = "vicuna" , max_eval_num = None , cache_dir = None , seed = 42 , load_from_cache_file = None , ignore_index = False , ignore_length = False ):
158175 if data_files is None :
159176 return None
160177
@@ -186,4 +203,9 @@ def prepare_eval_data(data_files=None, tokenizer=None, max_length=4096, min_leng
186203 raise ValueError (f"Found neither 'text' nor 'conversations' in the training data!" )
187204
188205 dataset = dataset .map (process_fn , batched = True , num_proc = 32 , remove_columns = dataset .column_names , with_indices = True , load_from_cache_file = load_from_cache_file )
206+ if "index" in dataset .column_names and ignore_index :
207+ dataset = dataset .remove_columns (["index" ])
208+ if "length" in dataset .column_names and ignore_length :
209+ dataset = dataset .remove_columns (["length" ])
210+
189211 return dataset
0 commit comments