Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 4 additions & 25 deletions tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
)

TRANSPOSE_VALS = [(False, True), (False, False)]
# Keep req_grad[1] (weight grad for B) disabled for test_matmullt.
# The req_grad[1] == True path there is deprecated, so we avoid generating that case.
REQ_GRAD_NO_B_WEIGHT = [flags for flags in BOOLEAN_TRIPLES if not flags[1]]


@pytest.mark.parametrize("device", get_available_devices())
Expand All @@ -26,19 +29,13 @@
ids=["func=matmul"],
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("req_grad", BOOLEAN_TRIPLES, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("req_grad", REQ_GRAD_NO_B_WEIGHT, ids=id_formatter("req_grad"))
@pytest.mark.parametrize("transpose", TRANSPOSE_VALS, ids=id_formatter("transpose"))
@pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))
@pytest.mark.parametrize("has_bias", TRUE_FALSE, ids=id_formatter("has_bias"))
def test_matmullt(
device, dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias
):
if device != "cuda":
if req_grad[1]:
# This will be deprecated for CUDA in the future. We don't expect
# this to work on any other device.
pytest.skip("Deprecated feature with CUDA support only.")

dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2)
dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3)
outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device=device)
Expand Down Expand Up @@ -111,7 +108,6 @@ def test_matmullt(
loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean()
loss_bnb.backward()
gradA1 = A.grad
gradB1 = B.grad
A.grad = None
B.grad = None
if has_bias:
Expand All @@ -121,7 +117,6 @@ def test_matmullt(
loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean()
loss_torch.backward()
gradA2 = A.grad
gradB2 = B.grad
A.grad = None
B.grad = None
if has_bias:
Expand All @@ -130,22 +125,6 @@ def test_matmullt(

if req_grad[0]:
torch.testing.assert_close(gradA1, gradA2, atol=0.015, rtol=0.1)
if req_grad[1]:
n = gradB1.numel()
if dim2 > 0:
assert torch.abs(gradB1).sum() > 0.0
assert torch.abs(gradB2).sum() > 0.0
else:
assert torch.abs(gradB1).sum() == 0.0
assert torch.abs(gradB2).sum() == 0.0

idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.10

idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3)
assert (idx == 0).sum().item() <= n * 0.02

torch.testing.assert_close(gradB1, gradB2, atol=0.18, rtol=0.3)

if req_grad[2]:
torch.testing.assert_close(gradBias1, gradBias2)
Expand Down
279 changes: 2 additions & 277 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import math
import platform
import random
import time

import einops
from packaging import version
import pytest
import torch

import bitsandbytes as bnb
from bitsandbytes import functional as F
from tests.helpers import (
BOOLEAN_TUPLES,
TRUE_FALSE,
describe_dtype,
get_available_devices,
Expand Down Expand Up @@ -339,280 +336,6 @@ def test_stable_embedding():
layer.reset_parameters()


def quant(x):
max1 = torch.abs(x).max()
x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)


def dequant(c, maxC):
return c.float() * (maxC / 127)


def mm_dequant(maxA, maxB, C):
return C.float() * (maxA / 127) * (maxB / 127)


def quant_multi(x, dim):
max1 = torch.amax(torch.abs(x), dim=dim, keepdim=True)
max1[max1 == 0] = 1.0
x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)


def quant_multi_chunk(x, dim, chunk_size=32):
if dim == 1:
x_chunked = einops.rearrange(x, "(c a) b -> c a b", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim + 1, keepdim=True)
max1 = torch.tile(max1, (1, 1, x.shape[1]))
max1 = max1.view(x.shape)
elif dim == 0:
x_chunked = einops.rearrange(x, "a (b c) -> a b c", c=chunk_size)
max1 = torch.amax(torch.abs(x_chunked), dim=dim, keepdim=True)
max1 = torch.tile(max1, (x.shape[0], 1, 1))
max1 = max1.view(x.shape)
max1[max1 == 0] = 1.0
x = torch.round(x / max1 * 127)
return max1, x.to(torch.int8)


def mean(xx):
return sum(xx) / float(len(xx))


methods = {
"linear": (
lambda x, dim: quant(x),
lambda x, dim: quant(x),
dequant,
dequant,
mm_dequant,
),
"vectorwise": (quant_multi, quant_multi, dequant, dequant, mm_dequant),
}


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
class TestIGEMMFunctional:
@pytest.mark.parametrize("dim1", [1024 * 2], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [1024 * 16], ids=id_formatter("dim2"))
@pytest.mark.parametrize("quant_methods", methods.values(), ids=methods.keys())
@pytest.mark.parametrize("batched", TRUE_FALSE, ids=id_formatter("batched"))
def test_approx_igemm(self, dim1, dim2, quant_methods, batched):
dim1 = dim1 - (dim1 % 32)
dim2 = dim2 - (dim2 % 32)
errors = []
relerrors = []
# print("")
for i in range(5):
if batched:
A = torch.normal(0, 0.5, size=(32, dim1, dim2 // 32), device="cuda")
B = torch.normal(0, 0.5, size=(32, dim2 // 32, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 2)
maxB, Bc = quant_methods[1](B, 1)
else:
A = torch.normal(0, 0.5, size=(dim1, dim2), device="cuda")
B = torch.normal(0, 0.5, size=(dim2, dim1), device="cuda")
maxA, Ac = quant_methods[0](A, 1)
maxB, Bc = quant_methods[1](B, 0)
torch.testing.assert_close(quant_methods[2](maxA, Ac), A, atol=0.025, rtol=0.05)
if batched:
out2 = torch.bmm(A, B)
C = torch.bmm(Ac.float(), Bc.float())
else:
out2 = torch.mm(A, B)
C = F.igemm(Ac, Bc)
out = quant_methods[4](maxA, maxB, C)
std = out2.std()
out /= std
out2 /= std
err = torch.abs(out - out2)
relerr = err / torch.abs(out2)
errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
# print(mean(errors))
# print(mean(relerrors))

@pytest.mark.parametrize("hidden_dim", [32, 256], ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", [16, 256], ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("seq_dim", [16, 256], ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_igemm(self, hidden_dim, batch_dim, transpose, seq_dim):
if (
torch.version.cuda == "13.0"
and torch.__version__ >= (2, 10)
and not any(transpose)
and batch_dim == 256
and seq_dim == 256
):
pytest.xfail("Failure due to regression in cuBLAS for CUDA Toolkit 13.0.2.")

hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 16)
seq_dim = seq_dim - (seq_dim % 16)
for i in range(k):
shapeA = (batch_dim, hidden_dim) if not transpose[0] else (hidden_dim, batch_dim)
shapeB = (
(32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
elif not transpose[0] and transpose[1]:
out2 = torch.matmul(A.float(), B.t().float())
out = F.igemm(A, B.t())
elif transpose[0] and not transpose[1]:
out2 = torch.matmul(A.t().float(), B.float())
out = F.igemm(A.t(), B)
elif transpose[0] and transpose[1]:
out2 = torch.matmul(A.t().float(), B.t().float())
out = F.igemm(A.t(), B.t())

torch.testing.assert_close(out.float(), out2)

for i in range(k):
shapeA = (batch_dim, seq_dim, hidden_dim)
shapeB = (
(32 * random.randint(1, 4), hidden_dim) if transpose[1] else (hidden_dim, 32 * random.randint(1, 4))
)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)
if not transpose[0] and not transpose[1]:
out2 = torch.matmul(A.float(), B.float())
out = F.igemm(A, B)
elif not transpose[0] and transpose[1]:
out2 = torch.matmul(A.float(), B.t().float())
out = F.igemm(A, B.t())

torch.testing.assert_close(out.float(), out2)

@pytest.mark.parametrize("seq_dim", [32, 256, 512], ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", [64, 1024, 4096], ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", [2, 8, 16], ids=id_formatter("batch_dim"))
def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
seq_dim = seq_dim - (seq_dim % 32)
hidden_dim = hidden_dim - (hidden_dim % 32)
batch_dim = batch_dim - (batch_dim % 2)
for i in range(25):
A = torch.randint(-128, 127, size=(batch_dim, seq_dim, hidden_dim), device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=(batch_dim, seq_dim, 1024), device="cuda").to(torch.int8)
out2 = torch.einsum("bsi, bso->io", A.float(), B.float())
iout = torch.empty(A.shape[2], B.shape[2], dtype=torch.int32, device=A.device)
out = F.igemm(A, B, out=iout)

torch.testing.assert_close(out.float(), out2)

@pytest.mark.parametrize("seq_dim", [32, 512], ids=id_formatter("seq_dim"))
@pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True)
minA = torch.amin(x, dim=2, keepdim=True)
scale = (maxA - minA) / 2.0
return (127 * (x - minA - scale) / scale).to(torch.int8), minA, scale

seq_dim = seq_dim - (seq_dim % 16)
hidden_dim = hidden_dim - (hidden_dim % 16)
batch_dim = batch_dim - (batch_dim % 2)
errs = []
relerrs = []
errs2 = []
relerrs2 = []
for i in range(k):
A = torch.normal(0.0, 0.5, size=(batch_dim, seq_dim, hidden_dim), device="cuda")
if transpose:
B = torch.normal(0, 0.5, size=(256, hidden_dim), device="cuda")
else:
B = torch.normal(0, 0.5, size=(hidden_dim, 256), device="cuda")
Ac, minA, scale = min_max(A)
if transpose:
maxB, Bc = quant_multi(B, dim=(1 if transpose else 0))
out = F.igemm(Ac, Bc.t())
out2 = torch.matmul(A, B.t())
offset = B.t().sum(0) * (minA + scale)
out = out.float()
out = (out * maxB.t() * scale / (127 * 127)) + offset

maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc.t())
out3 = mm_dequant(maxA, maxB.t(), out3)
else:
maxB, Bc = quant_multi(B, dim=0)
offset = B.sum(0) * (minA + scale)
out = F.igemm(Ac, Bc)
out2 = torch.matmul(A, B)
out = out.float()
out = (out * maxB * scale / (127 * 127)) + offset

maxA, Ac = quant_multi(A, dim=2)
out3 = F.igemm(Ac, Bc)
out3 = mm_dequant(maxA, maxB, out3)

std = out2.std()
out2 /= std
out /= std
out3 /= std

err = torch.abs(out - out2)
relerr = err / (torch.abs(out2) + 1e-7)

err2 = torch.abs(out3 - out2)
relerr2 = err2 / (torch.abs(out2) + 1e-7)

errs.append(err.mean().item())
relerrs.append(relerr.mean().item())
errs2.append(err2.mean().item())
relerrs2.append(relerr2.mean().item())
# print(mean(errs))
# print(mean(relerrs))
# print(mean(errs2))
# print(mean(relerrs2))
assert mean(errs) < 0.015

# There's a higher relerr on L40S with torch 2.4+cu118.
is_sm89 = torch.cuda.get_device_capability() == (8, 9)
if torch.version.cuda == "11.8" and is_sm89 and torch.__version__ < (2, 5):
assert mean(relerrs) < 0.41
else:
assert mean(relerrs) < 0.3

@pytest.mark.parametrize("dim1", [1, 64], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 128], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim3", [32, 256], ids=id_formatter("dim3"))
@pytest.mark.parametrize("dim4", [32, 256], ids=id_formatter("dim4"))
@pytest.mark.parametrize("transpose", BOOLEAN_TUPLES, ids=id_formatter("transpose"))
def test_ibmm(self, dim1, dim2, dim3, dim4, transpose):
if torch.version.cuda == "13.0" and torch.__version__ >= (2, 10) and dim1 == 64:
pytest.xfail("Failure due to regression in cuBLAS for CUDA Toolkit 13.0.2.")

dim2 = dim2 - (dim2 % 16)
dim3 = dim3 - (dim3 % 16)
dim4 = dim4 - (dim4 % 16)
for i in range(k):
shapeA = (dim1, dim3, dim2) if transpose[0] else (dim1, dim2, dim3)
shapeB = (dim1, dim4, dim3) if transpose[1] else (dim1, dim3, dim4)
A = torch.randint(-128, 127, size=shapeA, device="cuda").to(torch.int8)
B = torch.randint(-128, 127, size=shapeB, device="cuda").to(torch.int8)

if not transpose[0] and not transpose[1]:
out2 = torch.bmm(A.float(), B.float())
out = F.igemm(A, B)
elif not transpose[0] and transpose[1]:
out2 = torch.bmm(A.float(), B.permute([0, 2, 1]).float())
out = F.igemm(A, B.permute([0, 2, 1]))
elif transpose[0] and not transpose[1]:
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.float())
out = F.igemm(A.permute([0, 2, 1]), B)
elif transpose[0] and transpose[1]:
out2 = torch.bmm(A.permute([0, 2, 1]).float(), B.permute([0, 2, 1]).float())
out = F.igemm(A.permute([0, 2, 1]), B.permute([0, 2, 1]))
torch.testing.assert_close(out.float(), out2.float())


class TestLLMInt8Functional:
@staticmethod
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half):
Expand Down Expand Up @@ -723,6 +446,8 @@ def test_dequant_mm(self, device, dim1, dim4, dims, has_bias):
n = C5.numel()
assert_all_approx_close(C1, C4, atol=0.015, rtol=0.1, count=int(0.01 * n))

# Keep CUDA-only coverage for int8_double_quant during deprecation cycle.
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
@pytest.mark.parametrize("dim1", [2048, 4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [512, 1024], ids=id_formatter("dim2"))
def test_int8_double_quant(self, dim1, dim2):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main one where I hesitate because we may still have external users calling int8_double_quant - even if they're not actually training int8 weights.

There's a tip in the docs about this but this is one where we haven't emitted deprecation warnings on either.

<Tip>
    This function is useful for training, but for inference it is advised to use [`int8_vectorwise_quant`] instead.
    This implementation performs additional column-wise transposed calculations which are not optimized.
</Tip>

My ask here is that for now we change this one to only test on CUDA and give this one a short deprecation cycle before removing. This test doesn't necessarily assume int8 training and I believe is our only coverage on the op.

Expand Down
Loading
Loading