diff --git a/seqio/feature_converters.py b/seqio/feature_converters.py index cc8c5ecf..c9fee4a0 100644 --- a/seqio/feature_converters.py +++ b/seqio/feature_converters.py @@ -1126,6 +1126,135 @@ def loss_on_targets_only(self) -> bool: return self._loss_on_targets_only +class PrefixSuffixLMFeatureConverter(PrefixLMFeatureConverter): + """Feature converter for a input + target + suffix language model. + + When "suffixes" field is empty, it is identical as PrefixLMFeatureConverter. + When "suffixes" field is not empty, it merges "targets" and "suffixes" but + computes the loss only over tokens from "suffixes". + + Example: a packed dataset + ``` + ds = [{"inputs": [9, 4, 6], "targets": [3, 9], "suffixes": [2, 1]}, + {"inputs": [3, 2,], "targets": [4,], "suffixes": []}] + + task_feature_lengths = {"inputs": 7, "targets": 8} + + converted_ds = { + "decoder_target_tokens": [9, 4, 6, 3, 9, 2, 1, 3, 2, 4, 0, 0, 0, 0, 0], + "decoder_input_tokens": [0, 9, 4, 6, 3, 9, 2, 0, 3, 2, 0, 0, 0, 0, 0], + "decoder_loss_weights": [0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0], + "target_suffix_weights": [0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0], + "decoder_positions": [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 0, 0, 0, 0, 0], + "decoder_segment_ids": [1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0], + "decoder_causal_attention": [1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], + } + ``` + """ + + TASK_FEATURES = { + "inputs": FeatureConverter.FeatureSpec(dtype=tf.int32), + "targets": FeatureConverter.FeatureSpec(dtype=tf.int32), + "suffixes": FeatureConverter.FeatureSpec(dtype=tf.int32), + } + MODEL_FEATURES = { + "decoder_target_tokens": FeatureConverter.FeatureSpec(dtype=tf.int32), + "decoder_input_tokens": FeatureConverter.FeatureSpec(dtype=tf.int32), + "decoder_loss_weights": FeatureConverter.FeatureSpec(dtype=tf.int32), + "decoder_causal_attention": FeatureConverter.FeatureSpec(dtype=tf.int32), + "target_suffix_weights": FeatureConverter.FeatureSpec(dtype=tf.int32), + } + PACKING_FEATURE_DTYPES = { + "decoder_segment_ids": tf.int32, + "decoder_positions": tf.int32, + } + + def __init__(self, loss_on_targets_only: bool = True, **kwargs) -> None: + self._loss_on_targets_only = loss_on_targets_only + super().__init__(**kwargs) + + def _convert_example( + self, features: Mapping[str, tf.Tensor] + ) -> Mapping[str, tf.Tensor]: + """Convert a example into an example with model features.""" + # First use the standard LM conversion. + lm_features = super()._convert_example(features) + d = dict(lm_features) + target_suffix_weights = tf.cast( + tf.equal(features["target_suffix_weights"], 2), + dtype=d["decoder_loss_weights"].dtype) + d["target_suffix_weights"] = target_suffix_weights + return d + + def _concat_and_add_masks( + self, features: Mapping[str, tf.Tensor] + ) -> Mapping[str, tf.Tensor]: + """Creates concatenated inputs and targets fields and adds masks.""" + inputs = features["inputs"] + targets = features["targets"] + suffixes = features["suffixes"] + target_suffixes = tf.concat([targets, suffixes], axis=0) + # If the targets are empty, we add one padding target. + target_suffixes = tf.cond( + tf.size(target_suffixes) > 0, + lambda: target_suffixes, + lambda: tf.zeros(1, dtype="int32"), + ) + + # Width of the "inputs" portion in the concatenated sequence. + width = tf.size(inputs) + inputs_width = tf.fill([tf.size(inputs) + tf.size(target_suffixes)], width) + + # Width with an extra position to the right in the inputs mask. See + # docstring for PrefixLMFeatureConverter class for details. + inputs_width_add_pos = tf.fill( + [tf.size(inputs) + tf.size(target_suffixes)], width + 1 + ) + + target_weights_with_suffix = tf.concat( + [tf.ones_like(inputs), + tf.ones_like(targets), + tf.fill([tf.size(suffixes),], 2)], axis=-1) + target_weights_without_suffix = tf.concat( + [tf.ones_like(inputs), + tf.fill([tf.size(target_suffixes),], 2)], axis=-1) + + target_weights = tf.cond( + tf.size(suffixes) > 0, + lambda: target_weights_with_suffix, + lambda: target_weights_without_suffix, + ) + return { + "targets": tf.concat([inputs, target_suffixes], axis=-1), + "inputs_width": inputs_width, + "inputs_width_add_pos": inputs_width_add_pos, + "target_suffix_weights": target_weights + } + + def get_model_feature_lengths( + self, task_feature_lengths: Mapping[str, int] + ) -> Mapping[str, int]: + """Define the length relationship between task and model features.""" + decoder_length = sum(task_feature_lengths.values()) + concat_length = {"targets": decoder_length} + lm_model_feature_lengths = super().get_model_feature_lengths(concat_length) + model_feature_lengths = dict(lm_model_feature_lengths) + model_feature_lengths["decoder_causal_attention"] = decoder_length + model_feature_lengths["target_suffix_weights"] = decoder_length + return model_feature_lengths + + def _concat_task_feature_lengths( + self, task_feature_lengths: Mapping[str, int] + ) -> Mapping[str, int]: + concat_length = sum(task_feature_lengths.values()) + return { + "targets": concat_length, + "inputs_width": concat_length, + "inputs_width_add_pos": concat_length, + "target_suffix_weights": concat_length, + } + + class DecoderFeatureConverter(FeatureConverter): """Wrapper of FeatureConverter that handles both LM and PrefixLM tasks. @@ -1135,6 +1264,8 @@ class DecoderFeatureConverter(FeatureConverter): TASK_FEATURES = { "inputs": FeatureConverter.FeatureSpec(dtype=tf.int32), # Optional field "targets": FeatureConverter.FeatureSpec(dtype=tf.int32), + # Optional field + "suffixes": FeatureConverter.FeatureSpec(dtype=tf.int32), } MODEL_FEATURES = { "decoder_target_tokens": FeatureConverter.FeatureSpec(dtype=tf.int32), @@ -1176,10 +1307,19 @@ def __init__( apply_length_check=apply_length_check, bos_id=bos_id, ) + self.prefixsuffixlm_feature_converter = PrefixSuffixLMFeatureConverter( + loss_on_targets_only=loss_on_targets_only, + pack=pack, + use_custom_packing_ops=use_custom_packing_ops, + apply_length_check=apply_length_check, + bos_id=bos_id, + ) def __call__( self, ds: tf.data.Dataset, task_feature_lengths: Mapping[str, int] ) -> tf.data.Dataset: + if "suffixes" in task_feature_lengths: + return self.prefixsuffixlm_feature_converter(ds, task_feature_lengths) if "inputs" in task_feature_lengths: return self.prefixlm_feature_converter(ds, task_feature_lengths) else: @@ -1195,8 +1335,13 @@ def get_model_feature_lengths( self, task_feature_lengths: Mapping[str, int] ) -> Mapping[str, int]: """Define the length relationship between task and model features.""" - - if "inputs" in task_feature_lengths: + if "suffixes" in task_feature_lengths: + model_feature_lengths = ( + self.prefixsuffixlm_feature_converter.get_model_feature_lengths( + task_feature_lengths + ) + ) + elif "inputs" in task_feature_lengths: model_feature_lengths = ( self.prefixlm_feature_converter.get_model_feature_lengths( task_feature_lengths diff --git a/seqio/feature_converters_test.py b/seqio/feature_converters_test.py index 262660c0..fccc7581 100644 --- a/seqio/feature_converters_test.py +++ b/seqio/feature_converters_test.py @@ -1181,5 +1181,114 @@ def test_encoder_decoder_packed(self): assert_dataset(converted_ds, expected) +class PrefixSuffixLMFeatureConverter(tf.test.TestCase): + + def test_prefix_suffix_lm_unpacked(self): + x = [{"inputs": [9, 4, 6, 1], "targets": [3, 9], "suffixes": [2, 1]}] + ds = create_default_dataset( + x, feature_names=("inputs", "targets", "suffixes")) + + task_feature_lengths = {"inputs": 5, "targets": 4, "suffixes": 3} + converter = feature_converters.PrefixSuffixLMFeatureConverter(pack=False) + converted_ds = converter(ds, task_feature_lengths) + + expected = { + "decoder_target_tokens": [9, 4, 6, 1, 3, 9, 2, 1, 0, 0, 0, 0], + # The last EOS token is kept if unpacked. + "decoder_input_tokens": [0, 9, 4, 6, 1, 3, 9, 2, 1, 0, 0, 0], + "decoder_loss_weights": [0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + "decoder_causal_attention": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + "target_suffix_weights": [0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], + } + assert_dataset(converted_ds, expected) + + def test_prefix_suffix_lm_unpacked_trivial_targets(self): + x = [{"inputs": [9, 4, 6, 1], "targets": [], "suffixes": [2, 1]}] + ds = create_default_dataset( + x, feature_names=("inputs", "targets", "suffixes")) + + task_feature_lengths = {"inputs": 5, "targets": 4, "suffixes": 3} + + converter = feature_converters.PrefixSuffixLMFeatureConverter(pack=False) + converted_ds = converter(ds, task_feature_lengths) + expected = { + "decoder_target_tokens": [9, 4, 6, 1, 2, 1, 0, 0, 0, 0, 0, 0], + # The last EOS token is kept if unpacked. + "decoder_input_tokens": [0, 9, 4, 6, 1, 2, 1, 0, 0, 0, 0, 0], + "decoder_loss_weights": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], + "decoder_causal_attention": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + "target_suffix_weights": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], + } + assert_dataset(converted_ds, expected) + + def test_prefix_suffix_lm_unpacked_trivial_suffixes(self): + x = [{"inputs": [9, 4, 6, 1], "targets": [2, 1], "suffixes": []}] + ds = create_default_dataset( + x, feature_names=("inputs", "targets", "suffixes")) + + task_feature_lengths = {"inputs": 5, "targets": 4, "suffixes": 3} + + converter = feature_converters.PrefixSuffixLMFeatureConverter(pack=False) + converted_ds = converter(ds, task_feature_lengths) + + expected = { + "decoder_target_tokens": [9, 4, 6, 1, 2, 1, 0, 0, 0, 0, 0, 0], + # The last EOS token is kept if unpacked. + "decoder_input_tokens": [0, 9, 4, 6, 1, 2, 1, 0, 0, 0, 0, 0], + "decoder_loss_weights": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], + "decoder_causal_attention": [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + "target_suffix_weights": [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0], + } + assert_dataset(converted_ds, expected) + + def test_prefix_suffix_lm_packed(self): + x = [ + {"inputs": [9, 4, 6], "targets": [3, 9], "suffixes": [2, 1]}, + {"inputs": [3, 2,], "targets": [4,], "suffixes": [1]} + ] + ds = create_default_dataset( + x, feature_names=("inputs", "targets", "suffixes")) + + task_feature_lengths = {"inputs": 8, "targets": 4, "suffixes": 3} + converter = feature_converters.PrefixSuffixLMFeatureConverter(pack=True) + converted_ds = converter(ds, task_feature_lengths) + + expected = { + "decoder_target_tokens": [9, 4, 6, 3, 9, 2, 1, 3, 2, 4, 1, 0, 0, 0, 0], + "decoder_input_tokens": [0, 9, 4, 6, 3, 9, 2, 0, 3, 2, 4, 0, 0, 0, 0], + "decoder_loss_weights": [0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0], + "target_suffix_weights": [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0], + "decoder_positions": [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 0, 0, 0, 0], + "decoder_segment_ids": [1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 0, 0, 0, 0], + "decoder_causal_attention": [ + 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + } + assert_dataset(converted_ds, expected) + + def test_prefix_suffix_lm_packed_trivial_suffoxes(self): + x = [ + {"inputs": [9, 4, 6], "targets": [3, 9], "suffixes": [2, 1]}, + {"inputs": [3, 2,], "targets": [4,], "suffixes": []} + ] + ds = create_default_dataset( + x, feature_names=("inputs", "targets", "suffixes")) + + task_feature_lengths = {"inputs": 8, "targets": 4, "suffixes": 3} + converter = feature_converters.PrefixSuffixLMFeatureConverter(pack=True) + converted_ds = converter(ds, task_feature_lengths) + + expected = { + "decoder_target_tokens": [9, 4, 6, 3, 9, 2, 1, 3, 2, 4, 0, 0, 0, 0, 0], + "decoder_input_tokens": [0, 9, 4, 6, 3, 9, 2, 0, 3, 2, 0, 0, 0, 0, 0], + "decoder_loss_weights": [0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0], + "target_suffix_weights": [0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0], + "decoder_positions": [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 0, 0, 0, 0, 0], + "decoder_segment_ids": [1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0], + "decoder_causal_attention": [ + 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + } + assert_dataset(converted_ds, expected) + + if __name__ == "__main__": tf.test.main()