Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,48 @@ class ShardInfo:
num_shards: int


class DatasetProviderBase(metaclass=abc.ABCMeta):
class DatasetProvider(metaclass=abc.ABCMeta):
"""Interface for classes that provide a tf.data.Dataset."""

@property
@abc.abstractmethod
def output_features(self) -> Mapping[str, Feature]:
raise NotImplementedError

@property
@abc.abstractmethod
def splits(self) -> Sequence[str]:
raise NotImplementedError

@abc.abstractmethod
def get_dataset(
self,
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
num_epochs: Optional[int] = 1,
) -> tf.data.Dataset:
"""Returns the requested tf.data.Dataset."""
raise NotImplementedError

@abc.abstractmethod
def num_input_examples(self, split: str) -> Optional[int]:
raise NotImplementedError


class DatasetProviderBase(DatasetProvider, metaclass=abc.ABCMeta):
"""Abstract base for classes that provide a tf.data.Dataset."""

@abc.abstractproperty
@property
@abc.abstractmethod
def output_features(self) -> Mapping[str, Feature]:
raise NotImplementedError

@abc.abstractproperty
@property
@abc.abstractmethod
def splits(self) -> Sequence[str]:
raise NotImplementedError

Expand Down Expand Up @@ -110,8 +144,8 @@ class DatasetProviderRegistry(object):
"""

# Class variables must be defined in subclasses.
_REGISTRY: MutableMapping[str, DatasetProviderBase]
_PROVIDER_TYPE: Type[DatasetProviderBase]
_REGISTRY: MutableMapping[str, DatasetProvider]
_PROVIDER_TYPE: Type[DatasetProvider]

@classmethod
def add_provider(cls, name: str, provider):
Expand Down Expand Up @@ -228,7 +262,7 @@ def get_dataset(
class DataSource(DatasetProviderBase):
"""A `DatasetProvider` that provides raw data from an input source.

Inherits all abstract methods and properties of `DatasetProviderBase` except
Inherits all abstract methods and properties of `DatasetProvider` except
those overidden below.
"""

Expand Down Expand Up @@ -264,7 +298,7 @@ def supports_arbitrary_sharding(self) -> bool:

@property
def output_features(self) -> Mapping[str, Feature]:
"""Override unused property of `DatasetProviderBase`."""
"""Override unused property of `DatasetProvider`."""
raise NotImplementedError

@abc.abstractmethod
Expand Down Expand Up @@ -2146,7 +2180,7 @@ def get_dataset(

mixture_or_task = (
get_mixture_or_task(mixture_or_task_name)
if not isinstance(mixture_or_task_name, DatasetProviderBase)
if not isinstance(mixture_or_task_name, DatasetProvider)
else mixture_or_task_name
)
is_grain_task = False
Expand Down