Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion qa/L0_cppunittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

set -e

: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

# Find TE
: ${TE_PATH:=/opt/transformerengine}
TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}')
Expand All @@ -17,4 +20,4 @@ cd $TE_PATH/tests/cpp
cmake -GNinja -Bbuild .
cmake --build build
export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS))
ctest --test-dir build -j$NUM_PARALLEL_JOBS
ctest --test-dir build -j$NUM_PARALLEL_JOBS --output-junit $XML_LOG_DIR/ctest_cppunittest.xml
8 changes: 3 additions & 5 deletions tests/pytorch/attention/test_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import ModelConfig, get_available_attention_backends
from utils import ModelConfig, get_available_attention_backends, run_distributed

pytest_logging_level = logging.getLevelName(logging.root.level)

Expand Down Expand Up @@ -125,7 +125,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
if not flash_attn_supported:
pytest.skip("No attention backend available.")

subprocess.run(
run_distributed(
get_bash_arguments(
num_gpus_per_node=num_gpus,
dtype=dtype,
Expand All @@ -135,7 +135,6 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
cp_comm_type=cp_comm_type,
log_level=pytest_logging_level,
),
check=True,
)


Expand Down Expand Up @@ -368,7 +367,7 @@ def test_cp_with_fused_attention(
if not fused_attn_supported:
pytest.skip("No attention backend available.")

subprocess.run(
run_distributed(
get_bash_arguments(
num_gpus_per_node=num_gpus,
dtype=dtype,
Expand All @@ -384,5 +383,4 @@ def test_cp_with_fused_attention(
is_training=is_training,
log_level=pytest_logging_level,
),
check=True,
)
5 changes: 4 additions & 1 deletion tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import sys
import pathlib

sys.path.append(str(pathlib.Path(__file__).resolve().parent.parent))
from utils import run_distributed

import pytest
import torch
from torch import nn
Expand Down Expand Up @@ -1207,7 +1210,7 @@ def test_nvfp4_partial_cast_matches_full(world_size: int) -> None:
current_file,
"--parallel-nvfp4-partial",
]
subprocess.run(command, check=True)
run_distributed(command)


def test_single_gpu_partial_cast_vs_full():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
sys.path.append(str(_current_file.parent.parent))
from utils import dtype_tols, make_recipe, str_to_dtype
from utils import dtype_tols, make_recipe, run_distributed, str_to_dtype

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
Expand Down Expand Up @@ -463,7 +463,7 @@ def test_fuser_ops_with_userbuffers(
env["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"

# Launch parallel job
result = subprocess.run(command, check=True, env=env)
run_distributed(command, env=env)


def main() -> None:
Expand Down
12 changes: 8 additions & 4 deletions tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
# See LICENSE for license information.

import os
import sys
import subprocess
from pathlib import Path

sys.path.append(str(Path(__file__).resolve().parent.parent))
from utils import run_distributed

import pytest
import torch

Expand All @@ -20,7 +24,7 @@
def test_fsdp2_model_tests():
"""All FSDP2 model tests (parametrized internally by recipe, fp8_init, sharding, layer)."""
test_path = _FSDP2_DIR / "run_fsdp2_model.py"
result = subprocess.run(
run_distributed(
[
"torchrun",
f"--nproc_per_node={NUM_PROCS}",
Expand All @@ -32,10 +36,10 @@ def test_fsdp2_model_tests():
"-s",
"--tb=short",
],
valid_returncodes=(0, 5),
env=os.environ,
timeout=600,
)
assert result.returncode in (0, 5), f"Inner pytest failed with exit code {result.returncode}"


@pytest.mark.skipif(NUM_PROCS < 2, reason="Requires 2+ GPUs")
Expand All @@ -44,7 +48,7 @@ def test_fsdp2_fused_adam_tests():
"""All FSDP2 FusedAdam tests (parametrized internally by recipe, test variant)."""
test_path = _FSDP2_DIR / "run_fsdp2_fused_adam.py"
nproc = min(NUM_PROCS, 2)
result = subprocess.run(
run_distributed(
[
"torchrun",
f"--nproc_per_node={nproc}",
Expand All @@ -56,10 +60,10 @@ def test_fsdp2_fused_adam_tests():
"-s",
"--tb=short",
],
valid_returncodes=(0, 5),
env=os.environ,
timeout=600,
)
assert result.returncode in (0, 5), f"Inner pytest failed with exit code {result.returncode}"


def test_dummy() -> None:
Expand Down
34 changes: 33 additions & 1 deletion tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import logging
import os
import subprocess
from contextlib import contextmanager
from typing import Optional, Tuple, Dict, Any, List
from typing import Optional, Sequence, Tuple, Dict, Any, List
from packaging.version import Version as PkgVersion

import torch
Expand Down Expand Up @@ -407,3 +408,34 @@ def assert_close_grads(
assert actual is not None
assert expected is not None
assert_close(actual.grad, expected.grad, **kwargs)


def run_distributed(
args: Sequence[str],
*,
valid_returncodes: Sequence[int] = (0,),
**kwargs,
) -> subprocess.CompletedProcess:
"""Run a distributed subprocess with stderr capture for better error reporting.

stdout streams to the terminal in real time for interactive debugging.
On failure, stderr (containing Python tracebacks) is included in the
AssertionError so pytest writes it into the JUnit XML report.

Args:
args: Command and arguments to run.
valid_returncodes: Return codes considered success (default: (0,)).
Use (0, 5) for inner pytest runs where 5 means all tests skipped.
**kwargs: Passed through to subprocess.run (e.g. env, timeout).
"""
result = subprocess.run(args, stderr=subprocess.PIPE, text=True, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 **kwargs can silently conflict with stderr and text

If a caller ever passes stderr= or text= through **kwargs, Python will raise TypeError: subprocess.run() got multiple values for keyword argument 'stderr'. Consider explicitly popping or blocking those keys, or documenting the restriction:

kwargs.pop("stderr", None)  # always captured internally
kwargs.pop("text", None)    # always text mode internally
result = subprocess.run(args, stderr=subprocess.PIPE, text=True, **kwargs)

None of the current call sites pass these, so this is not an immediate bug — just a fragile API surface.

if result.returncode not in valid_returncodes:
cmd_str = " ".join(str(a) for a in args)
msg = f"Command exited with code {result.returncode}:\n {cmd_str}\n"
if result.stderr:
stderr_tail = result.stderr[-4000:]
if len(result.stderr) > 4000:
stderr_tail = "... [truncated] ...\n" + stderr_tail
msg += f"\n--- stderr ---\n{stderr_tail}"
raise AssertionError(msg)
return result
Loading