Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 147 additions & 2 deletions seqio/feature_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,135 @@ def loss_on_targets_only(self) -> bool:
return self._loss_on_targets_only


class PrefixSuffixLMFeatureConverter(PrefixLMFeatureConverter):
"""Feature converter for a input + target + suffix language model.

When "suffixes" field is empty, it is identical as PrefixLMFeatureConverter.
When "suffixes" field is not empty, it merges "targets" and "suffixes" but
computes the loss only over tokens from "suffixes".

Example: a packed dataset
```
ds = [{"inputs": [9, 4, 6], "targets": [3, 9], "suffixes": [2, 1]},
{"inputs": [3, 2,], "targets": [4,], "suffixes": []}]

task_feature_lengths = {"inputs": 7, "targets": 8}

converted_ds = {
"decoder_target_tokens": [9, 4, 6, 3, 9, 2, 1, 3, 2, 4, 0, 0, 0, 0, 0],
"decoder_input_tokens": [0, 9, 4, 6, 3, 9, 2, 0, 3, 2, 0, 0, 0, 0, 0],
"decoder_loss_weights": [0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0],
"target_suffix_weights": [0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0],
"decoder_positions": [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 0, 0, 0, 0, 0],
"decoder_segment_ids": [1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0],
"decoder_causal_attention": [1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
}
```
"""

TASK_FEATURES = {
"inputs": FeatureConverter.FeatureSpec(dtype=tf.int32),
"targets": FeatureConverter.FeatureSpec(dtype=tf.int32),
"suffixes": FeatureConverter.FeatureSpec(dtype=tf.int32),
}
MODEL_FEATURES = {
"decoder_target_tokens": FeatureConverter.FeatureSpec(dtype=tf.int32),
"decoder_input_tokens": FeatureConverter.FeatureSpec(dtype=tf.int32),
"decoder_loss_weights": FeatureConverter.FeatureSpec(dtype=tf.int32),
"decoder_causal_attention": FeatureConverter.FeatureSpec(dtype=tf.int32),
"target_suffix_weights": FeatureConverter.FeatureSpec(dtype=tf.int32),
}
PACKING_FEATURE_DTYPES = {
"decoder_segment_ids": tf.int32,
"decoder_positions": tf.int32,
}

def __init__(self, loss_on_targets_only: bool = True, **kwargs) -> None:
self._loss_on_targets_only = loss_on_targets_only
super().__init__(**kwargs)

def _convert_example(
self, features: Mapping[str, tf.Tensor]
) -> Mapping[str, tf.Tensor]:
"""Convert a example into an example with model features."""
# First use the standard LM conversion.
lm_features = super()._convert_example(features)
d = dict(lm_features)
target_suffix_weights = tf.cast(
tf.equal(features["target_suffix_weights"], 2),
dtype=d["decoder_loss_weights"].dtype)
d["target_suffix_weights"] = target_suffix_weights
return d

def _concat_and_add_masks(
self, features: Mapping[str, tf.Tensor]
) -> Mapping[str, tf.Tensor]:
"""Creates concatenated inputs and targets fields and adds masks."""
inputs = features["inputs"]
targets = features["targets"]
suffixes = features["suffixes"]
target_suffixes = tf.concat([targets, suffixes], axis=0)
# If the targets are empty, we add one padding target.
target_suffixes = tf.cond(
tf.size(target_suffixes) > 0,
lambda: target_suffixes,
lambda: tf.zeros(1, dtype="int32"),
)

# Width of the "inputs" portion in the concatenated sequence.
width = tf.size(inputs)
inputs_width = tf.fill([tf.size(inputs) + tf.size(target_suffixes)], width)

# Width with an extra position to the right in the inputs mask. See
# docstring for PrefixLMFeatureConverter class for details.
inputs_width_add_pos = tf.fill(
[tf.size(inputs) + tf.size(target_suffixes)], width + 1
)

target_weights_with_suffix = tf.concat(
[tf.ones_like(inputs),
tf.ones_like(targets),
tf.fill([tf.size(suffixes),], 2)], axis=-1)
target_weights_without_suffix = tf.concat(
[tf.ones_like(inputs),
tf.fill([tf.size(target_suffixes),], 2)], axis=-1)

target_weights = tf.cond(
tf.size(suffixes) > 0,
lambda: target_weights_with_suffix,
lambda: target_weights_without_suffix,
)
return {
"targets": tf.concat([inputs, target_suffixes], axis=-1),
"inputs_width": inputs_width,
"inputs_width_add_pos": inputs_width_add_pos,
"target_suffix_weights": target_weights
}

def get_model_feature_lengths(
self, task_feature_lengths: Mapping[str, int]
) -> Mapping[str, int]:
"""Define the length relationship between task and model features."""
decoder_length = sum(task_feature_lengths.values())
concat_length = {"targets": decoder_length}
lm_model_feature_lengths = super().get_model_feature_lengths(concat_length)
model_feature_lengths = dict(lm_model_feature_lengths)
model_feature_lengths["decoder_causal_attention"] = decoder_length
model_feature_lengths["target_suffix_weights"] = decoder_length
return model_feature_lengths

def _concat_task_feature_lengths(
self, task_feature_lengths: Mapping[str, int]
) -> Mapping[str, int]:
concat_length = sum(task_feature_lengths.values())
return {
"targets": concat_length,
"inputs_width": concat_length,
"inputs_width_add_pos": concat_length,
"target_suffix_weights": concat_length,
}


class DecoderFeatureConverter(FeatureConverter):
"""Wrapper of FeatureConverter that handles both LM and PrefixLM tasks.

Expand All @@ -1135,6 +1264,8 @@ class DecoderFeatureConverter(FeatureConverter):
TASK_FEATURES = {
"inputs": FeatureConverter.FeatureSpec(dtype=tf.int32), # Optional field
"targets": FeatureConverter.FeatureSpec(dtype=tf.int32),
# Optional field
"suffixes": FeatureConverter.FeatureSpec(dtype=tf.int32),
}
MODEL_FEATURES = {
"decoder_target_tokens": FeatureConverter.FeatureSpec(dtype=tf.int32),
Expand Down Expand Up @@ -1176,10 +1307,19 @@ def __init__(
apply_length_check=apply_length_check,
bos_id=bos_id,
)
self.prefixsuffixlm_feature_converter = PrefixSuffixLMFeatureConverter(
loss_on_targets_only=loss_on_targets_only,
pack=pack,
use_custom_packing_ops=use_custom_packing_ops,
apply_length_check=apply_length_check,
bos_id=bos_id,
)

def __call__(
self, ds: tf.data.Dataset, task_feature_lengths: Mapping[str, int]
) -> tf.data.Dataset:
if "suffixes" in task_feature_lengths:
return self.prefixsuffixlm_feature_converter(ds, task_feature_lengths)
if "inputs" in task_feature_lengths:
return self.prefixlm_feature_converter(ds, task_feature_lengths)
else:
Expand All @@ -1195,8 +1335,13 @@ def get_model_feature_lengths(
self, task_feature_lengths: Mapping[str, int]
) -> Mapping[str, int]:
"""Define the length relationship between task and model features."""

if "inputs" in task_feature_lengths:
if "suffixes" in task_feature_lengths:
model_feature_lengths = (
self.prefixsuffixlm_feature_converter.get_model_feature_lengths(
task_feature_lengths
)
)
elif "inputs" in task_feature_lengths:
model_feature_lengths = (
self.prefixlm_feature_converter.get_model_feature_lengths(
task_feature_lengths
Expand Down
109 changes: 109 additions & 0 deletions seqio/feature_converters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,5 +1181,114 @@ def test_encoder_decoder_packed(self):
assert_dataset(converted_ds, expected)


class PrefixSuffixLMFeatureConverter(tf.test.TestCase):

def test_prefix_suffix_lm_unpacked(self):
x = [{"inputs": [9, 4, 6, 1], "targets": [3, 9], "suffixes": [2, 1]}]
ds = create_default_dataset(
x, feature_names=("inputs", "targets", "suffixes"))

task_feature_lengths = {"inputs": 5, "targets": 4, "suffixes": 3}
converter = feature_converters.PrefixSuffixLMFeatureConverter(pack=False)
converted_ds = converter(ds, task_feature_lengths)

expected = {
"decoder_target_tokens": [9, 4, 6, 1, 3, 9, 2, 1, 0, 0, 0, 0],
# The last EOS token is kept if unpacked.
"decoder_input_tokens": [0, 9, 4, 6, 1, 3, 9, 2, 1, 0, 0, 0],
"decoder_loss_weights": [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
"decoder_causal_attention": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
"target_suffix_weights": [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],
}
assert_dataset(converted_ds, expected)

def test_prefix_suffix_lm_unpacked_trivial_targets(self):
x = [{"inputs": [9, 4, 6, 1], "targets": [], "suffixes": [2, 1]}]
ds = create_default_dataset(
x, feature_names=("inputs", "targets", "suffixes"))

task_feature_lengths = {"inputs": 5, "targets": 4, "suffixes": 3}

converter = feature_converters.PrefixSuffixLMFeatureConverter(pack=False)
converted_ds = converter(ds, task_feature_lengths)
expected = {
"decoder_target_tokens": [9, 4, 6, 1, 2, 1, 0, 0, 0, 0, 0, 0],
# The last EOS token is kept if unpacked.
"decoder_input_tokens": [0, 9, 4, 6, 1, 2, 1, 0, 0, 0, 0, 0],
"decoder_loss_weights": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
"decoder_causal_attention": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
"target_suffix_weights": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
}
assert_dataset(converted_ds, expected)

def test_prefix_suffix_lm_unpacked_trivial_suffixes(self):
x = [{"inputs": [9, 4, 6, 1], "targets": [2, 1], "suffixes": []}]
ds = create_default_dataset(
x, feature_names=("inputs", "targets", "suffixes"))

task_feature_lengths = {"inputs": 5, "targets": 4, "suffixes": 3}

converter = feature_converters.PrefixSuffixLMFeatureConverter(pack=False)
converted_ds = converter(ds, task_feature_lengths)

expected = {
"decoder_target_tokens": [9, 4, 6, 1, 2, 1, 0, 0, 0, 0, 0, 0],
# The last EOS token is kept if unpacked.
"decoder_input_tokens": [0, 9, 4, 6, 1, 2, 1, 0, 0, 0, 0, 0],
"decoder_loss_weights": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
"decoder_causal_attention": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
"target_suffix_weights": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
}
assert_dataset(converted_ds, expected)

def test_prefix_suffix_lm_packed(self):
x = [
{"inputs": [9, 4, 6], "targets": [3, 9], "suffixes": [2, 1]},
{"inputs": [3, 2,], "targets": [4,], "suffixes": [1]}
]
ds = create_default_dataset(
x, feature_names=("inputs", "targets", "suffixes"))

task_feature_lengths = {"inputs": 8, "targets": 4, "suffixes": 3}
converter = feature_converters.PrefixSuffixLMFeatureConverter(pack=True)
converted_ds = converter(ds, task_feature_lengths)

expected = {
"decoder_target_tokens": [9, 4, 6, 3, 9, 2, 1, 3, 2, 4, 1, 0, 0, 0, 0],
"decoder_input_tokens": [0, 9, 4, 6, 3, 9, 2, 0, 3, 2, 4, 0, 0, 0, 0],
"decoder_loss_weights": [0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0],
"target_suffix_weights": [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0],
"decoder_positions": [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 0, 0, 0, 0],
"decoder_segment_ids": [1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 0, 0],
"decoder_causal_attention": [
1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
}
assert_dataset(converted_ds, expected)

def test_prefix_suffix_lm_packed_trivial_suffoxes(self):
x = [
{"inputs": [9, 4, 6], "targets": [3, 9], "suffixes": [2, 1]},
{"inputs": [3, 2,], "targets": [4,], "suffixes": []}
]
ds = create_default_dataset(
x, feature_names=("inputs", "targets", "suffixes"))

task_feature_lengths = {"inputs": 8, "targets": 4, "suffixes": 3}
converter = feature_converters.PrefixSuffixLMFeatureConverter(pack=True)
converted_ds = converter(ds, task_feature_lengths)

expected = {
"decoder_target_tokens": [9, 4, 6, 3, 9, 2, 1, 3, 2, 4, 0, 0, 0, 0, 0],
"decoder_input_tokens": [0, 9, 4, 6, 3, 9, 2, 0, 3, 2, 0, 0, 0, 0, 0],
"decoder_loss_weights": [0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0],
"target_suffix_weights": [0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0],
"decoder_positions": [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 0, 0, 0, 0, 0],
"decoder_segment_ids": [1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0],
"decoder_causal_attention": [
1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0],
}
assert_dataset(converted_ds, expected)


if __name__ == "__main__":
tf.test.main()