From d47ce8f50f575708d303b70f39438e22044eb9f0 Mon Sep 17 00:00:00 2001 From: Marvin Ritter Date: Thu, 15 Sep 2022 02:09:18 -0700 Subject: [PATCH] Add ArrayRecordDataSource based on Grain. ## Motivation Using Grain and ArrayRecord files allows us to perform global shuffling at the very beginning of the pipeline without a shuffle buffer. This should reduce the size of checkpoints a lot. ## Summary of changes - Add `supports_global_shuffle` property to `DataSource`. - Implement `ArrayRecordDataSource`. - Change `Task.get_dataset()` to disable shuffle buffer if the data source support global shuffling. PiperOrigin-RevId: 474507172 --- seqio/dataset_providers.py | 110 ++++++++++++++++++++++++++++++++++++- 1 file changed, 108 insertions(+), 2 deletions(-) diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index c6151b28..629412e0 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -29,6 +29,7 @@ from absl import logging import clu.metrics +import grain.tensorflow as grain import numpy as np from packaging import version from seqio import metrics as metrics_lib @@ -235,7 +236,17 @@ def caching_permitted(self) -> bool: def splits(self) -> Sequence[str]: return self._splits - @abc.abstractproperty + @property + def supports_global_shuffle(self) -> bool: + """Whether the data source supports global shuffle.""" + # If False the consumer (usually `Task.get_dataset()``) should apply an + # additional shuffle buffer. Even if False we still expect data sources to + # perform some macro level shuffle (usually shuffling the filenames) in + # `get_dataset()`. + return False + + @property + @abc.abstractmethod def supports_arbitrary_sharding(self) -> bool: """Whether supports sharding beyond those available in `list_shards`.""" raise NotImplementedError @@ -272,6 +283,80 @@ def num_input_examples(self, split: str) -> Optional[int]: return self._num_input_examples[split] +class ArrayRecordDataSource(DataSource): + """Data source for ArrayRecord files. + + This is experimental! + This data source support arbitrary sharding and global shuffling. + """ + + def __init__(self, + split_to_filepattern: Mapping[str, Union[str, Iterable[str]]], + num_input_examples: Optional[Mapping[str, int]] = None): + self._split_to_filepattern = split_to_filepattern + super().__init__( + splits=split_to_filepattern.keys(), + num_input_examples=num_input_examples, + caching_permitted=True) + + @property + def supports_arbitrary_sharding(self) -> bool: + return True + + @property + def supports_global_shuffle(self) -> bool: + return True + + def list_shards(self, split: str) -> Sequence[str]: + filepattern = self._split_to_filepattern[split] + if isinstance(filepattern, str): + return [filepattern] + return list(filepattern) + + def get_dataset(self, + split: str, + shuffle: bool = True, + seed: Optional[int] = None, + shard_info: Optional[ShardInfo] = None) -> tf.data.Dataset: + if shuffle and not seed: + raise ValueError( + "ArrayRecordDataSource always runs in deterministic mode and " + "requires a random seed for shuffling.") + source = grain.TfArrayRecordDataSource(self._split_to_filepattern[split]) + if shard_info: + shard_options = grain.ShardOptions(shard_info.index, + shard_info.num_shards) + else: + shard_options = grain.NoSharding() + + # Regarding num_epochs: + # Grain doesn't allow calling repeat() on the returned dataset because + # epochs wouldn't get a different shuffle order. We must pass the correct + # number of epoch to the index sampler here but we don't have it. We work + # around this by using num_epochs=1 for shuffle=False and num_epochs=None + # (repeat forever) otherwise. The consumer (Tash.get_dataset()) can remove + # additional epochs using Dataset.take(). + sampler = grain.TfDefaultIndexSampler( + len(source), + shard_options=shard_options, + num_epochs=None if shuffle else 1, + shuffle=shuffle, + seed=seed) + loader = grain.TfDataLoader( + source=source, sampler=sampler, batch_fn=grain.TfBatchNone()) + # Always start at index 0 as we expect users to checkpoint the + # tf.data iterator. + return loader.as_dataset(start_index=0) + + def num_input_examples(self, split: str) -> int: + if self._num_input_examples is None: + # This will have to open each file once to read the index. It doesn't + # read any data but would still be slow for a large number of files. + source = grain.TfArrayRecordDataSource(self._split_to_filepattern[split]) + return len(source) + return self._num_input_examples[split] + + def _get_name(function): """Returns the name of a (possibly partially applied) function.""" if isinstance(function, functools.partial): @@ -1195,6 +1280,16 @@ def get_dataset(self, ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed) ds = ds.shard(shard_info.num_shards, shard_info.index) + # Remove excess elements. + # ArrayRecordDataSource will always repeat forever iff shuffle=True. + if (isinstance(source, ArrayRecordDataSource) and shuffle and + num_epochs is not None): + assert ds.cardinality() == tf.data.INFINITE_CARDINALITY + logging.warning( + "Using num_epochs=%d and shuffle=True on " + "ArrayRecordDataSource is not recommended.", num_epochs) + ds = ds.take(source.num_input_examples(split) * num_epochs) + if ((use_cached and self.get_cached_stats(split)["examples"] < _MAX_EXAMPLES_TO_MEM_CACHE) or (self.num_input_examples(split) and @@ -1210,6 +1305,9 @@ def get_dataset(self, # We repeat before calling any (potentially) stochastic post-cache # preprocessing in order to take new samples each epoch. + # We do repeat before shuffling (unless the data source support global + # shuffle). This means that the shuffle buffer can contain elements from + # multiple epochs. ds = ds.repeat(num_epochs) # Post cache processing. @@ -1218,7 +1316,7 @@ def get_dataset(self, ds = self._validate_preprocessing(ds) if trim_output_features: ds = self._trim_output_features(ds, sequence_length=sequence_length) - if shuffle: + if shuffle and not source.supports_global_shuffle: if self._shuffle_buffer_size is None: raise ValueError( f"Shuffling is disallowed for Task '{self.name}' since its " @@ -1476,6 +1574,14 @@ def filter_features(ex): trim_output_features=trim_output_features).map( filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE) + if shuffle and task._source.supports_global_shuffle: + # Data sources that support global shuffle do not have a shuffle + # buffer after the preprocessing. The preprocessing could output + # multiple correlated outputs. + # We could add a shuffle buffer here but this defeats the purpose of + # avoiding the shuffle buffer in the first place. + raise NotImplementedError("Using ArrayRecordDataSource in mixtures " + "is currently not supported.") datasets.append(ds) except: logging.error("Failed to load task '%s' as part of mixture '%s'",