Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import operator
import os
import re
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Set, Tuple, Type, Union
from absl import logging
import clu.metrics
import editdistance
Expand Down Expand Up @@ -1782,6 +1782,35 @@ def _check_compatible_features(self) -> None:
"Features across tasks in a mixture must use the same dtype."
)

def get_task_dataset(
self,
task: Task,
output_feature_keys: Set[str],
sequence_length: Optional[Mapping[str, int]] = None,
split: str = tfds.Split.TRAIN,
use_cached: bool = False,
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
num_epochs: Optional[int] = None,
trim_output_features: bool = True,
) -> tf.data.Dataset:
"""."""

def filter_features(ex):
return {k: v for k, v in ex.items() if k in output_feature_keys}

return task.get_dataset(
sequence_length,
split=split,
use_cached=use_cached,
shuffle=shuffle,
seed=seed,
shard_info=shard_info,
num_epochs=num_epochs,
trim_output_features=trim_output_features,
).map(filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)

def get_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
self,
sequence_length: Optional[Mapping[str, int]] = None,
Expand Down Expand Up @@ -1846,22 +1875,21 @@ def get_dataset( # pytype: disable=signature-mismatch # overriding-parameter-t
if passthrough_features:
output_feature_keys.update(passthrough_features)

def filter_features(ex):
return {k: v for k, v in ex.items() if k in output_feature_keys}

datasets: List[tf.data.Dataset] = []
for task in tasks:
try:
ds = task.get_dataset(
ds = self.get_task_dataset(
task,
output_feature_keys,
sequence_length,
split=split,
use_cached=use_cached,
shuffle=shuffle,
seed=seed,
shard_info=shard_info,
num_epochs=num_epochs,
trim_output_features=trim_output_features,
).map(filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
split,
use_cached,
shuffle,
seed,
shard_info,
num_epochs,
trim_output_features,
)
datasets.append(ds)
except:
logging.error(
Expand Down