Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 108 additions & 2 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from absl import logging
import clu.metrics
import grain.tensorflow as grain
import numpy as np
from packaging import version
from seqio import metrics as metrics_lib
Expand Down Expand Up @@ -235,7 +236,17 @@ def caching_permitted(self) -> bool:
def splits(self) -> Sequence[str]:
return self._splits

@abc.abstractproperty
@property
def supports_global_shuffle(self) -> bool:
"""Whether the data source supports global shuffle."""
# If False the consumer (usually `Task.get_dataset()``) should apply an
# additional shuffle buffer. Even if False we still expect data sources to
# perform some macro level shuffle (usually shuffling the filenames) in
# `get_dataset()`.
return False

@property
@abc.abstractmethod
def supports_arbitrary_sharding(self) -> bool:
"""Whether supports sharding beyond those available in `list_shards`."""
raise NotImplementedError
Expand Down Expand Up @@ -272,6 +283,80 @@ def num_input_examples(self, split: str) -> Optional[int]:
return self._num_input_examples[split]


class ArrayRecordDataSource(DataSource):
"""Data source for ArrayRecord files.

This is experimental!
This data source support arbitrary sharding and global shuffling.
"""

def __init__(self,
split_to_filepattern: Mapping[str, Union[str, Iterable[str]]],
num_input_examples: Optional[Mapping[str, int]] = None):
self._split_to_filepattern = split_to_filepattern
super().__init__(
splits=split_to_filepattern.keys(),
num_input_examples=num_input_examples,
caching_permitted=True)

@property
def supports_arbitrary_sharding(self) -> bool:
return True

@property
def supports_global_shuffle(self) -> bool:
return True

def list_shards(self, split: str) -> Sequence[str]:
filepattern = self._split_to_filepattern[split]
if isinstance(filepattern, str):
return [filepattern]
return list(filepattern)

def get_dataset(self,
split: str,
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None) -> tf.data.Dataset:
if shuffle and not seed:
raise ValueError(
"ArrayRecordDataSource always runs in deterministic mode and "
"requires a random seed for shuffling.")
source = grain.TfArrayRecordDataSource(self._split_to_filepattern[split])
if shard_info:
shard_options = grain.ShardOptions(shard_info.index,
shard_info.num_shards)
else:
shard_options = grain.NoSharding()

# Regarding num_epochs:
# Grain doesn't allow calling repeat() on the returned dataset because
# epochs wouldn't get a different shuffle order. We must pass the correct
# number of epoch to the index sampler here but we don't have it. We work
# around this by using num_epochs=1 for shuffle=False and num_epochs=None
# (repeat forever) otherwise. The consumer (Tash.get_dataset()) can remove
# additional epochs using Dataset.take().
sampler = grain.TfDefaultIndexSampler(
len(source),
shard_options=shard_options,
num_epochs=None if shuffle else 1,
shuffle=shuffle,
seed=seed)
loader = grain.TfDataLoader(
source=source, sampler=sampler, batch_fn=grain.TfBatchNone())
# Always start at index 0 as we expect users to checkpoint the
# tf.data iterator.
return loader.as_dataset(start_index=0)

def num_input_examples(self, split: str) -> int:
if self._num_input_examples is None:
# This will have to open each file once to read the index. It doesn't
# read any data but would still be slow for a large number of files.
source = grain.TfArrayRecordDataSource(self._split_to_filepattern[split])
return len(source)
return self._num_input_examples[split]


def _get_name(function):
"""Returns the name of a (possibly partially applied) function."""
if isinstance(function, functools.partial):
Expand Down Expand Up @@ -1195,6 +1280,16 @@ def get_dataset(self,
ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed)
ds = ds.shard(shard_info.num_shards, shard_info.index)

# Remove excess elements.
# ArrayRecordDataSource will always repeat forever iff shuffle=True.
if (isinstance(source, ArrayRecordDataSource) and shuffle and
num_epochs is not None):
assert ds.cardinality() == tf.data.INFINITE_CARDINALITY
logging.warning(
"Using num_epochs=%d and shuffle=True on "
"ArrayRecordDataSource is not recommended.", num_epochs)
ds = ds.take(source.num_input_examples(split) * num_epochs)

if ((use_cached and
self.get_cached_stats(split)["examples"] < _MAX_EXAMPLES_TO_MEM_CACHE)
or (self.num_input_examples(split) and
Expand All @@ -1210,6 +1305,9 @@ def get_dataset(self,

# We repeat before calling any (potentially) stochastic post-cache
# preprocessing in order to take new samples each epoch.
# We do repeat before shuffling (unless the data source support global
# shuffle). This means that the shuffle buffer can contain elements from
# multiple epochs.
ds = ds.repeat(num_epochs)

# Post cache processing.
Expand All @@ -1218,7 +1316,7 @@ def get_dataset(self,
ds = self._validate_preprocessing(ds)
if trim_output_features:
ds = self._trim_output_features(ds, sequence_length=sequence_length)
if shuffle:
if shuffle and not source.supports_global_shuffle:
if self._shuffle_buffer_size is None:
raise ValueError(
f"Shuffling is disallowed for Task '{self.name}' since its "
Expand Down Expand Up @@ -1476,6 +1574,14 @@ def filter_features(ex):
trim_output_features=trim_output_features).map(
filter_features,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if shuffle and task._source.supports_global_shuffle:
# Data sources that support global shuffle do not have a shuffle
# buffer after the preprocessing. The preprocessing could output
# multiple correlated outputs.
# We could add a shuffle buffer here but this defeats the purpose of
# avoiding the shuffle buffer in the first place.
raise NotImplementedError("Using ArrayRecordDataSource in mixtures "
"is currently not supported.")
datasets.append(ds)
except:
logging.error("Failed to load task '%s' as part of mixture '%s'",
Expand Down