Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions tests/fsdp_state_dict_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,34 @@
import bitsandbytes as bnb


def _current_accelerator_type():
if hasattr(torch, "accelerator") and torch.accelerator.is_available():
return str(torch.accelerator.current_accelerator())
if hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
if torch.cuda.is_available():
return "cuda"
return "cpu"


def _set_device_index(index: int, device_type: str):
if hasattr(torch, "accelerator"):
torch.accelerator.set_device_index(index)
return
if device_type == "cuda":
torch.cuda.set_device(index)
elif device_type == "xpu" and hasattr(torch, "xpu") and hasattr(torch.xpu, "set_device"):
torch.xpu.set_device(index)


def _get_device_and_backend():
"""Auto-detect accelerator device and distributed backend."""
device_type = _current_accelerator_type()
backend_map = {"cuda": "nccl", "xpu": "xccl"}
backend = backend_map.get(device_type, "gloo")
return device_type, backend


class SimpleQLoRAModel(nn.Module):
"""Minimal model with a frozen 4-bit base layer and a trainable adapter."""

Expand All @@ -33,15 +61,16 @@ def forward(self, x):


def main():
dist.init_process_group(backend="nccl")
device_type, backend = _get_device_and_backend()
dist.init_process_group(backend=backend)
rank = dist.get_rank()
torch.cuda.set_device(rank)
_set_device_index(rank, device_type)

errors = []

for quant_type in ("nf4", "fp4"):
model = SimpleQLoRAModel(quant_type=quant_type)
model = model.to("cuda")
model = model.to(device_type)

# Freeze quantized base weights (as in real QLoRA)
for p in model.base.parameters():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ class Test8BitBlockwiseQuantizeFunctional:
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
iters = 100

if device != "cuda":
if device not in ["cuda", "xpu"]:
iters = 10

# This test is slow in our non-CUDA implementations, so avoid atypical use cases.
# This test is slow in our non-cuda/non-xpu implementations, so avoid atypical use cases.
if nested:
pytest.skip("Not a typical use case.")
if blocksize != 256:
Expand Down
7 changes: 2 additions & 5 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,11 +569,8 @@ def test_params4bit_quant_state_attr_access(device, quant_type, compress_statist
assert w.bnb_quantized is True


@pytest.mark.skipif(not torch.cuda.is_available(), reason="FSDP requires CUDA")
@pytest.mark.skipif(
not getattr(torch.distributed, "is_nccl_available", lambda: False)(),
reason="FSDP test requires NCCL backend",
)
@pytest.mark.skipif(platform.system() == "Windows", reason="FSDP is not supported on Windows")
@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="FSDP requires an accelerator device")
def test_fsdp_state_dict_save_4bit():
"""Integration test: FSDP get_model_state_dict with cpu_offload on a 4-bit model (#1405).

Expand Down
7 changes: 4 additions & 3 deletions tests/test_linear8bitlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,9 @@ def test_linear_serialization(
assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5)


@pytest.fixture
def linear8bit(requires_cuda):
@pytest.fixture(params=get_available_devices(no_cpu=True))
def linear8bit(request):
device = request.param
linear = torch.nn.Linear(32, 96)
linear_custom = Linear8bitLt(
linear.in_features,
Expand All @@ -188,7 +189,7 @@ def linear8bit(requires_cuda):
has_fp16_weights=False,
)
linear_custom.bias = linear.bias
linear_custom = linear_custom.cuda()
linear_custom = linear_custom.to(device)
return linear_custom


Expand Down
14 changes: 8 additions & 6 deletions tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,34 +448,36 @@ def test_4bit_embedding_warnings(device, caplog):
assert any("inference" in msg for msg in caplog.messages)


def test_4bit_embedding_weight_fsdp_fix(requires_cuda):
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
def test_4bit_embedding_weight_fsdp_fix(device):
num_embeddings = 64
embedding_dim = 32

module = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=embedding_dim)

module.cuda()
module.to(device)

module.weight.quant_state = None

input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device=device)

module(input_tokens)

assert module.weight.quant_state is not None


def test_4bit_linear_weight_fsdp_fix(requires_cuda):
@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
def test_4bit_linear_weight_fsdp_fix(device):
inp_size = 64
out_size = 32

module = bnb.nn.Linear4bit(inp_size, out_size)

module.cuda()
module.to(device)

module.weight.quant_state = None

input_tensor = torch.randn((1, inp_size), device="cuda")
input_tensor = torch.randn((1, inp_size), device=device)

module(input_tensor)

Expand Down