88import multiprocessing
99import os
1010import subprocess
11+ import sys
1112from typing import Any , Union
1213from absl import logging
1314import accelerate
@@ -387,6 +388,57 @@ def load_tokenizer(
387388 return tokenizer
388389
389390
391+ def _get_indices_for_valid_length (
392+ dataset : Any ,
393+ input_column : str ,
394+ max_sequence_length : int ,
395+ tokenizer : transformers .PreTrainedTokenizer ,
396+ context_name : str = "the dataset" ,
397+ ) -> tuple [list [int ], int , int ]:
398+ """Gets indices of examples shorter than or equal to max_seq_length.
399+
400+ Args:
401+ dataset: The dataset to check.
402+ input_column: The input column in the dataset.
403+ max_sequence_length: The maximum sequence length.
404+ tokenizer: The tokenizer.
405+ context_name: A name for the dataset used in log messages.
406+
407+ Returns:
408+ A tuple of (indices_to_keep, original_length, dropped_samples).
409+ """
410+ if not dataset :
411+ return [], 0 , 0
412+
413+ original_length = len (dataset )
414+ indices_to_keep = [
415+ i
416+ for i , entry in enumerate (dataset )
417+ if len (tokenizer (entry [input_column ])["input_ids" ]) <= max_sequence_length
418+ ]
419+ dropped_samples = original_length - len (indices_to_keep )
420+
421+ if dropped_samples > 0 :
422+ examples_removed_percent = (dropped_samples * 100 ) / original_length
423+ logging .info (
424+ "(%.2f%%) of examples token length is <= max-seq-length(%d); (%.2f%%) >"
425+ " max-seq-length in %s. %d example(s) were longer than max-seq-length." ,
426+ 100 - examples_removed_percent ,
427+ max_sequence_length ,
428+ examples_removed_percent ,
429+ context_name ,
430+ dropped_samples ,
431+ )
432+ else :
433+ logging .info (
434+ "No samples were dropped from %s because all samples are"
435+ " shorter than max_sequence_length=%d." ,
436+ context_name ,
437+ max_sequence_length ,
438+ )
439+ return indices_to_keep , original_length , dropped_samples
440+
441+
390442def get_filtered_dataset (
391443 dataset : Any ,
392444 input_column : str ,
@@ -411,33 +463,25 @@ def get_filtered_dataset(
411463 ValueError: If more than `example_removed_threshold` of the dataset is
412464 filtered out.
413465 """
414- actual_dataset_length = len (dataset )
415- filtered_dataset = dataset .filter (
416- lambda x : len (tokenizer (x [input_column ])["input_ids" ]) <= max_seq_length
466+ indices_to_keep , original_length , dropped_samples = (
467+ _get_indices_for_valid_length (
468+ dataset , input_column , max_seq_length , tokenizer , "the dataset"
469+ )
417470 )
418- filtered_dataset_length = len (filtered_dataset )
419- if actual_dataset_length != filtered_dataset_length :
420- examples_removed_percent = (
421- (actual_dataset_length - filtered_dataset_length )
422- * 100
423- / actual_dataset_length
424- )
425- logging .info (
426- "(%.2f%%) of examples token length is <= max-seq-length(%d); (%.2f%%) >"
427- " max-seq-length. Filtering out %d example(s) which are longer than"
428- " max-seq-length." ,
429- 100 - examples_removed_percent ,
430- max_seq_length ,
431- examples_removed_percent ,
432- actual_dataset_length - filtered_dataset_length ,
471+
472+ if (
473+ original_length > 0
474+ and dropped_samples / original_length * 100 > example_removed_threshold
475+ ):
476+ examples_removed_percent = (dropped_samples * 100 ) / original_length
477+ raise ValueError (
478+ f"More than { examples_removed_percent :.2f} % of the dataset is filtered"
479+ " out. This may be due to small value of"
480+ f" max-seq-length({ max_seq_length } ) or incorrect template. Please"
481+ " increase the max-seq-length or check the template."
433482 )
434- if examples_removed_percent > example_removed_threshold :
435- raise ValueError (
436- "More than %.2f%% of the dataset is filtered out. This may be due to"
437- " small value of max-seq-length(%d) or incorrect template. Please"
438- " increase the max-seq-length or check the template."
439- % (examples_removed_percent , max_seq_length )
440- )
483+
484+ filtered_dataset = dataset .select (indices_to_keep )
441485 print (f"Some formatted examples from the dataset are: { filtered_dataset [:5 ]} " )
442486 return filtered_dataset
443487
@@ -502,6 +546,62 @@ def load_dataset_with_template(
502546 return raw , templated
503547
504548
549+ def drop_long_sequences (
550+ dataset : Any ,
551+ dataset_with_template : Any ,
552+ input_column : str ,
553+ max_sequence_length : int ,
554+ tokenizer : transformers .PreTrainedTokenizer ,
555+ dataset_dropped_threshold : float ,
556+ is_train : bool ,
557+ ) -> tuple [Any , Any , int ]:
558+ """Returns the dataset by removing examples that are longer than max_seq_length.
559+
560+ Args:
561+ dataset: The dataset to filter.
562+ dataset_with_template: The dataset with template to filter.
563+ input_column: The input column in the dataset to be used.
564+ max_sequence_length: The maximum sequence length.
565+ tokenizer: The tokenizer.
566+ dataset_dropped_threshold: The threshold for the number of samples dropped
567+ from the dataset.
568+ is_train: Whether the dataset is for training.
569+
570+ Returns:
571+ A tuple of (filtered_dataset, filtered_dataset_with_template,
572+ dropped_samples).
573+ """
574+ context_name = f"the { 'train' if is_train else 'eval' } dataset"
575+ indices_to_keep , original_length , dropped_samples = (
576+ _get_indices_for_valid_length (
577+ dataset_with_template ,
578+ input_column ,
579+ max_sequence_length ,
580+ tokenizer ,
581+ context_name ,
582+ )
583+ )
584+
585+ if (
586+ original_length > 0
587+ and dropped_samples / original_length * 100 > dataset_dropped_threshold
588+ ):
589+ logging .error (
590+ "More than %f%% of the samples were dropped from {%s} after"
591+ " filtering for max_sequence_length=%d. Please check your dataset." ,
592+ dataset_dropped_threshold ,
593+ context_name ,
594+ max_sequence_length ,
595+ )
596+
597+ # handling library when available.
598+ sys .exit (1 )
599+
600+ filtered_dataset = dataset .select (indices_to_keep )
601+ filtered_dataset_with_template = dataset_with_template .select (indices_to_keep )
602+ return filtered_dataset , filtered_dataset_with_template , dropped_samples
603+
604+
505605def validate_dataset_with_template (
506606 dataset_name : str ,
507607 split : str ,
0 commit comments