Skip to content

Commit cb4916f

Browse files
Rayan Dasoriyacopybara-github
authored andcommitted
No public description
MG_DOCKER_CODES_PIPER_ORIGIN_REV_ID: 875506020
1 parent 2933fe6 commit cb4916f

2 files changed

Lines changed: 250 additions & 50 deletions

File tree

notebooks/community/model_garden/docker_source_codes/model_oss/notebook_util/dataset_validation_util.py

Lines changed: 125 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import multiprocessing
99
import os
1010
import subprocess
11+
import sys
1112
from typing import Any, Union
1213
from absl import logging
1314
import accelerate
@@ -387,6 +388,57 @@ def load_tokenizer(
387388
return tokenizer
388389

389390

391+
def _get_indices_for_valid_length(
392+
dataset: Any,
393+
input_column: str,
394+
max_sequence_length: int,
395+
tokenizer: transformers.PreTrainedTokenizer,
396+
context_name: str = "the dataset",
397+
) -> tuple[list[int], int, int]:
398+
"""Gets indices of examples shorter than or equal to max_seq_length.
399+
400+
Args:
401+
dataset: The dataset to check.
402+
input_column: The input column in the dataset.
403+
max_sequence_length: The maximum sequence length.
404+
tokenizer: The tokenizer.
405+
context_name: A name for the dataset used in log messages.
406+
407+
Returns:
408+
A tuple of (indices_to_keep, original_length, dropped_samples).
409+
"""
410+
if not dataset:
411+
return [], 0, 0
412+
413+
original_length = len(dataset)
414+
indices_to_keep = [
415+
i
416+
for i, entry in enumerate(dataset)
417+
if len(tokenizer(entry[input_column])["input_ids"]) <= max_sequence_length
418+
]
419+
dropped_samples = original_length - len(indices_to_keep)
420+
421+
if dropped_samples > 0:
422+
examples_removed_percent = (dropped_samples * 100) / original_length
423+
logging.info(
424+
"(%.2f%%) of examples token length is <= max-seq-length(%d); (%.2f%%) >"
425+
" max-seq-length in %s. %d example(s) were longer than max-seq-length.",
426+
100 - examples_removed_percent,
427+
max_sequence_length,
428+
examples_removed_percent,
429+
context_name,
430+
dropped_samples,
431+
)
432+
else:
433+
logging.info(
434+
"No samples were dropped from %s because all samples are"
435+
" shorter than max_sequence_length=%d.",
436+
context_name,
437+
max_sequence_length,
438+
)
439+
return indices_to_keep, original_length, dropped_samples
440+
441+
390442
def get_filtered_dataset(
391443
dataset: Any,
392444
input_column: str,
@@ -411,33 +463,25 @@ def get_filtered_dataset(
411463
ValueError: If more than `example_removed_threshold` of the dataset is
412464
filtered out.
413465
"""
414-
actual_dataset_length = len(dataset)
415-
filtered_dataset = dataset.filter(
416-
lambda x: len(tokenizer(x[input_column])["input_ids"]) <= max_seq_length
466+
indices_to_keep, original_length, dropped_samples = (
467+
_get_indices_for_valid_length(
468+
dataset, input_column, max_seq_length, tokenizer, "the dataset"
469+
)
417470
)
418-
filtered_dataset_length = len(filtered_dataset)
419-
if actual_dataset_length != filtered_dataset_length:
420-
examples_removed_percent = (
421-
(actual_dataset_length - filtered_dataset_length)
422-
* 100
423-
/ actual_dataset_length
424-
)
425-
logging.info(
426-
"(%.2f%%) of examples token length is <= max-seq-length(%d); (%.2f%%) >"
427-
" max-seq-length. Filtering out %d example(s) which are longer than"
428-
" max-seq-length.",
429-
100 - examples_removed_percent,
430-
max_seq_length,
431-
examples_removed_percent,
432-
actual_dataset_length - filtered_dataset_length,
471+
472+
if (
473+
original_length > 0
474+
and dropped_samples / original_length * 100 > example_removed_threshold
475+
):
476+
examples_removed_percent = (dropped_samples * 100) / original_length
477+
raise ValueError(
478+
f"More than {examples_removed_percent:.2f}% of the dataset is filtered"
479+
" out. This may be due to small value of"
480+
f" max-seq-length({max_seq_length}) or incorrect template. Please"
481+
" increase the max-seq-length or check the template."
433482
)
434-
if examples_removed_percent > example_removed_threshold:
435-
raise ValueError(
436-
"More than %.2f%% of the dataset is filtered out. This may be due to"
437-
" small value of max-seq-length(%d) or incorrect template. Please"
438-
" increase the max-seq-length or check the template."
439-
% (examples_removed_percent, max_seq_length)
440-
)
483+
484+
filtered_dataset = dataset.select(indices_to_keep)
441485
print(f"Some formatted examples from the dataset are: {filtered_dataset[:5]}")
442486
return filtered_dataset
443487

@@ -502,6 +546,62 @@ def load_dataset_with_template(
502546
return raw, templated
503547

504548

549+
def drop_long_sequences(
550+
dataset: Any,
551+
dataset_with_template: Any,
552+
input_column: str,
553+
max_sequence_length: int,
554+
tokenizer: transformers.PreTrainedTokenizer,
555+
dataset_dropped_threshold: float,
556+
is_train: bool,
557+
) -> tuple[Any, Any, int]:
558+
"""Returns the dataset by removing examples that are longer than max_seq_length.
559+
560+
Args:
561+
dataset: The dataset to filter.
562+
dataset_with_template: The dataset with template to filter.
563+
input_column: The input column in the dataset to be used.
564+
max_sequence_length: The maximum sequence length.
565+
tokenizer: The tokenizer.
566+
dataset_dropped_threshold: The threshold for the number of samples dropped
567+
from the dataset.
568+
is_train: Whether the dataset is for training.
569+
570+
Returns:
571+
A tuple of (filtered_dataset, filtered_dataset_with_template,
572+
dropped_samples).
573+
"""
574+
context_name = f"the {'train' if is_train else 'eval'} dataset"
575+
indices_to_keep, original_length, dropped_samples = (
576+
_get_indices_for_valid_length(
577+
dataset_with_template,
578+
input_column,
579+
max_sequence_length,
580+
tokenizer,
581+
context_name,
582+
)
583+
)
584+
585+
if (
586+
original_length > 0
587+
and dropped_samples / original_length * 100 > dataset_dropped_threshold
588+
):
589+
logging.error(
590+
"More than %f%% of the samples were dropped from {%s} after"
591+
" filtering for max_sequence_length=%d. Please check your dataset.",
592+
dataset_dropped_threshold,
593+
context_name,
594+
max_sequence_length,
595+
)
596+
597+
# handling library when available.
598+
sys.exit(1)
599+
600+
filtered_dataset = dataset.select(indices_to_keep)
601+
filtered_dataset_with_template = dataset_with_template.select(indices_to_keep)
602+
return filtered_dataset, filtered_dataset_with_template, dropped_samples
603+
604+
505605
def validate_dataset_with_template(
506606
dataset_name: str,
507607
split: str,

notebooks/community/model_garden/docker_source_codes/notebook_util/dataset_validation_util.py

Lines changed: 125 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import multiprocessing
99
import os
1010
import subprocess
11+
import sys
1112
from typing import Any, Union
1213
from absl import logging
1314
import accelerate
@@ -387,6 +388,57 @@ def load_tokenizer(
387388
return tokenizer
388389

389390

391+
def _get_indices_for_valid_length(
392+
dataset: Any,
393+
input_column: str,
394+
max_sequence_length: int,
395+
tokenizer: transformers.PreTrainedTokenizer,
396+
context_name: str = "the dataset",
397+
) -> tuple[list[int], int, int]:
398+
"""Gets indices of examples shorter than or equal to max_seq_length.
399+
400+
Args:
401+
dataset: The dataset to check.
402+
input_column: The input column in the dataset.
403+
max_sequence_length: The maximum sequence length.
404+
tokenizer: The tokenizer.
405+
context_name: A name for the dataset used in log messages.
406+
407+
Returns:
408+
A tuple of (indices_to_keep, original_length, dropped_samples).
409+
"""
410+
if not dataset:
411+
return [], 0, 0
412+
413+
original_length = len(dataset)
414+
indices_to_keep = [
415+
i
416+
for i, entry in enumerate(dataset)
417+
if len(tokenizer(entry[input_column])["input_ids"]) <= max_sequence_length
418+
]
419+
dropped_samples = original_length - len(indices_to_keep)
420+
421+
if dropped_samples > 0:
422+
examples_removed_percent = (dropped_samples * 100) / original_length
423+
logging.info(
424+
"(%.2f%%) of examples token length is <= max-seq-length(%d); (%.2f%%) >"
425+
" max-seq-length in %s. %d example(s) were longer than max-seq-length.",
426+
100 - examples_removed_percent,
427+
max_sequence_length,
428+
examples_removed_percent,
429+
context_name,
430+
dropped_samples,
431+
)
432+
else:
433+
logging.info(
434+
"No samples were dropped from %s because all samples are"
435+
" shorter than max_sequence_length=%d.",
436+
context_name,
437+
max_sequence_length,
438+
)
439+
return indices_to_keep, original_length, dropped_samples
440+
441+
390442
def get_filtered_dataset(
391443
dataset: Any,
392444
input_column: str,
@@ -411,33 +463,25 @@ def get_filtered_dataset(
411463
ValueError: If more than `example_removed_threshold` of the dataset is
412464
filtered out.
413465
"""
414-
actual_dataset_length = len(dataset)
415-
filtered_dataset = dataset.filter(
416-
lambda x: len(tokenizer(x[input_column])["input_ids"]) <= max_seq_length
466+
indices_to_keep, original_length, dropped_samples = (
467+
_get_indices_for_valid_length(
468+
dataset, input_column, max_seq_length, tokenizer, "the dataset"
469+
)
417470
)
418-
filtered_dataset_length = len(filtered_dataset)
419-
if actual_dataset_length != filtered_dataset_length:
420-
examples_removed_percent = (
421-
(actual_dataset_length - filtered_dataset_length)
422-
* 100
423-
/ actual_dataset_length
424-
)
425-
logging.info(
426-
"(%.2f%%) of examples token length is <= max-seq-length(%d); (%.2f%%) >"
427-
" max-seq-length. Filtering out %d example(s) which are longer than"
428-
" max-seq-length.",
429-
100 - examples_removed_percent,
430-
max_seq_length,
431-
examples_removed_percent,
432-
actual_dataset_length - filtered_dataset_length,
471+
472+
if (
473+
original_length > 0
474+
and dropped_samples / original_length * 100 > example_removed_threshold
475+
):
476+
examples_removed_percent = (dropped_samples * 100) / original_length
477+
raise ValueError(
478+
f"More than {examples_removed_percent:.2f}% of the dataset is filtered"
479+
" out. This may be due to small value of"
480+
f" max-seq-length({max_seq_length}) or incorrect template. Please"
481+
" increase the max-seq-length or check the template."
433482
)
434-
if examples_removed_percent > example_removed_threshold:
435-
raise ValueError(
436-
"More than %.2f%% of the dataset is filtered out. This may be due to"
437-
" small value of max-seq-length(%d) or incorrect template. Please"
438-
" increase the max-seq-length or check the template."
439-
% (examples_removed_percent, max_seq_length)
440-
)
483+
484+
filtered_dataset = dataset.select(indices_to_keep)
441485
print(f"Some formatted examples from the dataset are: {filtered_dataset[:5]}")
442486
return filtered_dataset
443487

@@ -502,6 +546,62 @@ def load_dataset_with_template(
502546
return raw, templated
503547

504548

549+
def drop_long_sequences(
550+
dataset: Any,
551+
dataset_with_template: Any,
552+
input_column: str,
553+
max_sequence_length: int,
554+
tokenizer: transformers.PreTrainedTokenizer,
555+
dataset_dropped_threshold: float,
556+
is_train: bool,
557+
) -> tuple[Any, Any, int]:
558+
"""Returns the dataset by removing examples that are longer than max_seq_length.
559+
560+
Args:
561+
dataset: The dataset to filter.
562+
dataset_with_template: The dataset with template to filter.
563+
input_column: The input column in the dataset to be used.
564+
max_sequence_length: The maximum sequence length.
565+
tokenizer: The tokenizer.
566+
dataset_dropped_threshold: The threshold for the number of samples dropped
567+
from the dataset.
568+
is_train: Whether the dataset is for training.
569+
570+
Returns:
571+
A tuple of (filtered_dataset, filtered_dataset_with_template,
572+
dropped_samples).
573+
"""
574+
context_name = f"the {'train' if is_train else 'eval'} dataset"
575+
indices_to_keep, original_length, dropped_samples = (
576+
_get_indices_for_valid_length(
577+
dataset_with_template,
578+
input_column,
579+
max_sequence_length,
580+
tokenizer,
581+
context_name,
582+
)
583+
)
584+
585+
if (
586+
original_length > 0
587+
and dropped_samples / original_length * 100 > dataset_dropped_threshold
588+
):
589+
logging.error(
590+
"More than %f%% of the samples were dropped from {%s} after"
591+
" filtering for max_sequence_length=%d. Please check your dataset.",
592+
dataset_dropped_threshold,
593+
context_name,
594+
max_sequence_length,
595+
)
596+
597+
# handling library when available.
598+
sys.exit(1)
599+
600+
filtered_dataset = dataset.select(indices_to_keep)
601+
filtered_dataset_with_template = dataset_with_template.select(indices_to_keep)
602+
return filtered_dataset, filtered_dataset_with_template, dropped_samples
603+
604+
505605
def validate_dataset_with_template(
506606
dataset_name: str,
507607
split: str,

0 commit comments

Comments
 (0)