Skip to content

Commit 1f9e939

Browse files
Rayan Dasoriyacopybara-github
authored andcommitted
No public description
MG_DOCKER_CODES_PIPER_ORIGIN_REV_ID: 878215121
1 parent 0628351 commit 1f9e939

2 files changed

Lines changed: 16 additions & 56 deletions

File tree

notebooks/community/model_garden/docker_source_codes/model_oss/notebook_util/dataset_validation_util.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import multiprocessing
99
import os
1010
import subprocess
11-
import sys
1211
from typing import Any, Union
1312
from absl import logging
1413
import accelerate
@@ -552,51 +551,32 @@ def drop_long_sequences(
552551
input_column: str,
553552
max_sequence_length: int,
554553
tokenizer: transformers.PreTrainedTokenizer,
555-
dataset_dropped_threshold: float,
556554
is_train: bool,
557555
) -> tuple[Any, Any, int]:
558-
"""Returns the dataset by removing examples that are longer than max_seq_length.
556+
"""Drops examples longer than max_seq_length from the dataset.
559557
560558
Args:
561559
dataset: The dataset to filter.
562560
dataset_with_template: The dataset with template to filter.
563561
input_column: The input column in the dataset to be used.
564562
max_sequence_length: The maximum sequence length.
565563
tokenizer: The tokenizer.
566-
dataset_dropped_threshold: The threshold for the number of samples dropped
567-
from the dataset.
568564
is_train: Whether the dataset is for training.
569565
570566
Returns:
571567
A tuple of (filtered_dataset, filtered_dataset_with_template,
572568
dropped_samples).
573569
"""
570+
574571
context_name = f"the {'train' if is_train else 'eval'} dataset"
575-
indices_to_keep, original_length, dropped_samples = (
576-
_get_indices_for_valid_length(
577-
dataset_with_template,
578-
input_column,
579-
max_sequence_length,
580-
tokenizer,
581-
context_name,
582-
)
572+
indices_to_keep, _, dropped_samples = _get_indices_for_valid_length(
573+
dataset_with_template,
574+
input_column,
575+
max_sequence_length,
576+
tokenizer,
577+
context_name,
583578
)
584579

585-
if (
586-
original_length > 0
587-
and dropped_samples / original_length * 100 > dataset_dropped_threshold
588-
):
589-
logging.error(
590-
"More than %f%% of the samples were dropped from {%s} after"
591-
" filtering for max_sequence_length=%d. Please check your dataset.",
592-
dataset_dropped_threshold,
593-
context_name,
594-
max_sequence_length,
595-
)
596-
597-
# handling library when available.
598-
sys.exit(1)
599-
600580
filtered_dataset = dataset.select(indices_to_keep)
601581
filtered_dataset_with_template = dataset_with_template.select(indices_to_keep)
602582
return filtered_dataset, filtered_dataset_with_template, dropped_samples

notebooks/community/model_garden/docker_source_codes/notebook_util/dataset_validation_util.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import multiprocessing
99
import os
1010
import subprocess
11-
import sys
1211
from typing import Any, Union
1312
from absl import logging
1413
import accelerate
@@ -552,51 +551,32 @@ def drop_long_sequences(
552551
input_column: str,
553552
max_sequence_length: int,
554553
tokenizer: transformers.PreTrainedTokenizer,
555-
dataset_dropped_threshold: float,
556554
is_train: bool,
557555
) -> tuple[Any, Any, int]:
558-
"""Returns the dataset by removing examples that are longer than max_seq_length.
556+
"""Drops examples longer than max_seq_length from the dataset.
559557
560558
Args:
561559
dataset: The dataset to filter.
562560
dataset_with_template: The dataset with template to filter.
563561
input_column: The input column in the dataset to be used.
564562
max_sequence_length: The maximum sequence length.
565563
tokenizer: The tokenizer.
566-
dataset_dropped_threshold: The threshold for the number of samples dropped
567-
from the dataset.
568564
is_train: Whether the dataset is for training.
569565
570566
Returns:
571567
A tuple of (filtered_dataset, filtered_dataset_with_template,
572568
dropped_samples).
573569
"""
570+
574571
context_name = f"the {'train' if is_train else 'eval'} dataset"
575-
indices_to_keep, original_length, dropped_samples = (
576-
_get_indices_for_valid_length(
577-
dataset_with_template,
578-
input_column,
579-
max_sequence_length,
580-
tokenizer,
581-
context_name,
582-
)
572+
indices_to_keep, _, dropped_samples = _get_indices_for_valid_length(
573+
dataset_with_template,
574+
input_column,
575+
max_sequence_length,
576+
tokenizer,
577+
context_name,
583578
)
584579

585-
if (
586-
original_length > 0
587-
and dropped_samples / original_length * 100 > dataset_dropped_threshold
588-
):
589-
logging.error(
590-
"More than %f%% of the samples were dropped from {%s} after"
591-
" filtering for max_sequence_length=%d. Please check your dataset.",
592-
dataset_dropped_threshold,
593-
context_name,
594-
max_sequence_length,
595-
)
596-
597-
# handling library when available.
598-
sys.exit(1)
599-
600580
filtered_dataset = dataset.select(indices_to_keep)
601581
filtered_dataset_with_template = dataset_with_template.select(indices_to_keep)
602582
return filtered_dataset, filtered_dataset_with_template, dropped_samples

0 commit comments

Comments
 (0)