Skip to content

Warp backend does not support gradient-based optimization (stepper returns zero gradients) #161

@Medyan-Naser

Description

@Medyan-Naser

When using the Warp backend for differentiable LBM, gradients do not flow through the stepper. This prevents gradient descent and other optimization methods from working.


The Problem

When computing gradients through an LBM simulation:

Component Gradient Flow
Macroscopic (rho, u) Works
Stepper (collision + streaming) Returns 0.0

This breaks inverse problems and differentiable physics with Warp.


Test Script

Run with: python examples/cfd/test_stepper_autodiff.py

test_stepper_autodiff.py
"""
Test: XLB Stepper Autodiff - JAX vs Warp Comparison

This script tests whether gradients propagate through the XLB stepper
for both JAX and Warp backends. It performs identical tests on both
backends and compares the results side-by-side.

Expected result: JAX works, Warp does not (stepper lacks adjoint kernels).

Usage:
    python examples/cfd/test_stepper_autodiff.py
"""
import numpy as np

print()
print("=" * 70)
print("XLB STEPPER AUTODIFF TEST")
print("=" * 70)
print()
print("This test checks if gradients propagate through the LBM stepper.")
print("We run the SAME test on both JAX and Warp backends and compare.")
print()

# Common parameters
grid_shape = (32, 32)
omega = 1.8

print("-" * 70)
print("TEST CONFIGURATION")
print("-" * 70)
print(f"  Grid shape:       {grid_shape}")
print(f"  Omega:            {omega}")
print("  Precision:        FP32FP32")
print("  Boundary:         Periodic (no walls)")
print("  Collision:        BGK")
print("  Test:             Forward 1 step -> Compute rho -> Loss -> Backward")
print()

# =============================================================================
# JAX BACKEND TEST (run first - always works)
# =============================================================================

import jax
import jax.numpy as jnp
from jax import value_and_grad

import xlb
from xlb.compute_backend import ComputeBackend
from xlb.precision_policy import PrecisionPolicy
from xlb.grid import grid_factory
from xlb.operator.stepper import IncompressibleNavierStokesStepper
import xlb.velocity_set

precision_policy = PrecisionPolicy.FP32FP32

jax_velocity_set = xlb.velocity_set.D2Q9(
    precision_policy=precision_policy,
    compute_backend=ComputeBackend.JAX,
)

xlb.init(
    velocity_set=jax_velocity_set,
    default_backend=ComputeBackend.JAX,
    default_precision_policy=precision_policy,
)

jax_grid = grid_factory(grid_shape, compute_backend=ComputeBackend.JAX)

jax_stepper = IncompressibleNavierStokesStepper(
    grid=jax_grid,
    boundary_conditions=[],  # Periodic
    collision_type="BGK",
)

jax_f_0, jax_f_1, jax_bc_mask, jax_missing_mask = jax_stepper.prepare_fields()

def jax_forward_and_loss(f_in):
    f_out, _ = jax_stepper(f_in, jax_f_1, jax_bc_mask, jax_missing_mask, omega, 0)
    rho = jnp.sum(f_out, axis=0)
    return jnp.sum(rho ** 2)

jax_loss_val, jax_grad = value_and_grad(jax_forward_and_loss)(jax_f_0)
jax_loss_val = float(jax_loss_val)
jax_grad_norm = float(jnp.linalg.norm(jax_grad))

# =============================================================================
# WARP BACKEND TEST (may fail due to API compatibility)
# =============================================================================

warp_loss_val = None
warp_f_in_grad_norm = None
warp_f_out_grad_norm = None
warp_rho_grad_norm = None
warp_error = None

try:
    import warp as wp
    wp.init()
    
    from xlb.operator.macroscopic import Macroscopic
    
    warp_velocity_set = xlb.velocity_set.D2Q9(
        precision_policy=precision_policy,
        compute_backend=ComputeBackend.WARP,
    )

    xlb.init(
        velocity_set=warp_velocity_set,
        default_backend=ComputeBackend.WARP,
        default_precision_policy=precision_policy,
    )

    warp_grid = grid_factory(grid_shape, compute_backend=ComputeBackend.WARP)

    warp_stepper = IncompressibleNavierStokesStepper(
        grid=warp_grid,
        boundary_conditions=[],  # Periodic
        collision_type="BGK",
    )

    warp_f_0, warp_f_1, warp_bc_mask, warp_missing_mask = warp_stepper.prepare_fields()

    warp_macro = Macroscopic(
        velocity_set=warp_velocity_set,
        precision_policy=precision_policy,
        compute_backend=ComputeBackend.WARP,
    )

    q = warp_velocity_set.q
    shape_4d = (*grid_shape, 1)

    @wp.kernel
    def warp_loss_kernel(rho: wp.array4d(dtype=wp.float32), loss: wp.array(dtype=wp.float32)):
        i, j, k = wp.tid()
        wp.atomic_add(loss, 0, rho[0, i, j, k] ** 2.0)

    f_in_warp = wp.zeros((q, *shape_4d), dtype=wp.float32, requires_grad=True)
    f_out_warp = wp.zeros((q, *shape_4d), dtype=wp.float32, requires_grad=True)
    rho_warp = wp.zeros((1, *shape_4d), dtype=wp.float32, requires_grad=True)
    u_warp = wp.zeros((2, *shape_4d), dtype=wp.float32, requires_grad=True)
    loss_warp = wp.zeros((1,), dtype=wp.float32, requires_grad=True)
    wp.copy(f_in_warp, warp_f_0)

    with wp.Tape() as tape:
        f_out_warp, f_in_warp = warp_stepper(f_in_warp, f_out_warp, warp_bc_mask, warp_missing_mask, omega, 0)
        rho_warp, u_warp = warp_macro(f_out_warp, rho_warp, u_warp)
        wp.launch(warp_loss_kernel, inputs=[rho_warp], outputs=[loss_warp], dim=rho_warp.shape[1:])

    warp_loss_val = float(loss_warp.numpy()[0])
    loss_warp.grad.fill_(1.0)
    tape.backward()

    warp_f_in_grad = f_in_warp.grad.numpy() if f_in_warp.grad is not None else np.zeros_like(warp_f_0.numpy())
    warp_f_out_grad = f_out_warp.grad.numpy() if f_out_warp.grad is not None else np.zeros_like(warp_f_0.numpy())
    warp_rho_grad = rho_warp.grad.numpy() if rho_warp.grad is not None else np.zeros((1, *shape_4d))

    warp_f_in_grad_norm = float(np.linalg.norm(warp_f_in_grad))
    warp_f_out_grad_norm = float(np.linalg.norm(warp_f_out_grad))
    warp_rho_grad_norm = float(np.linalg.norm(warp_rho_grad))

except Exception as e:
    warp_error = str(e)

# =============================================================================
# SIDE-BY-SIDE RESULTS
# =============================================================================

print("=" * 70)
print("RESULTS: SIDE-BY-SIDE COMPARISON")
print("=" * 70)
print()
print(f"{'Metric':<40} {'WARP':<15} {'JAX':<15}")
print("-" * 70)

if warp_loss_val is not None:
    print(f"{'Loss value':<40} {warp_loss_val:<15.2f} {jax_loss_val:<15.2f}")
    print(f"{'Gradient norm (through stepper)':<40} {warp_f_in_grad_norm:<15.2f} {jax_grad_norm:<15.2f}")
else:
    print(f"{'Loss value':<40} {'ERROR':<15} {jax_loss_val:<15.2f}")
    print(f"{'Gradient norm (through stepper)':<40} {'ERROR':<15} {jax_grad_norm:<15.2f}")
print()

if warp_error:
    print("-" * 70)
    print("WARP ERROR")
    print("-" * 70)
    print(f"  {warp_error}")
    print()

if warp_f_in_grad_norm is not None:
    print("-" * 70)
    print("GRADIENT FLOW ANALYSIS (Warp)")
    print("-" * 70)
    print()
    print("  Checking gradients at each stage:")
    print()
    print("    1. loss.grad (seed)              : 1.0")
    print(f"    2. d(loss)/d(rho) gradient norm  : {warp_rho_grad_norm:.2f}")
    print(f"    3. d(loss)/d(f_out) gradient norm: {warp_f_out_grad_norm:.2f}")
    print(f"    4. d(loss)/d(f_in) gradient norm : {warp_f_in_grad_norm:.2f}  <-- PROBLEM")
    print()
    print("  Gradient flows: loss -> rho -> f_out (Macroscopic works)")
    print("  Gradient STOPS: f_out -> f_in (Stepper broken)")
    print()

print("=" * 70)
print("DIAGNOSIS")
print("=" * 70)
print()

if warp_error:
    print("  Warp backend failed to initialize.")
    print(f"  Error: {warp_error}")
    print()
    print("  This may be due to Warp API changes (e.g., wp.mat removed).")
    print("  Even when Warp works, the stepper gradient is 0.0.")
elif warp_f_in_grad_norm == 0 and jax_grad_norm > 0:
    print("  CONFIRMED: Warp stepper does not propagate gradients.")
    print()
    print("  WHY THIS HAPPENS:")
    print("  Warp's autodiff (wp.Tape) requires either:")
    print("    a) Automatic adjoint generation (simple kernels), or")
    print("    b) Manual @wp.func_grad implementations (complex kernels)")
    print()
    print("  XLB's stepper has patterns that prevent auto-differentiation:")
    print("    - Early returns: 'if _boundary_id == wp.uint8(255): return'")
    print("    - Integer conditionals and mask operations")
    print()
    print("  Warp silently returns 0.0 gradient when it cannot differentiate.")
    print()
    print("  JAX works because it uses source-code transformation, not tape.")
else:
    print("  Unexpected result - please investigate.")

print()
print("=" * 70)
print("SUMMARY")
print("=" * 70)
print()
if warp_loss_val is not None:
    warp_status = "BROKEN" if warp_f_in_grad_norm == 0 else "OK"
    print(f"  WARP: Loss={warp_loss_val:.2f}, Gradient={warp_f_in_grad_norm:.2f} --> {warp_status}")
else:
    print(f"  WARP: ERROR --> BROKEN")
jax_status = "OK" if jax_grad_norm > 0 else "BROKEN"
print(f"  JAX:  Loss={jax_loss_val:.2f}, Gradient={jax_grad_norm:.2f} --> {jax_status}")
print()
print("=" * 70)


Test Output

======================================================================
XLB STEPPER AUTODIFF TEST
======================================================================

This test checks if gradients propagate through the LBM stepper.
We run the SAME test on both JAX and Warp backends and compare.

----------------------------------------------------------------------
TEST CONFIGURATION
----------------------------------------------------------------------
  Grid shape:       (32, 32)
  Omega:            1.8
  Precision:        FP32FP32
  Boundary:         Periodic (no walls)
  Collision:        BGK
  Test:             Forward 1 step -> Compute rho -> Loss -> Backward

======================================================================
RESULTS: SIDE-BY-SIDE COMPARISON
======================================================================

Metric                                   WARP            JAX            
----------------------------------------------------------------------
Loss value                               1024.00         1024.00        
Gradient norm (through stepper)          0.00            192.00         

----------------------------------------------------------------------
GRADIENT FLOW ANALYSIS (Warp)
----------------------------------------------------------------------

  Checking gradients at each stage:

    1. loss.grad (seed)              : 1.0
    2. d(loss)/d(rho) gradient norm  : 0.00
    3. d(loss)/d(f_out) gradient norm: 192.00
    4. d(loss)/d(f_in) gradient norm : 0.00  <-- PROBLEM

  Gradient flows: loss -> rho -> f_out (Macroscopic works)
  Gradient STOPS: f_out -> f_in (Stepper broken)


======================================================================
SUMMARY
======================================================================

  WARP: Loss=1024.00, Gradient=0.00 --> BROKEN
  JAX:  Loss=1024.00, Gradient=192.00 --> OK

======================================================================

Why This Happens

Warp's autodiff (wp.Tape) needs either:

  1. Simple kernels it can auto-differentiate, OR
  2. Manual @wp.func_grad adjoint functions

XLB's stepper has patterns Warp cannot auto-differentiate:

# From xlb/operator/stepper/nse_stepper.py
if _boundary_id == wp.uint8(255):
    return  # Early return breaks autodiff

Warp silently returns 0.0 when it cannot differentiate (no error thrown).

The Macroscopic operator works because it is a simple summation kernel. The Stepper (collision + streaming) has complex control flow that prevents automatic adjoint generation.


How to Fix (Future Work)

Add manual @wp.func_grad adjoint implementations for:

  1. xlb/operator/collision/bgk.py - warp_functional()
  2. xlb/operator/stream/stream.py - warp_functional()
  3. xlb/operator/equilibrium/*.py - warp_functional()

Workaround

Use the JAX backend for differentiable LBM applications.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions