"""
Test: XLB Stepper Autodiff - JAX vs Warp Comparison
This script tests whether gradients propagate through the XLB stepper
for both JAX and Warp backends. It performs identical tests on both
backends and compares the results side-by-side.
Expected result: JAX works, Warp does not (stepper lacks adjoint kernels).
Usage:
python examples/cfd/test_stepper_autodiff.py
"""
import numpy as np
print()
print("=" * 70)
print("XLB STEPPER AUTODIFF TEST")
print("=" * 70)
print()
print("This test checks if gradients propagate through the LBM stepper.")
print("We run the SAME test on both JAX and Warp backends and compare.")
print()
# Common parameters
grid_shape = (32, 32)
omega = 1.8
print("-" * 70)
print("TEST CONFIGURATION")
print("-" * 70)
print(f" Grid shape: {grid_shape}")
print(f" Omega: {omega}")
print(" Precision: FP32FP32")
print(" Boundary: Periodic (no walls)")
print(" Collision: BGK")
print(" Test: Forward 1 step -> Compute rho -> Loss -> Backward")
print()
# =============================================================================
# JAX BACKEND TEST (run first - always works)
# =============================================================================
import jax
import jax.numpy as jnp
from jax import value_and_grad
import xlb
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.grid import grid_factory
from xlb.operator.stepper import IncompressibleNavierStokesStepper
import xlb.velocity_set
precision_policy = PrecisionPolicy.FP32FP32
jax_velocity_set = xlb.velocity_set.D2Q9(
precision_policy=precision_policy,
compute_backend=ComputeBackend.JAX,
)
xlb.init(
velocity_set=jax_velocity_set,
default_backend=ComputeBackend.JAX,
default_precision_policy=precision_policy,
)
jax_grid = grid_factory(grid_shape, compute_backend=ComputeBackend.JAX)
jax_stepper = IncompressibleNavierStokesStepper(
grid=jax_grid,
boundary_conditions=[], # Periodic
collision_type="BGK",
)
jax_f_0, jax_f_1, jax_bc_mask, jax_missing_mask = jax_stepper.prepare_fields()
def jax_forward_and_loss(f_in):
f_out, _ = jax_stepper(f_in, jax_f_1, jax_bc_mask, jax_missing_mask, omega, 0)
rho = jnp.sum(f_out, axis=0)
return jnp.sum(rho ** 2)
jax_loss_val, jax_grad = value_and_grad(jax_forward_and_loss)(jax_f_0)
jax_loss_val = float(jax_loss_val)
jax_grad_norm = float(jnp.linalg.norm(jax_grad))
# =============================================================================
# WARP BACKEND TEST (may fail due to API compatibility)
# =============================================================================
warp_loss_val = None
warp_f_in_grad_norm = None
warp_f_out_grad_norm = None
warp_rho_grad_norm = None
warp_error = None
try:
import warp as wp
wp.init()
from xlb.operator.macroscopic import Macroscopic
warp_velocity_set = xlb.velocity_set.D2Q9(
precision_policy=precision_policy,
compute_backend=ComputeBackend.WARP,
)
xlb.init(
velocity_set=warp_velocity_set,
default_backend=ComputeBackend.WARP,
default_precision_policy=precision_policy,
)
warp_grid = grid_factory(grid_shape, compute_backend=ComputeBackend.WARP)
warp_stepper = IncompressibleNavierStokesStepper(
grid=warp_grid,
boundary_conditions=[], # Periodic
collision_type="BGK",
)
warp_f_0, warp_f_1, warp_bc_mask, warp_missing_mask = warp_stepper.prepare_fields()
warp_macro = Macroscopic(
velocity_set=warp_velocity_set,
precision_policy=precision_policy,
compute_backend=ComputeBackend.WARP,
)
q = warp_velocity_set.q
shape_4d = (*grid_shape, 1)
@wp.kernel
def warp_loss_kernel(rho: wp.array4d(dtype=wp.float32), loss: wp.array(dtype=wp.float32)):
i, j, k = wp.tid()
wp.atomic_add(loss, 0, rho[0, i, j, k] ** 2.0)
f_in_warp = wp.zeros((q, *shape_4d), dtype=wp.float32, requires_grad=True)
f_out_warp = wp.zeros((q, *shape_4d), dtype=wp.float32, requires_grad=True)
rho_warp = wp.zeros((1, *shape_4d), dtype=wp.float32, requires_grad=True)
u_warp = wp.zeros((2, *shape_4d), dtype=wp.float32, requires_grad=True)
loss_warp = wp.zeros((1,), dtype=wp.float32, requires_grad=True)
wp.copy(f_in_warp, warp_f_0)
with wp.Tape() as tape:
f_out_warp, f_in_warp = warp_stepper(f_in_warp, f_out_warp, warp_bc_mask, warp_missing_mask, omega, 0)
rho_warp, u_warp = warp_macro(f_out_warp, rho_warp, u_warp)
wp.launch(warp_loss_kernel, inputs=[rho_warp], outputs=[loss_warp], dim=rho_warp.shape[1:])
warp_loss_val = float(loss_warp.numpy()[0])
loss_warp.grad.fill_(1.0)
tape.backward()
warp_f_in_grad = f_in_warp.grad.numpy() if f_in_warp.grad is not None else np.zeros_like(warp_f_0.numpy())
warp_f_out_grad = f_out_warp.grad.numpy() if f_out_warp.grad is not None else np.zeros_like(warp_f_0.numpy())
warp_rho_grad = rho_warp.grad.numpy() if rho_warp.grad is not None else np.zeros((1, *shape_4d))
warp_f_in_grad_norm = float(np.linalg.norm(warp_f_in_grad))
warp_f_out_grad_norm = float(np.linalg.norm(warp_f_out_grad))
warp_rho_grad_norm = float(np.linalg.norm(warp_rho_grad))
except Exception as e:
warp_error = str(e)
# =============================================================================
# SIDE-BY-SIDE RESULTS
# =============================================================================
print("=" * 70)
print("RESULTS: SIDE-BY-SIDE COMPARISON")
print("=" * 70)
print()
print(f"{'Metric':<40} {'WARP':<15} {'JAX':<15}")
print("-" * 70)
if warp_loss_val is not None:
print(f"{'Loss value':<40} {warp_loss_val:<15.2f} {jax_loss_val:<15.2f}")
print(f"{'Gradient norm (through stepper)':<40} {warp_f_in_grad_norm:<15.2f} {jax_grad_norm:<15.2f}")
else:
print(f"{'Loss value':<40} {'ERROR':<15} {jax_loss_val:<15.2f}")
print(f"{'Gradient norm (through stepper)':<40} {'ERROR':<15} {jax_grad_norm:<15.2f}")
print()
if warp_error:
print("-" * 70)
print("WARP ERROR")
print("-" * 70)
print(f" {warp_error}")
print()
if warp_f_in_grad_norm is not None:
print("-" * 70)
print("GRADIENT FLOW ANALYSIS (Warp)")
print("-" * 70)
print()
print(" Checking gradients at each stage:")
print()
print(" 1. loss.grad (seed) : 1.0")
print(f" 2. d(loss)/d(rho) gradient norm : {warp_rho_grad_norm:.2f}")
print(f" 3. d(loss)/d(f_out) gradient norm: {warp_f_out_grad_norm:.2f}")
print(f" 4. d(loss)/d(f_in) gradient norm : {warp_f_in_grad_norm:.2f} <-- PROBLEM")
print()
print(" Gradient flows: loss -> rho -> f_out (Macroscopic works)")
print(" Gradient STOPS: f_out -> f_in (Stepper broken)")
print()
print("=" * 70)
print("DIAGNOSIS")
print("=" * 70)
print()
if warp_error:
print(" Warp backend failed to initialize.")
print(f" Error: {warp_error}")
print()
print(" This may be due to Warp API changes (e.g., wp.mat removed).")
print(" Even when Warp works, the stepper gradient is 0.0.")
elif warp_f_in_grad_norm == 0 and jax_grad_norm > 0:
print(" CONFIRMED: Warp stepper does not propagate gradients.")
print()
print(" WHY THIS HAPPENS:")
print(" Warp's autodiff (wp.Tape) requires either:")
print(" a) Automatic adjoint generation (simple kernels), or")
print(" b) Manual @wp.func_grad implementations (complex kernels)")
print()
print(" XLB's stepper has patterns that prevent auto-differentiation:")
print(" - Early returns: 'if _boundary_id == wp.uint8(255): return'")
print(" - Integer conditionals and mask operations")
print()
print(" Warp silently returns 0.0 gradient when it cannot differentiate.")
print()
print(" JAX works because it uses source-code transformation, not tape.")
else:
print(" Unexpected result - please investigate.")
print()
print("=" * 70)
print("SUMMARY")
print("=" * 70)
print()
if warp_loss_val is not None:
warp_status = "BROKEN" if warp_f_in_grad_norm == 0 else "OK"
print(f" WARP: Loss={warp_loss_val:.2f}, Gradient={warp_f_in_grad_norm:.2f} --> {warp_status}")
else:
print(f" WARP: ERROR --> BROKEN")
jax_status = "OK" if jax_grad_norm > 0 else "BROKEN"
print(f" JAX: Loss={jax_loss_val:.2f}, Gradient={jax_grad_norm:.2f} --> {jax_status}")
print()
print("=" * 70)
When using the Warp backend for differentiable LBM, gradients do not flow through the stepper. This prevents gradient descent and other optimization methods from working.
The Problem
When computing gradients through an LBM simulation:
Macroscopic(rho, u)Stepper(collision + streaming)This breaks inverse problems and differentiable physics with Warp.
Test Script
Run with:
python examples/cfd/test_stepper_autodiff.pytest_stepper_autodiff.py
Test Output
Why This Happens
Warp's autodiff (
wp.Tape) needs either:@wp.func_gradadjoint functionsXLB's stepper has patterns Warp cannot auto-differentiate:
Warp silently returns 0.0 when it cannot differentiate (no error thrown).
The
Macroscopicoperator works because it is a simple summation kernel. TheStepper(collision + streaming) has complex control flow that prevents automatic adjoint generation.How to Fix (Future Work)
Add manual
@wp.func_gradadjoint implementations for:xlb/operator/collision/bgk.py-warp_functional()xlb/operator/stream/stream.py-warp_functional()xlb/operator/equilibrium/*.py-warp_functional()Workaround
Use the JAX backend for differentiable LBM applications.