Skip to content

[JAX] Add warning if using BSHD and max_segments_per_seq > 1#2796

Open
jberchtold-nvidia wants to merge 4 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/te-max-segments-per-seq-warning-bshd
Open

[JAX] Add warning if using BSHD and max_segments_per_seq > 1#2796
jberchtold-nvidia wants to merge 4 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/te-max-segments-per-seq-warning-bshd

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Adds a small warning if the user tries to use BSHD with max_segments_per_seq > 1

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Adds a warning if the user tries to use BSHD with max_segments_per_seq > 1
  • Adds a new test to validate this warning is shown correctly

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 24, 2026

Greptile Summary

This PR adds a UserWarning to fused_attn when max_segments_per_seq > 1 is combined with a non-THD layout (e.g. BS3HD, BSHD_BSHD_BSHD), since sequence packing only applies to THD layouts and the parameter is silently ignored otherwise. A new parametrized test class validates the warning fires before GPU/cuDNN dispatch for both affected layout types.

  • transformer_engine/jax/attention.py: Warning is inserted after the legacy early-return guard and before the _fused_attn call, ensuring it fires on the normal code path. Includes stacklevel=2 so the warning points to the caller's file/line, not library internals. Message names the applicable THD layouts for user guidance.
  • tests/jax/test_fused_attn.py: TestMaxSegmentsPerSeqWarning follows the existing @staticmethod + @pytest.mark.parametrize convention used throughout the file. Wraps the dispatch in try/except so the test passes on machines without GPU/cuDNN, and asserts both UserWarning category and "max_segments_per_seq" message content.

Confidence Score: 5/5

  • PR is safe to merge — the change is a non-breaking, informational warning with no effect on the computation path.
  • The previously flagged issues (duplicate if statement, missing stacklevel) are fully resolved. The warning placement, condition, and message are all correct. The test follows established patterns in the file and will pass without GPU hardware. No logic is altered for the happy path.
  • No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/jax/attention.py Adds a UserWarning (with stacklevel=2) immediately after the legacy early-return path in fused_attn when max_segments_per_seq > 1 and the layout is not THD — correct placement, correct guard condition, and correct idiom for library-level warnings.
tests/jax/test_fused_attn.py Adds TestMaxSegmentsPerSeqWarning with two parametrized cases (BS3HD, BSHD_BSHD_BSHD); follows the existing @staticmethod + @pytest.mark.parametrize pattern used throughout the file, captures warnings before GPU dispatch, and asserts both category and message content.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["fused_attn(qkv, ..., max_segments_per_seq, qkv_layout)"] --> B{"sequence_descriptor is None\nor isinstance ndarray?"}
    B -- Yes --> C["warnings.warn(DeprecationWarning)\n→ _legacy_fused_attn(...)"]
    B -- No --> D{"max_segments_per_seq > 1\nAND NOT qkv_layout.is_thd()?"}
    D -- Yes --> E["warnings.warn(UserWarning, stacklevel=2)\n'max_segments_per_seq has no effect\nwith non-THD layouts'"]
    D -- No --> F["_fused_attn(...)"]
    E --> F
    C --> G["return output (legacy)"]
    F --> H["return output"]
Loading

Reviews (3): Last reviewed commit: "Update transformer_engine/jax/attention...." | Re-trigger Greptile

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant