diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index 2e977b13..c4cfb54f 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -72,14 +72,48 @@ class ShardInfo: num_shards: int -class DatasetProviderBase(metaclass=abc.ABCMeta): +class DatasetProvider(metaclass=abc.ABCMeta): + """Interface for classes that provide a tf.data.Dataset.""" + + @property + @abc.abstractmethod + def output_features(self) -> Mapping[str, Feature]: + raise NotImplementedError + + @property + @abc.abstractmethod + def splits(self) -> Sequence[str]: + raise NotImplementedError + + @abc.abstractmethod + def get_dataset( + self, + sequence_length: Optional[Mapping[str, int]] = None, + split: str = tfds.Split.TRAIN, + use_cached: bool = False, + shuffle: bool = True, + seed: Optional[int] = None, + shard_info: Optional[ShardInfo] = None, + num_epochs: Optional[int] = 1, + ) -> tf.data.Dataset: + """Returns the requested tf.data.Dataset.""" + raise NotImplementedError + + @abc.abstractmethod + def num_input_examples(self, split: str) -> Optional[int]: + raise NotImplementedError + + +class DatasetProviderBase(DatasetProvider, metaclass=abc.ABCMeta): """Abstract base for classes that provide a tf.data.Dataset.""" - @abc.abstractproperty + @property + @abc.abstractmethod def output_features(self) -> Mapping[str, Feature]: raise NotImplementedError - @abc.abstractproperty + @property + @abc.abstractmethod def splits(self) -> Sequence[str]: raise NotImplementedError @@ -110,8 +144,8 @@ class DatasetProviderRegistry(object): """ # Class variables must be defined in subclasses. - _REGISTRY: MutableMapping[str, DatasetProviderBase] - _PROVIDER_TYPE: Type[DatasetProviderBase] + _REGISTRY: MutableMapping[str, DatasetProvider] + _PROVIDER_TYPE: Type[DatasetProvider] @classmethod def add_provider(cls, name: str, provider): @@ -228,7 +262,7 @@ def get_dataset( class DataSource(DatasetProviderBase): """A `DatasetProvider` that provides raw data from an input source. - Inherits all abstract methods and properties of `DatasetProviderBase` except + Inherits all abstract methods and properties of `DatasetProvider` except those overidden below. """ @@ -264,7 +298,7 @@ def supports_arbitrary_sharding(self) -> bool: @property def output_features(self) -> Mapping[str, Feature]: - """Override unused property of `DatasetProviderBase`.""" + """Override unused property of `DatasetProvider`.""" raise NotImplementedError @abc.abstractmethod @@ -2146,7 +2180,7 @@ def get_dataset( mixture_or_task = ( get_mixture_or_task(mixture_or_task_name) - if not isinstance(mixture_or_task_name, DatasetProviderBase) + if not isinstance(mixture_or_task_name, DatasetProvider) else mixture_or_task_name ) is_grain_task = False