diff --git a/seqio/__init__.py b/seqio/__init__.py index 3c96519e..5222cde4 100644 --- a/seqio/__init__.py +++ b/seqio/__init__.py @@ -16,8 +16,6 @@ # pylint:disable=wildcard-import,g-bad-import-order from seqio.dataset_providers import * -from seqio.grain_dataset_providers import * -from seqio.dataset_providers_helpers import * from seqio import evaluation from seqio import experimental from seqio.evaluation import Evaluator diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index 1d367589..9209bc16 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -25,14 +25,16 @@ import json import os import re -from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Sequence, Tuple, Type, Union, List from absl import logging import clu.metrics import numpy as np from packaging import version from seqio import metrics as metrics_lib +from seqio import preprocessors as seqio_preprocessors from seqio import utils +from seqio.feature_converters import FeatureConverter from seqio.vocabularies import PassThroughVocabulary from seqio.vocabularies import Vocabulary import tensorflow.compat.v2 as tf @@ -1265,6 +1267,7 @@ def postprocess_fn(self, decoded_model_output: Any, return decoded_model_output + class TaskRegistry(DatasetProviderRegistry): """Registry of Tasks.""" _REGISTRY = {} @@ -1531,6 +1534,8 @@ def filter_features(ex): return dataset + + def _log_padding_fractions(dataset, sequence_length, num_examples=100): """Empirically compute the fraction of padding - log the results. @@ -1650,3 +1655,125 @@ def add(cls, def get(cls, name) -> Mixture: return super().get(name) # pylint: enable=arguments-renamed + + +def get_mixture_or_task(task_or_mixture_name): + """Return the Task or Mixture from the appropriate registry.""" + mixtures = MixtureRegistry.names() + tasks = TaskRegistry.names() + if task_or_mixture_name in mixtures: + if task_or_mixture_name in tasks: + logging.warning("%s is both a Task and a Mixture, returning Mixture", + task_or_mixture_name) + return MixtureRegistry.get(task_or_mixture_name) + if task_or_mixture_name in tasks: + return TaskRegistry.get(task_or_mixture_name) + else: + for available_task in sorted(tasks): + logging.info("Available task: %s", available_task) + for available_mixture in sorted(mixtures): + logging.info("Available mixture: %s", available_mixture) + raise ValueError( + "No Task or Mixture found with name '%s'." % task_or_mixture_name) + + +def get_subtasks(task_or_mixture): + """Returns all the Tasks in a Mixture as a list or the Task itself.""" + if isinstance(task_or_mixture, Task): + return [task_or_mixture] + else: + return task_or_mixture.tasks + + +def get_dataset(mixture_or_task_name: str, + task_feature_lengths: Mapping[str, int], + feature_converter: FeatureConverter, + dataset_split: str = "train", + use_cached: bool = False, + shuffle: bool = False, + num_epochs: Optional[int] = 1, + shard_info: Optional[ShardInfo] = None, + verbose: bool = True, + seed: Optional[int] = None, + batch_size: Optional[int] = None, + trim_output_features: bool = True) -> tf.data.Dataset: + """Get processed dataset with the model features. + + In order to use options specific to a feature converter, e.g., packing, + `feature_converter` instance should be instantiated with those options before + being pased to this function. + + Getting sharded datasets is supported. To use this feature, pass in + `shard_info`, with shard_index and num_shards information. Sharding is done + before the feature converter stage. Therefore, if packing is used it will be + done on the sharded dataset. + + Args: + mixture_or_task_name: mixture or task name for the Task API. + task_feature_lengths: dict mapping task feature key to its sequence length. + This specifies the sequence length of the dataset from the Task API. + feature_converter: a feature converter object to use to convert the task + features to model features. Must be a subclass of FeatureConverter. + dataset_split: the split to use. + use_cached: whether to use the cached dataset instead of processing it on + the fly. + shuffle: whether to shuffle the dataset. + num_epochs: the number of times to iterate through the dataset, or `None` to + repeat indefinitely. Note that the repeat occurs in the pipeline after + offline caching, but before applying potentially stochastic post-cache + preprocessors and is therefore typically preferred to calling `repeat()` + on the returned dataset. Defaults to `1`. + shard_info: number of shards and shard index information. + verbose: if true, log the feature shapes. + seed: a random seed to for shuffling tf.data. + batch_size: Optional batch size. + trim_output_features: If True, it trims output features to be less than + the length given by `sequence_length`. + + Returns: + ds: the processed dataset. + """ + if not isinstance(feature_converter, FeatureConverter): + raise TypeError( + "feature_converter should be an instance of FeatureConverter.") + + mixture_or_task = get_mixture_or_task(mixture_or_task_name) + is_grain_task = False + if is_grain_task: + ds = mixture_or_task.get_dataset( + sequence_length=task_feature_lengths, + split=dataset_split, + use_cached=use_cached, + shuffle=shuffle, + seed=seed, + shard_info=shard_info, + num_epochs=num_epochs, + batch_size=batch_size, + feature_converter=feature_converter, + trim_output_features=trim_output_features) + else: + ds = mixture_or_task.get_dataset( + task_feature_lengths, + split=dataset_split, + use_cached=use_cached, + shuffle=shuffle, + seed=seed, + shard_info=shard_info, + num_epochs=num_epochs, + trim_output_features=trim_output_features) + ds = feature_converter(ds, task_feature_lengths=task_feature_lengths) + if batch_size is not None: + ds = ds.batch(batch_size, drop_remainder=True) + + if verbose: + logging.info( + "The output dataset from seqio.get_dataset has the following features") + element_spec = utils.flatten_dict(ds.element_spec, delimiter=".") + for feature_name, tensor_spec in element_spec.items(): + if isinstance(tensor_spec, tf.TensorSpec): + logging.info("feature: %s \t shape: %s \t dtype: %s", feature_name, + tensor_spec.shape.as_list(), tensor_spec.dtype.name) + else: + logging.error("Unknown tensor_spec type %s for feature %s.", + type(tensor_spec), feature_name) + return ds diff --git a/seqio/dataset_providers_helpers.py b/seqio/dataset_providers_helpers.py deleted file mode 100644 index 3ebd419b..00000000 --- a/seqio/dataset_providers_helpers.py +++ /dev/null @@ -1,160 +0,0 @@ -# Copyright 2022 The SeqIO Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Classes for data loading and processing. - -Defines Tasks, TaskRegistry, Mixture, and MixtureRegistry -""" - -import re -from typing import Mapping, Optional, Union - -from absl import logging -from seqio import dataset_providers -from seqio import feature_converters -from seqio import grain_dataset_providers -from seqio import utils -import tensorflow.compat.v2 as tf - -_DEFAULT_FEATURE_KEYS = ["inputs", "targets"] - -_VALID_TASK_NAME_REGEX = re.compile(r"^[\w\d\.\:_]+$") -_MAX_EXAMPLES_TO_MEM_CACHE = 10000 -SHUFFLE_BUFFER_SIZE = 1000 - -Feature = utils.Feature - - -def get_mixture_or_task( - task_or_mixture_name -) -> Union[dataset_providers.Task, dataset_providers.Mixture]: - """Return the Task or Mixture from the appropriate registry.""" - mixtures = dataset_providers.MixtureRegistry.names() - tasks = dataset_providers.TaskRegistry.names() - if task_or_mixture_name in mixtures: - if task_or_mixture_name in tasks: - logging.warning("%s is both a Task and a Mixture, returning Mixture", - task_or_mixture_name) - return dataset_providers.MixtureRegistry.get(task_or_mixture_name) - if task_or_mixture_name in tasks: - return dataset_providers.TaskRegistry.get(task_or_mixture_name) - else: - for available_task in sorted(tasks): - logging.info("Available task: %s", available_task) - for available_mixture in sorted(mixtures): - logging.info("Available mixture: %s", available_mixture) - raise ValueError( - "No Task or Mixture found with name '%s'." % task_or_mixture_name) - - -def get_subtasks(task_or_mixture): - """Returns all the Tasks in a Mixture as a list or the Task itself.""" - if isinstance(task_or_mixture, dataset_providers.Task): - return [task_or_mixture] - else: - return task_or_mixture.tasks - - -def get_dataset(mixture_or_task_name: str, - task_feature_lengths: Mapping[str, int], - feature_converter: feature_converters.FeatureConverter, - dataset_split: str = "train", - use_cached: bool = False, - shuffle: bool = False, - num_epochs: Optional[int] = 1, - shard_info: Optional[dataset_providers.ShardInfo] = None, - verbose: bool = True, - seed: Optional[int] = None, - batch_size: Optional[int] = None, - trim_output_features: bool = True) -> tf.data.Dataset: - """Get processed dataset with the model features. - - In order to use options specific to a feature converter, e.g., packing, - `feature_converter` instance should be instantiated with those options before - being pased to this function. - - Getting sharded datasets is supported. To use this feature, pass in - `shard_info`, with shard_index and num_shards information. Sharding is done - before the feature converter stage. Therefore, if packing is used it will be - done on the sharded dataset. - - Args: - mixture_or_task_name: mixture or task name for the Task API. - task_feature_lengths: dict mapping task feature key to its sequence length. - This specifies the sequence length of the dataset from the Task API. - feature_converter: a feature converter object to use to convert the task - features to model features. Must be a subclass of FeatureConverter. - dataset_split: the split to use. - use_cached: whether to use the cached dataset instead of processing it on - the fly. - shuffle: whether to shuffle the dataset. - num_epochs: the number of times to iterate through the dataset, or `None` to - repeat indefinitely. Note that the repeat occurs in the pipeline after - offline caching, but before applying potentially stochastic post-cache - preprocessors and is therefore typically preferred to calling `repeat()` - on the returned dataset. Defaults to `1`. - shard_info: number of shards and shard index information. - verbose: if true, log the feature shapes. - seed: a random seed to for shuffling tf.data. - batch_size: Optional batch size. - trim_output_features: If True, it trims output features to be less than - the length given by `sequence_length`. - - Returns: - ds: the processed dataset. - """ - if not isinstance(feature_converter, feature_converters.FeatureConverter): - raise TypeError( - "feature_converter should be an instance of FeatureConverter.") - - mixture_or_task = get_mixture_or_task(mixture_or_task_name) - is_grain_task = False - if is_grain_task: - ds = mixture_or_task.get_dataset( - sequence_length=task_feature_lengths, - split=dataset_split, - use_cached=use_cached, - shuffle=shuffle, - seed=seed, - shard_info=shard_info, - num_epochs=num_epochs, - batch_size=batch_size, - feature_converter=feature_converter, - trim_output_features=trim_output_features) - else: - ds = mixture_or_task.get_dataset( - task_feature_lengths, - split=dataset_split, - use_cached=use_cached, - shuffle=shuffle, - seed=seed, - shard_info=shard_info, - num_epochs=num_epochs, - trim_output_features=trim_output_features) - ds = feature_converter(ds, task_feature_lengths=task_feature_lengths) - if batch_size is not None: - ds = ds.batch(batch_size, drop_remainder=True) - - if verbose: - logging.info( - "The output dataset from seqio.get_dataset has the following features") - element_spec = utils.flatten_dict(ds.element_spec, delimiter=".") - for feature_name, tensor_spec in element_spec.items(): - if isinstance(tensor_spec, tf.TensorSpec): - logging.info("feature: %s \t shape: %s \t dtype: %s", feature_name, - tensor_spec.shape.as_list(), tensor_spec.dtype.name) - else: - logging.error("Unknown tensor_spec type %s for feature %s.", - type(tensor_spec), feature_name) - return ds diff --git a/seqio/dataset_providers_helpers_test.py b/seqio/dataset_providers_helpers_test.py deleted file mode 100644 index 56126377..00000000 --- a/seqio/dataset_providers_helpers_test.py +++ /dev/null @@ -1,457 +0,0 @@ -# Copyright 2022 The SeqIO Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for seqio.dataset_providers.""" - -import copy -from typing import Callable, Sequence - -from absl.testing import absltest -from absl.testing import parameterized -from seqio import dataset_providers -from seqio import dataset_providers_helpers -from seqio import feature_converters -from seqio import preprocessors -from seqio import test_utils -import tensorflow as tf - -tf.compat.v1.enable_eager_execution() - -TaskRegistry = dataset_providers.TaskRegistry -MixtureRegistry = dataset_providers.MixtureRegistry -mock = absltest.mock -assert_dataset = test_utils.assert_dataset -create_default_dataset = test_utils.create_default_dataset - - -class GetDatasetTest(parameterized.TestCase, tf.test.TestCase): - - def test_get_dataset_enc_dec_unpacked(self): - mixture_or_task_name = "enc_dec_unpacked" - x = [{ - "inputs": [7, 8, 5, 6, 9, 4, 3], - "targets": [3, 9] - }, { - "inputs": [8, 4], - "targets": [4] - }, { - "inputs": [5, 6, 7], - "targets": [6, 5] - }] - ds = create_default_dataset(x) - dataset_fn = lambda split, shuffle_files: ds - register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) - - task_feature_lengths = {"inputs": 7, "targets": 5} - converter = feature_converters.EncDecFeatureConverter(pack=False) - output_ds = dataset_providers_helpers.get_dataset( - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=task_feature_lengths, - dataset_split="train", - shuffle=False, - feature_converter=converter) - - expected = [{ - "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1], - "decoder_target_tokens": [3, 9, 1, 0, 0], - "decoder_input_tokens": [0, 3, 9, 1, 0], - "decoder_loss_weights": [1, 1, 1, 0, 0], - }, { - "encoder_input_tokens": [8, 4, 1, 0, 0, 0, 0], - "decoder_target_tokens": [4, 1, 0, 0, 0], - "decoder_input_tokens": [0, 4, 1, 0, 0], - "decoder_loss_weights": [1, 1, 0, 0, 0], - }, { - "encoder_input_tokens": [5, 6, 7, 1, 0, 0, 0], - "decoder_target_tokens": [6, 5, 1, 0, 0], - "decoder_input_tokens": [0, 6, 5, 1, 0], - "decoder_loss_weights": [1, 1, 1, 0, 0], - }] - expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} - assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) - - @parameterized.parameters( - dict( - task_name="enc_dec_partial_trim_both", - task_feature_lengths={ - "inputs": 7, - "targets": 2 - }, - expect_trim_inputs=True, - expect_trim_targets=True), - dict( - task_name="enc_dec_partial_trim_targets", - task_feature_lengths={ - "inputs": None, - "targets": 2 - }, - expect_trim_inputs=False, - expect_trim_targets=True), - dict( - task_name="enc_dec_partial_trim_inputs", - task_feature_lengths={ - "inputs": 7, - "targets": None - }, - expect_trim_inputs=True, - expect_trim_targets=False), - dict( - task_name="enc_dec_partial_trim_neither", - task_feature_lengths={ - "inputs": None, - "targets": None - }, - expect_trim_inputs=False, - expect_trim_targets=False), - dict( - task_name="enc_dec_partial_trim_nothing", - task_feature_lengths=None, - expect_trim_inputs=False, - expect_trim_targets=False)) - def test_partial_sequence_length(self, task_name, task_feature_lengths, - expect_trim_inputs, expect_trim_targets): - x = [{ - "inputs": [7, 8, 5, 6, 9, 4, 3], - "targets": [3, 9] - }, { - "inputs": [8, 4], - "targets": [4] - }, { - "inputs": [5, 6, 7], - "targets": [6, 5] - }] - ds = create_default_dataset(x) - dataset_fn = lambda split, shuffle_files: ds - register_dummy_task(task_name, dataset_fn=dataset_fn) - # Unlike the other tests, don't use a feature converter. Instead, test the - # task.get_dataset method directly, which is similar to how evaluation.py - # infers feature lengths w/trimming. - task = dataset_providers_helpers.get_mixture_or_task(task_name) - output_ds = task.get_dataset( - sequence_length=task_feature_lengths, shuffle=False) - - expected = [{ - "inputs": [7, 8, 5, 6, 9, 4, 3, 1], - "targets": [3, 9, 1], - }, { - "inputs": [8, 4, 1], - "targets": [4, 1], - }, { - "inputs": [5, 6, 7, 1], - "targets": [6, 5, 1], - }] - if expect_trim_inputs: - expected[0]["inputs"] = [7, 8, 5, 6, 9, 4, 1] - if expect_trim_targets: - expected[0]["targets"] = [3, 1] - expected[2]["targets"] = [6, 1] - expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} - assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) - - @parameterized.parameters( - dict( - task_name="enc_dec_multidim_trim_both", - task_feature_lengths={ - "inputs": (2, 5), - "targets": 2 - }, - expect_trim_inputs=True, - expect_trim_targets=True, - ), - dict( - task_name="enc_dec_multidim_trim_inputs", - task_feature_lengths={ - "inputs": (2, 5), - "targets": None - }, - expect_trim_inputs=True, - expect_trim_targets=False, - ), - dict( - task_name="enc_dec_multidim_trim_targets", - task_feature_lengths={ - "inputs": None, - "targets": 2 - }, - expect_trim_inputs=False, - expect_trim_targets=True, - ), - dict( - task_name="enc_dec_no_multidim_trim", - task_feature_lengths={ - "inputs": None, - "targets": None - }, - expect_trim_inputs=False, - expect_trim_targets=False)) - def test_multidimension_sequence_length(self, task_name, task_feature_lengths, - expect_trim_inputs, - expect_trim_targets): - x = [{ - "inputs": [[7, 8, 5, 6, 9, 4, 3], [2, 3, 4, 5, 0, 0, 0], - [6, 7, 1, 0, 0, 0, 0]], - "targets": [3, 9] - }, { - "inputs": [[8, 4], [1, 0], [2, 3]], - "targets": [4] - }, { - "inputs": [[5, 6, 7]], - "targets": [6, 5, 1] - }, { - "inputs": [[7, 8, 9, 1, 2, 3, 4, 5, 6]], - "targets": [10, 11, 1] - }] - ds = tf.data.Dataset.from_generator( - lambda: x, - output_types={ - "inputs": tf.int32, - "targets": tf.int32 - }, - output_shapes={ - "inputs": (None, None), - "targets": (None,) - }) - dataset_fn = lambda split, shuffle_files: ds - dataset_providers.TaskRegistry.add( - task_name, - source=dataset_providers.FunctionDataSource( - dataset_fn=dataset_fn, splits=["train", "validation"]), - preprocessors=[ - dataset_providers.CacheDatasetPlaceholder(), - ], - output_features={ - "inputs": - dataset_providers.Feature( - test_utils.sentencepiece_vocab(), rank=2), - "targets": - dataset_providers.Feature(test_utils.sentencepiece_vocab()) - }, - metric_fns=[]) - # Unlike the other tests, don't use a feature converter. Instead, test the - # task.get_dataset method directly, which is similar to how evaluation.py - # infers feature lengths w/trimming. - task = dataset_providers_helpers.get_mixture_or_task(task_name) - output_ds = task.get_dataset( - sequence_length=task_feature_lengths, shuffle=False) - - expected = copy.deepcopy(x) - if expect_trim_inputs: - expected[0]["inputs"] = [[7, 8, 5, 6, 9], [2, 3, 4, 5, 0]] - expected[1]["inputs"] = [[8, 4], [1, 0]] - expected[3]["inputs"] = [[7, 8, 9, 1, 2]] - if expect_trim_targets: - expected[2]["targets"] = [6, 5] - expected[3]["targets"] = [10, 11] - expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} - assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) - - def test_get_dataset_enc_dec_packed(self): - mixture_or_task_name = "enc_dec_packed" - x = [{ - "inputs": [7, 8, 5, 6, 9, 4, 3], - "targets": [3, 9] - }, { - "inputs": [8, 4], - "targets": [4] - }, { - "inputs": [5, 6, 7], - "targets": [6, 5] - }] - ds = create_default_dataset(x) - dataset_fn = lambda split, shuffle_files: ds - register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) - - task_feature_lengths = {"inputs": 7, "targets": 5} - converter = feature_converters.EncDecFeatureConverter(pack=True) - output_ds = dataset_providers_helpers.get_dataset( - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=task_feature_lengths, - dataset_split="train", - shuffle=False, - feature_converter=converter) - - expected = [ - { - # Example 1 is trimmed - "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1], - "encoder_segment_ids": [1, 1, 1, 1, 1, 1, 1], - "encoder_positions": [0, 1, 2, 3, 4, 5, 6], - "decoder_target_tokens": [3, 9, 1, 0, 0], - "decoder_input_tokens": [0, 3, 9, 0, 0], - "decoder_loss_weights": [1, 1, 1, 0, 0], - "decoder_segment_ids": [1, 1, 1, 0, 0], - "decoder_positions": [0, 1, 2, 0, 0], - }, - { - # Example 2 and 3 are packed together - "encoder_input_tokens": [8, 4, 1, 5, 6, 7, 1], - "encoder_segment_ids": [1, 1, 1, 2, 2, 2, 2], - "encoder_positions": [0, 1, 2, 0, 1, 2, 3], - "decoder_target_tokens": [4, 1, 6, 5, 1], - "decoder_input_tokens": [0, 4, 0, 6, 5], - "decoder_loss_weights": [1, 1, 1, 1, 1], - "decoder_segment_ids": [1, 1, 2, 2, 2], - "decoder_positions": [0, 1, 0, 1, 2], - } - ] - expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} - assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) - - def test_get_dataset_both_train_and_validation_splits(self): - mixture_or_task_name = "both_train_and_validation_splits" - x_train = [{"inputs": [7, 8, 5, 6, 9, 4, 3], "targets": [3, 9]}] - x_val = [{"inputs": [8, 4], "targets": [4]}] - datasets = { - "train": create_default_dataset(x_train), - "validation": create_default_dataset(x_val) - } - dataset_fn = lambda split, shuffle_files: datasets[split] - register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) - - task_feature_lengths = {"inputs": 7, "targets": 5} - output_ds = {} - for split in ["train", "validation"]: - converter = feature_converters.EncDecFeatureConverter(pack=False) - output_ds[split] = dataset_providers_helpers.get_dataset( - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=task_feature_lengths, - dataset_split=split, - shuffle=False, - feature_converter=converter) - - expected_train = { - "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1], - "decoder_target_tokens": [3, 9, 1, 0, 0], - "decoder_input_tokens": [0, 3, 9, 1, 0], - "decoder_loss_weights": [1, 1, 1, 0, 0], - } - expected_val = { - "encoder_input_tokens": [8, 4, 1, 0, 0, 0, 0], - "decoder_target_tokens": [4, 1, 0, 0, 0], - "decoder_input_tokens": [0, 4, 1, 0, 0], - "decoder_loss_weights": [1, 1, 0, 0, 0], - } - expected_dtypes = {feat: tf.int32 for feat in expected_train.keys()} - assert_dataset( - output_ds["train"], expected_train, expected_dtypes=expected_dtypes) - assert_dataset( - output_ds["validation"], expected_val, expected_dtypes=expected_dtypes) - - def test_get_dataset_enc_dec_sharded(self): - mixture_or_task_name = "enc_dec_sharded" - x = [{ - "inputs": [7, 8, 5, 6, 9, 4, 3], - "targets": [3, 9] - }, { - "inputs": [8, 4], - "targets": [4] - }, { - "inputs": [5, 6, 7], - "targets": [6, 5] - }] - ds = create_default_dataset(x) - dataset_fn = lambda split, shuffle_files: ds - register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) - - task_feature_lengths = {"inputs": 7, "targets": 5} - converter = feature_converters.EncDecFeatureConverter(pack=False) - shard_info = dataset_providers.ShardInfo(index=0, num_shards=2) - output_ds = dataset_providers_helpers.get_dataset( - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=task_feature_lengths, - dataset_split="train", - shuffle=False, - feature_converter=converter, - shard_info=shard_info) - - # Example index 1 should not be present in the sharded dataset. - expected = [{ - "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1], - "decoder_target_tokens": [3, 9, 1, 0, 0], - "decoder_input_tokens": [0, 3, 9, 1, 0], - "decoder_loss_weights": [1, 1, 1, 0, 0], - }, { - "encoder_input_tokens": [5, 6, 7, 1, 0, 0, 0], - "decoder_target_tokens": [6, 5, 1, 0, 0], - "decoder_input_tokens": [0, 6, 5, 1, 0], - "decoder_loss_weights": [1, 1, 1, 0, 0], - }] - expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} - assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) - - def test_get_dataset_enc_dec_sharded_and_packed(self): - mixture_or_task_name = "enc_dec_sharded_and_packed" - x = [{ - "inputs": [7, 8], - "targets": [3, 9] - }, { - "inputs": [8, 4], - "targets": [4] - }, { - "inputs": [5, 6, 7], - "targets": [6] - }] - ds = create_default_dataset(x) - dataset_fn = lambda split, shuffle_files: ds - register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) - - task_feature_lengths = {"inputs": 7, "targets": 5} - converter = feature_converters.EncDecFeatureConverter(pack=True) - shard_info = dataset_providers.ShardInfo(index=0, num_shards=2) - output_ds = dataset_providers_helpers.get_dataset( - mixture_or_task_name=mixture_or_task_name, - task_feature_lengths=task_feature_lengths, - dataset_split="train", - shuffle=False, - feature_converter=converter, - shard_info=shard_info) - - # Packing should be done after the sharding. - expected = { - "encoder_input_tokens": [7, 8, 1, 5, 6, 7, 1], - "encoder_segment_ids": [1, 1, 1, 2, 2, 2, 2], - "encoder_positions": [0, 1, 2, 0, 1, 2, 3], - "decoder_target_tokens": [3, 9, 1, 6, 1], - "decoder_input_tokens": [0, 3, 9, 0, 6], - "decoder_loss_weights": [1, 1, 1, 1, 1], - "decoder_segment_ids": [1, 1, 1, 2, 2], - "decoder_positions": [0, 1, 2, 0, 1], - } - expected_dtypes = {feat: tf.int32 for feat in expected.keys()} - assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) - - -def register_dummy_task( - task_name: str, - dataset_fn: Callable[[str, str], tf.data.Dataset], - output_feature_names: Sequence[str] = ("inputs", "targets") -) -> None: - """Register a dummy task for GetDatasetTest.""" - dataset_providers.TaskRegistry.add( - task_name, - source=dataset_providers.FunctionDataSource( - dataset_fn=dataset_fn, splits=["train", "validation"]), - preprocessors=[ - dataset_providers.CacheDatasetPlaceholder(), - preprocessors.append_eos_after_trim, - ], - output_features={ - feat: dataset_providers.Feature(test_utils.sentencepiece_vocab()) - for feat in output_feature_names - }, - metric_fns=[]) - - -if __name__ == "__main__": - absltest.main() diff --git a/seqio/dataset_providers_test.py b/seqio/dataset_providers_test.py index abb24b0d..d974b6dc 100644 --- a/seqio/dataset_providers_test.py +++ b/seqio/dataset_providers_test.py @@ -18,11 +18,13 @@ import functools import os import shutil -from typing import Any, Mapping, Optional, Sequence +from typing import Any, Callable, Mapping, Optional, Sequence from absl.testing import absltest +from absl.testing import parameterized +import numpy as np from seqio import dataset_providers -from seqio import dataset_providers_helpers +from seqio import feature_converters from seqio import metrics as metrics_lib from seqio import preprocessors from seqio import test_utils @@ -289,8 +291,7 @@ def _get_preps_with_cache_placeholder_buffer_size(self, buffer_size): return preps def _mock_and_assert_cached_source(self, task_name, buffer_size): - cached_task = dataset_providers_helpers.get_mixture_or_task(task_name) - assert isinstance(cached_task, dataset_providers.Task) + cached_task = dataset_providers.get_mixture_or_task(task_name) cached_task._get_cached_source = mock.MagicMock( side_effect=cached_task._get_cached_source) _ = cached_task.get_dataset(None, "train", use_cached=True) @@ -876,9 +877,8 @@ def test_plaintext_to_pretokenized_rename(self): def test_list_shards(self): def _get_formatted_shards_list(task_name, split): - task = dataset_providers_helpers.get_mixture_or_task(task_name) - assert isinstance(task, dataset_providers.Task) - shards = task.source.list_shards(split) + shards = dataset_providers.get_mixture_or_task( + task_name).source.list_shards(split) shards = [s.split("/")[-1] for s in shards] return sorted(shards) @@ -921,6 +921,8 @@ def test_replace(self): self.assertEqual(10000, new_task.shuffle_buffer_size) + + class MixturesTest(test_utils.FakeTaskTest): def setUp(self): @@ -1161,5 +1163,423 @@ def gen_dataset(split, self.assertEqual(expected, actual) +class GetDatasetTest(parameterized.TestCase, tf.test.TestCase): + + def test_get_dataset_enc_dec_unpacked(self): + mixture_or_task_name = "enc_dec_unpacked" + x = [{ + "inputs": [7, 8, 5, 6, 9, 4, 3], + "targets": [3, 9] + }, { + "inputs": [8, 4], + "targets": [4] + }, { + "inputs": [5, 6, 7], + "targets": [6, 5] + }] + ds = create_default_dataset(x) + dataset_fn = lambda split, shuffle_files: ds + register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) + + task_feature_lengths = {"inputs": 7, "targets": 5} + converter = feature_converters.EncDecFeatureConverter(pack=False) + output_ds = dataset_providers.get_dataset( + mixture_or_task_name=mixture_or_task_name, + task_feature_lengths=task_feature_lengths, + dataset_split="train", + shuffle=False, + feature_converter=converter) + + expected = [{ + "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1], + "decoder_target_tokens": [3, 9, 1, 0, 0], + "decoder_input_tokens": [0, 3, 9, 1, 0], + "decoder_loss_weights": [1, 1, 1, 0, 0], + }, { + "encoder_input_tokens": [8, 4, 1, 0, 0, 0, 0], + "decoder_target_tokens": [4, 1, 0, 0, 0], + "decoder_input_tokens": [0, 4, 1, 0, 0], + "decoder_loss_weights": [1, 1, 0, 0, 0], + }, { + "encoder_input_tokens": [5, 6, 7, 1, 0, 0, 0], + "decoder_target_tokens": [6, 5, 1, 0, 0], + "decoder_input_tokens": [0, 6, 5, 1, 0], + "decoder_loss_weights": [1, 1, 1, 0, 0], + }] + expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} + assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) + + @parameterized.parameters( + dict( + task_name="enc_dec_partial_trim_both", + task_feature_lengths={ + "inputs": 7, + "targets": 2 + }, + expect_trim_inputs=True, + expect_trim_targets=True), + dict( + task_name="enc_dec_partial_trim_targets", + task_feature_lengths={ + "inputs": None, + "targets": 2 + }, + expect_trim_inputs=False, + expect_trim_targets=True), + dict( + task_name="enc_dec_partial_trim_inputs", + task_feature_lengths={ + "inputs": 7, + "targets": None + }, + expect_trim_inputs=True, + expect_trim_targets=False), + dict( + task_name="enc_dec_partial_trim_neither", + task_feature_lengths={ + "inputs": None, + "targets": None + }, + expect_trim_inputs=False, + expect_trim_targets=False), + dict( + task_name="enc_dec_partial_trim_nothing", + task_feature_lengths=None, + expect_trim_inputs=False, + expect_trim_targets=False)) + def test_partial_sequence_length(self, task_name, task_feature_lengths, + expect_trim_inputs, expect_trim_targets): + x = [{ + "inputs": [7, 8, 5, 6, 9, 4, 3], + "targets": [3, 9] + }, { + "inputs": [8, 4], + "targets": [4] + }, { + "inputs": [5, 6, 7], + "targets": [6, 5] + }] + ds = create_default_dataset(x) + dataset_fn = lambda split, shuffle_files: ds + register_dummy_task(task_name, dataset_fn=dataset_fn) + # Unlike the other tests, don't use a feature converter. Instead, test the + # task.get_dataset method directly, which is similar to how evaluation.py + # infers feature lengths w/trimming. + task = dataset_providers.get_mixture_or_task(task_name) + output_ds = task.get_dataset( + sequence_length=task_feature_lengths, shuffle=False) + + expected = [{ + "inputs": [7, 8, 5, 6, 9, 4, 3, 1], + "targets": [3, 9, 1], + }, { + "inputs": [8, 4, 1], + "targets": [4, 1], + }, { + "inputs": [5, 6, 7, 1], + "targets": [6, 5, 1], + }] + if expect_trim_inputs: + expected[0]["inputs"] = [7, 8, 5, 6, 9, 4, 1] + if expect_trim_targets: + expected[0]["targets"] = [3, 1] + expected[2]["targets"] = [6, 1] + expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} + assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) + + @parameterized.parameters( + dict( + task_name="enc_dec_multidim_trim_both", + task_feature_lengths={ + "inputs": (2, 5), + "targets": 2 + }, + expect_trim_inputs=True, + expect_trim_targets=True, + ), + dict( + task_name="enc_dec_multidim_trim_inputs", + task_feature_lengths={ + "inputs": (2, 5), + "targets": None + }, + expect_trim_inputs=True, + expect_trim_targets=False, + ), + dict( + task_name="enc_dec_multidim_trim_targets", + task_feature_lengths={ + "inputs": None, + "targets": 2 + }, + expect_trim_inputs=False, + expect_trim_targets=True, + ), + dict( + task_name="enc_dec_no_multidim_trim", + task_feature_lengths={ + "inputs": None, + "targets": None + }, + expect_trim_inputs=False, + expect_trim_targets=False)) + def test_multidimension_sequence_length(self, task_name, task_feature_lengths, + expect_trim_inputs, + expect_trim_targets): + x = [{ + "inputs": [[7, 8, 5, 6, 9, 4, 3], [2, 3, 4, 5, 0, 0, 0], + [6, 7, 1, 0, 0, 0, 0]], + "targets": [3, 9] + }, { + "inputs": [[8, 4], [1, 0], [2, 3]], + "targets": [4] + }, { + "inputs": [[5, 6, 7]], + "targets": [6, 5, 1] + }, { + "inputs": [[7, 8, 9, 1, 2, 3, 4, 5, 6]], + "targets": [10, 11, 1] + }] + ds = tf.data.Dataset.from_generator( + lambda: x, + output_types={ + "inputs": tf.int32, + "targets": tf.int32 + }, + output_shapes={ + "inputs": (None, None), + "targets": (None,) + }) + dataset_fn = lambda split, shuffle_files: ds + dataset_providers.TaskRegistry.add( + task_name, + source=dataset_providers.FunctionDataSource( + dataset_fn=dataset_fn, splits=["train", "validation"]), + preprocessors=[ + dataset_providers.CacheDatasetPlaceholder(), + ], + output_features={ + "inputs": + dataset_providers.Feature( + test_utils.sentencepiece_vocab(), rank=2), + "targets": + dataset_providers.Feature(test_utils.sentencepiece_vocab()) + }, + metric_fns=[]) + # Unlike the other tests, don't use a feature converter. Instead, test the + # task.get_dataset method directly, which is similar to how evaluation.py + # infers feature lengths w/trimming. + task = dataset_providers.get_mixture_or_task(task_name) + output_ds = task.get_dataset( + sequence_length=task_feature_lengths, shuffle=False) + + expected = copy.deepcopy(x) + if expect_trim_inputs: + expected[0]["inputs"] = [[7, 8, 5, 6, 9], [2, 3, 4, 5, 0]] + expected[1]["inputs"] = [[8, 4], [1, 0]] + expected[3]["inputs"] = [[7, 8, 9, 1, 2]] + if expect_trim_targets: + expected[2]["targets"] = [6, 5] + expected[3]["targets"] = [10, 11] + expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} + assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) + + def test_get_dataset_enc_dec_packed(self): + mixture_or_task_name = "enc_dec_packed" + x = [{ + "inputs": [7, 8, 5, 6, 9, 4, 3], + "targets": [3, 9] + }, { + "inputs": [8, 4], + "targets": [4] + }, { + "inputs": [5, 6, 7], + "targets": [6, 5] + }] + ds = create_default_dataset(x) + dataset_fn = lambda split, shuffle_files: ds + register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) + + task_feature_lengths = {"inputs": 7, "targets": 5} + converter = feature_converters.EncDecFeatureConverter(pack=True) + output_ds = dataset_providers.get_dataset( + mixture_or_task_name=mixture_or_task_name, + task_feature_lengths=task_feature_lengths, + dataset_split="train", + shuffle=False, + feature_converter=converter) + + expected = [ + { + # Example 1 is trimmed + "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1], + "encoder_segment_ids": [1, 1, 1, 1, 1, 1, 1], + "encoder_positions": [0, 1, 2, 3, 4, 5, 6], + "decoder_target_tokens": [3, 9, 1, 0, 0], + "decoder_input_tokens": [0, 3, 9, 0, 0], + "decoder_loss_weights": [1, 1, 1, 0, 0], + "decoder_segment_ids": [1, 1, 1, 0, 0], + "decoder_positions": [0, 1, 2, 0, 0], + }, + { + # Example 2 and 3 are packed together + "encoder_input_tokens": [8, 4, 1, 5, 6, 7, 1], + "encoder_segment_ids": [1, 1, 1, 2, 2, 2, 2], + "encoder_positions": [0, 1, 2, 0, 1, 2, 3], + "decoder_target_tokens": [4, 1, 6, 5, 1], + "decoder_input_tokens": [0, 4, 0, 6, 5], + "decoder_loss_weights": [1, 1, 1, 1, 1], + "decoder_segment_ids": [1, 1, 2, 2, 2], + "decoder_positions": [0, 1, 0, 1, 2], + } + ] + expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} + assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) + + def test_get_dataset_both_train_and_validation_splits(self): + mixture_or_task_name = "both_train_and_validation_splits" + x_train = [{"inputs": [7, 8, 5, 6, 9, 4, 3], "targets": [3, 9]}] + x_val = [{"inputs": [8, 4], "targets": [4]}] + datasets = { + "train": create_default_dataset(x_train), + "validation": create_default_dataset(x_val) + } + dataset_fn = lambda split, shuffle_files: datasets[split] + register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) + + task_feature_lengths = {"inputs": 7, "targets": 5} + output_ds = {} + for split in ["train", "validation"]: + converter = feature_converters.EncDecFeatureConverter(pack=False) + output_ds[split] = dataset_providers.get_dataset( + mixture_or_task_name=mixture_or_task_name, + task_feature_lengths=task_feature_lengths, + dataset_split=split, + shuffle=False, + feature_converter=converter) + + expected_train = { + "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1], + "decoder_target_tokens": [3, 9, 1, 0, 0], + "decoder_input_tokens": [0, 3, 9, 1, 0], + "decoder_loss_weights": [1, 1, 1, 0, 0], + } + expected_val = { + "encoder_input_tokens": [8, 4, 1, 0, 0, 0, 0], + "decoder_target_tokens": [4, 1, 0, 0, 0], + "decoder_input_tokens": [0, 4, 1, 0, 0], + "decoder_loss_weights": [1, 1, 0, 0, 0], + } + expected_dtypes = {feat: tf.int32 for feat in expected_train.keys()} + assert_dataset( + output_ds["train"], expected_train, expected_dtypes=expected_dtypes) + assert_dataset( + output_ds["validation"], expected_val, expected_dtypes=expected_dtypes) + + def test_get_dataset_enc_dec_sharded(self): + mixture_or_task_name = "enc_dec_sharded" + x = [{ + "inputs": [7, 8, 5, 6, 9, 4, 3], + "targets": [3, 9] + }, { + "inputs": [8, 4], + "targets": [4] + }, { + "inputs": [5, 6, 7], + "targets": [6, 5] + }] + ds = create_default_dataset(x) + dataset_fn = lambda split, shuffle_files: ds + register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) + + task_feature_lengths = {"inputs": 7, "targets": 5} + converter = feature_converters.EncDecFeatureConverter(pack=False) + shard_info = dataset_providers.ShardInfo(index=0, num_shards=2) + output_ds = dataset_providers.get_dataset( + mixture_or_task_name=mixture_or_task_name, + task_feature_lengths=task_feature_lengths, + dataset_split="train", + shuffle=False, + feature_converter=converter, + shard_info=shard_info) + + # Example index 1 should not be present in the sharded dataset. + expected = [{ + "encoder_input_tokens": [7, 8, 5, 6, 9, 4, 1], + "decoder_target_tokens": [3, 9, 1, 0, 0], + "decoder_input_tokens": [0, 3, 9, 1, 0], + "decoder_loss_weights": [1, 1, 1, 0, 0], + }, { + "encoder_input_tokens": [5, 6, 7, 1, 0, 0, 0], + "decoder_target_tokens": [6, 5, 1, 0, 0], + "decoder_input_tokens": [0, 6, 5, 1, 0], + "decoder_loss_weights": [1, 1, 1, 0, 0], + }] + expected_dtypes = {feat: tf.int32 for feat in expected[0].keys()} + assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) + + def test_get_dataset_enc_dec_sharded_and_packed(self): + mixture_or_task_name = "enc_dec_sharded_and_packed" + x = [{ + "inputs": [7, 8], + "targets": [3, 9] + }, { + "inputs": [8, 4], + "targets": [4] + }, { + "inputs": [5, 6, 7], + "targets": [6] + }] + ds = create_default_dataset(x) + dataset_fn = lambda split, shuffle_files: ds + register_dummy_task(mixture_or_task_name, dataset_fn=dataset_fn) + + task_feature_lengths = {"inputs": 7, "targets": 5} + converter = feature_converters.EncDecFeatureConverter(pack=True) + shard_info = dataset_providers.ShardInfo(index=0, num_shards=2) + output_ds = dataset_providers.get_dataset( + mixture_or_task_name=mixture_or_task_name, + task_feature_lengths=task_feature_lengths, + dataset_split="train", + shuffle=False, + feature_converter=converter, + shard_info=shard_info) + + # Packing should be done after the sharding. + expected = { + "encoder_input_tokens": [7, 8, 1, 5, 6, 7, 1], + "encoder_segment_ids": [1, 1, 1, 2, 2, 2, 2], + "encoder_positions": [0, 1, 2, 0, 1, 2, 3], + "decoder_target_tokens": [3, 9, 1, 6, 1], + "decoder_input_tokens": [0, 3, 9, 0, 6], + "decoder_loss_weights": [1, 1, 1, 1, 1], + "decoder_segment_ids": [1, 1, 1, 2, 2], + "decoder_positions": [0, 1, 2, 0, 1], + } + expected_dtypes = {feat: tf.int32 for feat in expected.keys()} + assert_dataset(output_ds, expected, expected_dtypes=expected_dtypes) + + +def register_dummy_task( + task_name: str, + dataset_fn: Callable[[str, str], tf.data.Dataset], + output_feature_names: Sequence[str] = ("inputs", "targets") +) -> None: + """Register a dummy task for GetDatasetTest.""" + dataset_providers.TaskRegistry.add( + task_name, + source=dataset_providers.FunctionDataSource( + dataset_fn=dataset_fn, splits=["train", "validation"]), + preprocessors=[ + dataset_providers.CacheDatasetPlaceholder(), + preprocessors.append_eos_after_trim, + ], + output_features={ + feat: dataset_providers.Feature(test_utils.sentencepiece_vocab()) + for feat in output_feature_names + }, + metric_fns=[]) + + if __name__ == "__main__": absltest.main() diff --git a/seqio/evaluation.py b/seqio/evaluation.py index 739a2dd0..4cdb42ab 100644 --- a/seqio/evaluation.py +++ b/seqio/evaluation.py @@ -25,7 +25,6 @@ import jax import numpy as np from seqio import dataset_providers -from seqio import dataset_providers_helpers from seqio import feature_converters from seqio import loggers as loggers_lib from seqio import metrics as metrics_lib @@ -353,8 +352,8 @@ def __init__(self, value. """ logging.info("Initializing Evaluator for '%s'", mixture_or_task_name) - eval_tasks = dataset_providers_helpers.get_subtasks( - dataset_providers_helpers.get_mixture_or_task(mixture_or_task_name)) + eval_tasks = dataset_providers.get_subtasks( + dataset_providers.get_mixture_or_task(mixture_or_task_name)) self._eval_tasks = get_valid_eval_tasks(eval_tasks, eval_split) self._metrics_executor = concurrent.futures.ThreadPoolExecutor( diff --git a/seqio/grain_dataset_providers.py b/seqio/grain_dataset_providers.py deleted file mode 100644 index 6ecb2ca1..00000000 --- a/seqio/grain_dataset_providers.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2022 The SeqIO Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Classes for data loading and processing with Grain.""" diff --git a/seqio/grain_dataset_providers_test.py b/seqio/grain_dataset_providers_test.py deleted file mode 100644 index 3d7300fc..00000000 --- a/seqio/grain_dataset_providers_test.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2022 The SeqIO Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for grain_dataset_providers.""" diff --git a/seqio/helpers.py b/seqio/helpers.py index 9969c37d..d3a4acc9 100644 --- a/seqio/helpers.py +++ b/seqio/helpers.py @@ -20,7 +20,6 @@ from typing import Mapping, Optional, Sequence, Union from seqio import dataset_providers as dp -from seqio import dataset_providers_helpers from seqio import vocabularies as vc import tensorflow.compat.v2 as tf @@ -81,11 +80,9 @@ def _validate_output_features(og_output_features, new_output_features): f"new_output_features: {new_output_features} incompatible with " f"original output_features: {og_output_features}") - task_or_mixture = dataset_providers_helpers.get_mixture_or_task( - mixture_or_task_name) - if isinstance(task_or_mixture, dp.Task): + if mixture_or_task_name in dp.TaskRegistry.names(): # This is a Task. Create a new Task with the provided vocab/output_features. - og_task: dp.Task = task_or_mixture + og_task: dp.Task = dp.get_mixture_or_task(mixture_or_task_name) if new_vocab: new_output_features = { @@ -109,7 +106,7 @@ def _validate_output_features(og_output_features, new_output_features): # This is a Mixture. Create and register new sub-Tasks/Mixtures with the # provided vocab/output_features, then create a new Mixture. - og_mix: dp.Mixture = task_or_mixture + og_mix: dp.Mixture = dp.get_mixture_or_task(mixture_or_task_name) new_tasks_and_rates = [] for task_name, rate in og_mix._task_to_rate.items(): @@ -230,10 +227,9 @@ def mixture_or_task_with_truncated_data( The new `Task` or `Mixture` object. """ - task_or_mixture = dataset_providers_helpers.get_mixture_or_task( - mixture_or_task_name) - if isinstance(task_or_mixture, dp.Task): - og_task: dp.Task = task_or_mixture + if mixture_or_task_name in dp.TaskRegistry.names(): + # This is a `Task`. + og_task: dp.Task = dp.get_mixture_or_task(mixture_or_task_name) new_task = dp.Task( new_mixture_or_task_name, @@ -252,7 +248,7 @@ def mixture_or_task_with_truncated_data( else: # This is a Mixture. Create and register new sub-Tasks/Mixtures with the # provided vocab/output_features, then create a new Mixture. - og_mix: dp.Mixture = task_or_mixture + og_mix: dp.Mixture = dp.get_mixture_or_task(mixture_or_task_name) new_tasks_and_rates = [] for task_name, rate in og_mix._task_to_rate.items(): @@ -300,12 +296,10 @@ def mixture_with_missing_task_splits_removed( Returns: The new `Mixture` object. """ - og_mix = dataset_providers_helpers.get_mixture_or_task(mixture_name) - assert isinstance(og_mix, dp.Mixture) + og_mix: dp.Mixture = dp.get_mixture_or_task(mixture_name) new_tasks_and_rates = [] for task_name, rate in og_mix._task_to_rate.items(): - subtask = dataset_providers_helpers.get_mixture_or_task(task_name) - assert isinstance(subtask, dp.Task) + subtask: dp.Task = dp.get_mixture_or_task(task_name) if split in subtask.splits: new_tasks_and_rates.append((subtask.name, rate)) new_mix = dp.Mixture( diff --git a/seqio/helpers_test.py b/seqio/helpers_test.py index 28ed8173..bacbce57 100644 --- a/seqio/helpers_test.py +++ b/seqio/helpers_test.py @@ -18,7 +18,6 @@ from absl.testing import absltest from seqio import dataset_providers as dp -from seqio import dataset_providers_helpers from seqio import helpers from seqio import preprocessors as pr from seqio import test_utils @@ -50,7 +49,7 @@ def test_task_new_vocab(self): }) helpers.mixture_or_task_with_new_vocab( "my_test_task", "my_new_test_task", new_vocab=VOCAB2) - new_task = dataset_providers_helpers.get_mixture_or_task("my_new_test_task") + new_task = dp.get_mixture_or_task("my_new_test_task") self.assertEqual(new_task.source, test_task.source) self.assertEqual(new_task.preprocessors, test_task.preprocessors) self.assertEqual( @@ -122,15 +121,13 @@ def test_mixture_new_vocab(self): add_to_seqio_registry=True) # Step 4: Get new Tasks and Mixtures from the Registry. - new_mix = dataset_providers_helpers.get_mixture_or_task("my_new_test_mix2") - new_submix = dataset_providers_helpers.get_mixture_or_task( - "my_new_test_mix2.my_test_mix1") - new_submix_subtask1 = dataset_providers_helpers.get_mixture_or_task( + new_mix = dp.get_mixture_or_task("my_new_test_mix2") + new_submix = dp.get_mixture_or_task("my_new_test_mix2.my_test_mix1") + new_submix_subtask1 = dp.get_mixture_or_task( "my_new_test_mix2.my_test_mix1.my_test_task1") - new_submix_subtask2 = dataset_providers_helpers.get_mixture_or_task( + new_submix_subtask2 = dp.get_mixture_or_task( "my_new_test_mix2.my_test_mix1.my_test_task2") - new_subtask = dataset_providers_helpers.get_mixture_or_task( - "my_new_test_mix2.my_test_task1") + new_subtask = dp.get_mixture_or_task("my_new_test_mix2.my_test_task1") # Step 5: Verify mixing rates for new mixtures. self.assertDictEqual(new_mix._task_to_rate, { @@ -207,14 +204,12 @@ def test_mixture_new_output_features(self): # Step 4: Get new Tasks and Mixtures from the Registry. self.assertNotIn("my_new_test_mix2", dp.MixtureRegistry.names()) - new_submix = dataset_providers_helpers.get_mixture_or_task( - "my_new_test_mix2.my_test_mix1") - new_submix_subtask1 = dataset_providers_helpers.get_mixture_or_task( + new_submix = dp.get_mixture_or_task("my_new_test_mix2.my_test_mix1") + new_submix_subtask1 = dp.get_mixture_or_task( "my_new_test_mix2.my_test_mix1.my_test_task1") - new_submix_subtask2 = dataset_providers_helpers.get_mixture_or_task( + new_submix_subtask2 = dp.get_mixture_or_task( "my_new_test_mix2.my_test_mix1.my_test_task2") - new_subtask = dataset_providers_helpers.get_mixture_or_task( - "my_new_test_mix2.my_test_task1") + new_subtask = dp.get_mixture_or_task("my_new_test_mix2.my_test_task1") # Step 5: Verify mixing rates for new mixtures. self.assertDictEqual(new_mix._task_to_rate, { @@ -295,7 +290,7 @@ def test_task_with_truncated_data(self): }) helpers.mixture_or_task_with_truncated_data( "my_test_task", "my_new_test_task", split_sizes={"train": 1}) - new_task = dataset_providers_helpers.get_mixture_or_task("my_new_test_task") + new_task = dp.get_mixture_or_task("my_new_test_task") ds = new_task.get_dataset(_SEQUENCE_LENGTH, "train") examples = list(ds.as_numpy_iterator()) self.assertEqual(len(examples), 1) diff --git a/seqio/test_utils.py b/seqio/test_utils.py index 5d42c899..49c2e306 100644 --- a/seqio/test_utils.py +++ b/seqio/test_utils.py @@ -20,14 +20,13 @@ import os import shutil import sys -from typing import Any, Iterator, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Iterator, Mapping, Optional, Sequence, Union, Tuple from absl import flags from absl import logging from absl.testing import absltest import numpy as np from seqio import dataset_providers -from seqio import dataset_providers_helpers from seqio import evaluation from seqio import feature_converters from seqio import preprocessors @@ -616,10 +615,10 @@ class DataInjector(): """ def __init__(self, task_name, per_split_data): - self._task = dataset_providers_helpers.get_mixture_or_task(task_name) + self._task = dataset_providers.get_mixture_or_task(task_name) self.per_split_data = per_split_data - self._saved_source = self._task._source # pytype: disable=attribute-error + self._saved_source = self._task._source def __enter__(self): @@ -655,12 +654,12 @@ def assert_dict_contains(expected, actual): def encode_str(task_name, s, output_feature_name="targets"): - task = dataset_providers_helpers.get_mixture_or_task(task_name) + task = dataset_providers.get_mixture_or_task(task_name) return task.output_features[output_feature_name].vocabulary.encode(s) def create_prediction(task_name, s, output_feature_name="targets"): - task = dataset_providers_helpers.get_mixture_or_task(task_name) + task = dataset_providers.get_mixture_or_task(task_name) return [(0, task.output_features[output_feature_name].vocabulary.encode(s))] @@ -746,7 +745,7 @@ def test_preprocessing( with DataInjector(task_name, raw_data): split = list(raw_data.keys())[0] - task = dataset_providers_helpers.get_mixture_or_task(task_name) + task = dataset_providers.get_mixture_or_task(task_name) iterator = task.get_dataset( sequence_length=sequence_length, split=split, shuffle=False, seed=seed).as_numpy_iterator() @@ -837,7 +836,7 @@ def __call__(self, model_feature_lengths: Optional[Mapping[str, int]] = None): if predict_output is None: return [] - task = dataset_providers_helpers.get_mixture_or_task(task_name) + task = dataset_providers.get_mixture_or_task(task_name) return list( enumerate( task.output_features[target_feature_name].vocabulary.encode(s)