Skip to content

Commit 780bbbe

Browse files
TimDettmersclaude
authored andcommitted
Fix matmul_4bit out parameter not writing to output tensor (#1235)
The `out` kwarg in `matmul_4bit()` was accepted but ignored in the `MatMul4Bit.forward()` path (2D+ inputs). The computed result was returned as a new tensor without being copied into `out`. Added `out.copy_(output)` after computing the linear result so the caller's pre-allocated tensor is populated as expected. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6a18715 commit 780bbbe

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,12 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
322322
out.copy_(output)
323323
output = out
324324

325-
# 3. Save state
325+
# 3. Write to out tensor if provided
326+
if out is not None:
327+
out.copy_(output)
328+
output = out
329+
330+
# 4. Save state
326331
ctx.state = quant_state
327332
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
328333

tests/test_autograd.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,45 @@ def test_matmul_4bit(
258258

259259
if req_grad[2]:
260260
torch.testing.assert_close(gradBias1, gradBias2)
261+
262+
263+
@pytest.mark.parametrize("device", get_available_devices())
264+
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"], ids=id_formatter("quant_type"))
265+
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=describe_dtype)
266+
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
267+
def test_matmul_4bit_out_parameter(device, quant_type, dtype, has_bias):
268+
"""Test that matmul_4bit(A, B, out=output) writes the result into output (issue #1235)."""
269+
M, K, N = 32, 64, 48
270+
271+
# Create weight matrix (K, N) and quantize — matmul_4bit computes A @ dequant(B)
272+
W = torch.randn(K, N, device=device, dtype=dtype)
273+
torch.nn.init.xavier_uniform_(W)
274+
B_quant, quant_state = bnb.functional.quantize_4bit(W, quant_type=quant_type)
275+
276+
bias = None
277+
if has_bias:
278+
bias = torch.randn(N, device=device, dtype=dtype)
279+
280+
# --- Test 2D input (matrix path through MatMul4Bit) ---
281+
A_2d = torch.randn(M, K, device=device, dtype=dtype)
282+
expected = bnb.matmul_4bit(A_2d, B_quant, quant_state, bias=bias)
283+
284+
out_2d = torch.zeros(M, N, device=device, dtype=dtype)
285+
returned = bnb.matmul_4bit(A_2d, B_quant, quant_state, out=out_2d, bias=bias)
286+
287+
# out tensor should contain the result
288+
torch.testing.assert_close(out_2d, expected)
289+
# returned value should be the same object as out
290+
assert returned.data_ptr() == out_2d.data_ptr(), "returned tensor should share storage with out"
291+
292+
# --- Test 1D input (gemv path) if on CUDA and blocksize divides K ---
293+
# Skip bias for 1D: the gemv path has a pre-existing shape bug with bias when K != N.
294+
if device == "cuda" and K % quant_state.blocksize == 0 and not has_bias:
295+
A_1d = torch.randn(K, device=device, dtype=dtype)
296+
expected_1d = bnb.matmul_4bit(A_1d, B_quant, quant_state)
297+
298+
out_1d = torch.zeros_like(expected_1d)
299+
returned_1d = bnb.matmul_4bit(A_1d, B_quant, quant_state, out=out_1d)
300+
301+
torch.testing.assert_close(out_1d, expected_1d)
302+
assert returned_1d.data_ptr() == out_1d.data_ptr(), "returned tensor should share storage with out"

0 commit comments

Comments
 (0)