From e90fbec4e68490c78e2d09927767a22972a288f1 Mon Sep 17 00:00:00 2001 From: SeqIO Team Date: Fri, 14 Apr 2023 11:25:22 -0700 Subject: [PATCH] internal PiperOrigin-RevId: 524341684 --- seqio/dataset_providers.py | 54 +++++++++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 13 deletions(-) diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index b46626d1..0a9360f2 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -28,7 +28,7 @@ import operator import os import re -from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Set, Tuple, Type, Union from absl import logging import clu.metrics import editdistance @@ -1782,6 +1782,35 @@ def _check_compatible_features(self) -> None: "Features across tasks in a mixture must use the same dtype." ) + def get_task_dataset( + self, + task: Task, + output_feature_keys: Set[str], + sequence_length: Optional[Mapping[str, int]] = None, + split: str = tfds.Split.TRAIN, + use_cached: bool = False, + shuffle: bool = True, + seed: Optional[int] = None, + shard_info: Optional[ShardInfo] = None, + num_epochs: Optional[int] = None, + trim_output_features: bool = True, + ) -> tf.data.Dataset: + """.""" + + def filter_features(ex): + return {k: v for k, v in ex.items() if k in output_feature_keys} + + return task.get_dataset( + sequence_length, + split=split, + use_cached=use_cached, + shuffle=shuffle, + seed=seed, + shard_info=shard_info, + num_epochs=num_epochs, + trim_output_features=trim_output_features, + ).map(filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE) + def get_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks self, sequence_length: Optional[Mapping[str, int]] = None, @@ -1846,22 +1875,21 @@ def get_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t if passthrough_features: output_feature_keys.update(passthrough_features) - def filter_features(ex): - return {k: v for k, v in ex.items() if k in output_feature_keys} - datasets: List[tf.data.Dataset] = [] for task in tasks: try: - ds = task.get_dataset( + ds = self.get_task_dataset( + task, + output_feature_keys, sequence_length, - split=split, - use_cached=use_cached, - shuffle=shuffle, - seed=seed, - shard_info=shard_info, - num_epochs=num_epochs, - trim_output_features=trim_output_features, - ).map(filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE) + split, + use_cached, + shuffle, + seed, + shard_info, + num_epochs, + trim_output_features, + ) datasets.append(ds) except: logging.error(