Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions agents/coordinator_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ project before making changes so you can verify your setup works.
- Information about existing PRs — what they change, whether they look correct
- Anything else the worker agent should know

**5. Your recommended approach.** What you think the fix should look like. Be specific — name files, functions, line numbers. Frame it as guidance, not commands — the worker agent may find things you didn't and should use its own judgment.
**5. Your recommended approach.** What you think the fix should look like. Be specific — name files, functions, line numbers. Frame it as guidance, not commands — the worker agent may find things you didn't and should use its own judgment. Include which specific test file(s) or test function(s) the agent should run to verify its fix — not the full suite.

**6. Completion workflow.** Every prompt file must include this section verbatim, with the issue number filled in:

Expand All @@ -130,20 +130,29 @@ project before making changes so you can verify your setup works.

After implementing and verifying the fix:

1. **Commit** your changes with a message referencing the issue:
1. **Run only the tests relevant to your change.** Do NOT run the full
test suite — it takes 10+ minutes and will be run separately later.
Instead, run the specific test file(s) that cover the code you changed:

pytest tests/test_autograd.py -v --tb=short -k "relevant_test_name"

If you wrote a new test, run that plus the existing tests in the same
file to check for regressions in that area.

2. **Commit** your changes with a message referencing the issue:

git add <files>
git commit -m "Fix <brief description> (#<NUMBER>)"

2. **Push** the branch:
3. **Push** the branch:

git push -u origin fix/issue-<NUMBER>

3. **Create a pull request** with `gh pr create`. The PR body must
4. **Create a pull request** with `gh pr create`. The PR body must
include "Fixes #<NUMBER>" so GitHub auto-links and auto-closes the
issue on merge. Describe what the fix does and how you verified it.

4. **Post to the bitsandbytes Slack channel** to notify the team.
5. **Post to the bitsandbytes Slack channel** to notify the team.
Write a temporary Python script to `/tmp/slack_notify.py` and run it:

import json, urllib.request, sys
Expand Down Expand Up @@ -245,7 +254,8 @@ whether it is correct and complete before implementing from scratch.

## When You Are Done

[the standard completion workflow section with issue number 1810 filled in]
[the standard completion workflow section with issue number 1810 filled in.
Remember: tell the agent to run only the relevant tests, not the full suite.]

## What NOT to Do

Expand Down
193 changes: 97 additions & 96 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,58 +1117,65 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
relerr = (err / (A1.abs().float() + 1e-8)).mean()
err = err.mean()

# The following values were taken from averaging 1k samples per test configuration.
error_dict = dict()
error_dict["fp4"] = dict()
error_dict["nf4"] = dict()
error_dict["fp4"]["err"] = {
32: 0.088918,
64: 0.096545,
128: 0.102947,
256: 0.108685,
512: 0.114087,
1024: 0.119312,
2048: 0.124460,
4096: 0.129573,
# Expected (mean, std) per configuration, from 200 samples on RTX 4090.
# Thresholds are set at mean + N_SIGMA * std to avoid flaky failures
# while still catching real regressions. Worst-case std across dtypes is used.
N_SIGMA = 7
error_stats = {
"fp4": {
"err": {
32: (0.088925, 0.000091),
64: (0.096543, 0.000111),
128: (0.102969, 0.000134),
256: (0.108684, 0.000182),
512: (0.114115, 0.000234),
1024: (0.119333, 0.000320),
2048: (0.124556, 0.000455),
4096: (0.129536, 0.000612),
},
"rel_err": {
32: (0.242443, 0.000330),
64: (0.260125, 0.000379),
128: (0.275817, 0.000433),
256: (0.289831, 0.000497),
512: (0.302881, 0.000583),
1024: (0.315000, 0.000757),
2048: (0.326607, 0.000955),
4096: (0.337169, 0.001239),
},
},
"nf4": {
"err": {
32: (0.067746, 0.000069),
64: (0.072798, 0.000074),
128: (0.076831, 0.000091),
256: (0.080337, 0.000102),
512: (0.083547, 0.000143),
1024: (0.086610, 0.000187),
2048: (0.089592, 0.000251),
4096: (0.092547, 0.000360),
},
"rel_err": {
32: (0.189726, 0.000304),
64: (0.203339, 0.000340),
128: (0.215237, 0.000391),
256: (0.226105, 0.000398),
512: (0.236079, 0.000544),
1024: (0.245370, 0.000600),
2048: (0.254163, 0.000747),
4096: (0.262473, 0.000999),
},
},
}
error_dict["fp4"]["rel_err"] = {
32: 0.242380,
64: 0.260130,
128: 0.275734,
256: 0.289842,
512: 0.302852,
1024: 0.314982,
2048: 0.326402,
4096: 0.337228,
}

error_dict["nf4"]["err"] = {
32: 0.067745,
64: 0.072792,
128: 0.076835,
256: 0.080326,
512: 0.083535,
1024: 0.086603,
2048: 0.089592,
4096: 0.092537,
}
error_dict["nf4"]["rel_err"] = {
32: 0.189700,
64: 0.203299,
128: 0.215252,
256: 0.226044,
512: 0.236021,
1024: 0.245365,
2048: 0.254146,
4096: 0.262457,
}

# Allow higher tolerance for fp32 on CPU with larger block sizes
reltol = 2.8e-3 if dtype == torch.float32 and blocksize >= 128 and device == "cpu" else 1e-3
errtol = 1.2e-3 if dtype == torch.float32 and blocksize >= 1024 and device == "cpu" else 1e-3

assert err < error_dict[quant_type]["err"][blocksize] + errtol
assert relerr < error_dict[quant_type]["rel_err"][blocksize] + reltol
err_mean, err_std = error_stats[quant_type]["err"][blocksize]
relerr_mean, relerr_std = error_stats[quant_type]["rel_err"][blocksize]
assert err < err_mean + N_SIGMA * err_std, (
f"abs error {err:.6f} exceeds {err_mean:.6f} + {N_SIGMA}*{err_std:.6f}"
)
assert relerr < relerr_mean + N_SIGMA * relerr_std, (
f"rel error {relerr:.6f} exceeds {relerr_mean:.6f} + {N_SIGMA}*{relerr_std:.6f}"
)

@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
Expand Down Expand Up @@ -1374,61 +1381,55 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, double_quant, kind):
relratio = relerr2 / relerr3
maxratio = relerr2 / relerr3

# for debugging if the tests fails
#
# print('='*80)
# print(f'For matmul: {A.shape}, {B.shape}, {kind}, {dtype}, {storage_type}, double_quant={double_quant}:')
# print(C1.flatten()[-20:])
# print(C2.flatten()[-20:])
# print(f'inference vs training abs: {err1}')
# print(f'inference vs training rel: {relerr1}')
# print(f'inference vs training max: {maxerr1}')
# print(f'inference vs training vs torch err ratio abs: {absratio}')
# print(f'inference vs training vs torch err ratio rel: {relratio}')
# print(f'inference vs training vs torch err ratio max: {maxratio}')
# Expected (mean, std) for err1, relerr1, maxerr1 per dtype/dim group.
# Measured from 100 iterations x all storage_type/kind/DQ combos on RTX 4090.
# std is for individual iterations (not the average), so thresholds are generous
# enough to accommodate GPU architecture differences (e.g., T4, XPU, Blackwell).
N_SIGMA = 7
gemv_thresholds = {
torch.float16: {
"le512": {
"err1": (0.000052, 0.0000063),
"relerr1": (0.00024, 0.000357),
"maxerr1": (0.00042, 0.0000687),
},
"gt512": {
"err1": (0.000018, 0.0000028),
"relerr1": (0.00010, 0.000197),
"maxerr1": (0.00017, 0.0000179),
},
},
torch.float32: {
"le512": {"err1": (2e-8, 2e-9), "relerr1": (8e-7, 1.2e-6), "maxerr1": (6e-8, 2e-8)},
"gt512": {"err1": (1e-8, 2e-9), "relerr1": (5e-7, 1.6e-7), "maxerr1": (4e-8, 1e-8)},
},
torch.bfloat16: {
"le512": {"err1": (0.00042, 0.000059), "relerr1": (0.0041, 0.01153), "maxerr1": (0.0037, 0.000556)},
"gt512": {"err1": (0.00014, 0.0000095), "relerr1": (0.0012, 0.000679), "maxerr1": (0.0010, 0.000137)},
},
}

dim_key = "le512" if dim <= 512 else "gt512"
thresholds = gemv_thresholds[dtype][dim_key]
for metric_name, metric_val in [("err1", err1), ("relerr1", relerr1), ("maxerr1", maxerr1)]:
mean_val, std_val = thresholds[metric_name]
limit = mean_val + N_SIGMA * std_val
assert metric_val < limit, (
f"{metric_name}={metric_val:.8f} exceeds {mean_val:.8f} + {N_SIGMA}*{std_val:.8f} = {limit:.8f} "
f"for {dtype}, dim={dim}, {storage_type}, DQ={double_quant}, {kind}"
)

# Ratios check that gemv_4bit and matmul_4bit produce consistent results.
# These are tight bounds on internal consistency, not absolute accuracy.
if dtype == torch.float16:
if dim <= 512:
assert err1 < 7e-5

# TODO(matthewdouglas): On T4, dim=128-fp16-fc2-fp4-DQ will have relerror ~ 0.00092727
if (
device == "cuda"
and double_quant
and storage_type == "fp4"
and kind == "fc2"
and torch.cuda.get_device_capability() == (7, 5)
):
assert relerr1 < 0.00093
else:
assert relerr1 < 0.0008
else:
assert err1 < 6e-5
assert relerr1 < 2e-4
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.005 and relratio > 0.992
assert maxratio < 1.005 and maxratio > 0.992
elif dtype == torch.float32:
if dim <= 512:
assert err1 < 5e-8
assert relerr1 < 1e-6
assert maxerr1 < 1.05e-7
else:
assert err1 < 5e-8
assert relerr1 < 8e-6
assert maxerr1 < 1e-7
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.005 and relratio > 0.995
assert maxratio < 1.005 and maxratio > 0.995
elif dtype == torch.bfloat16:
if dim <= 512:
relerr_thres = 0.013 if hasattr(torch, "xpu") and torch.xpu.is_available() else 0.007
assert err1 < 6e-4
assert relerr1 < relerr_thres
assert maxerr1 < 0.015
else:
assert err1 < 2e-4
assert relerr1 < 0.002
assert maxerr1 < 0.0012
assert absratio < 1.005 and absratio > 0.995
assert relratio < 1.05 and relratio > 0.96
assert maxratio < 1.05 and maxratio > 0.97
Expand Down
4 changes: 1 addition & 3 deletions tests/test_linear4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,7 @@ def test_quant_storage_shard_roundtrip(device, quant_type, quant_storage):
reassembled = torch.cat(shards).reshape(qB.shape)

assert reassembled.dtype == qB.dtype
assert torch.equal(
reassembled.view(torch.uint8), qB.view(torch.uint8)
), "Bytes changed after shard roundtrip"
assert torch.equal(reassembled.view(torch.uint8), qB.view(torch.uint8)), "Bytes changed after shard roundtrip"

out = bnb.functional.gemv_4bit(A, reassembled.t(), state=state)
torch.testing.assert_close(out, ref)
Expand Down
72 changes: 49 additions & 23 deletions tests/test_parametrize.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,30 @@ def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics,
relerr = (err / (original_param.abs().float() + 1e-8)).mean()
err_mean = err.mean()

# Expected error bounds from test_functional.py
# Expected (mean, std) from 200 samples on RTX 4090. Worst-case std across dtypes.
# Threshold = mean + N_SIGMA * std avoids flaky failures across GPU architectures.
N_SIGMA = 7
expected_errors = {
"nf4": {
64: {"abs": 0.072792, "rel": 0.203299},
128: {"abs": 0.076835, "rel": 0.215252},
256: {"abs": 0.080326, "rel": 0.226044},
64: {"abs": (0.072796, 0.000072), "rel": (0.203353, 0.000326)},
128: {"abs": (0.076839, 0.000093), "rel": (0.215258, 0.000367)},
256: {"abs": (0.080322, 0.000100), "rel": (0.226056, 0.000392)},
},
"fp4": {
64: {"abs": 0.096545, "rel": 0.260130},
128: {"abs": 0.102947, "rel": 0.275734},
256: {"abs": 0.108685, "rel": 0.289842},
64: {"abs": (0.096547, 0.000112), "rel": (0.260144, 0.000379)},
128: {"abs": (0.102949, 0.000138), "rel": (0.275763, 0.000391)},
256: {"abs": (0.108681, 0.000177), "rel": (0.289835, 0.000507)},
},
}

assert err_mean < expected_errors[quant_type][blocksize]["abs"] + 1e-3, f"Mean abs error {err_mean:.6f} too high"
assert relerr < expected_errors[quant_type][blocksize]["rel"] + 1e-3, f"Mean rel error {relerr:.6f} too high"
abs_mean, abs_std = expected_errors[quant_type][blocksize]["abs"]
rel_mean, rel_std = expected_errors[quant_type][blocksize]["rel"]
assert err_mean < abs_mean + N_SIGMA * abs_std, (
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
)
assert relerr < rel_mean + N_SIGMA * rel_std, (
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
)


@pytest.mark.parametrize("device", get_available_devices())
Expand Down Expand Up @@ -120,12 +128,17 @@ def __init__(self, device, dtype):
relerr = (err / (original_param.abs().float() + 1e-8)).mean()
err_mean = err.mean()

# Use slightly looser bounds for higher dimensional tensors
abs_bound = 0.085 # NF4 baseline + margin
rel_bound = 0.25 # NF4 baseline + margin
# Expected (mean, std) for NF4 on MoE-shaped tensors (8x512x256), from 200 samples on RTX 4090.
N_SIGMA = 7
abs_mean, abs_std = 0.072802, 0.000072
rel_mean, rel_std = 0.203327, 0.000312

assert err_mean < abs_bound, f"Mean abs error {err_mean:.6f} too high for shape {param_shape}"
assert relerr < rel_bound, f"Mean rel error {relerr:.6f} too high for shape {param_shape}"
assert err_mean < abs_mean + N_SIGMA * abs_std, (
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
)
assert relerr < rel_mean + N_SIGMA * rel_std, (
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
)


@pytest.mark.parametrize("device", get_available_devices())
Expand Down Expand Up @@ -349,14 +362,19 @@ def test_different_blocksizes(device, dtype, blocksize):
relerr = (err / (original_param.abs().float() + 1e-8)).mean()
err_mean = err.mean()

# Expected error bounds from functional tests (using NF4 bounds since that's what we're testing)
expected_abs = {64: 0.072792, 128: 0.076835, 256: 0.080326}
expected_rel = {64: 0.203299, 128: 0.215252, 256: 0.226044}
# Expected (mean, std) for NF4, from 200 samples on RTX 4090. Worst-case std across dtypes.
N_SIGMA = 7
expected_abs = {64: (0.072796, 0.000072), 128: (0.076839, 0.000093), 256: (0.080322, 0.000100)}
expected_rel = {64: (0.203353, 0.000326), 128: (0.215258, 0.000367), 256: (0.226056, 0.000392)}

assert err_mean < expected_abs[blocksize] + 0.01, (
f"Mean abs error {err_mean:.6f} too high for blocksize {blocksize}"
abs_mean, abs_std = expected_abs[blocksize]
rel_mean, rel_std = expected_rel[blocksize]
assert err_mean < abs_mean + N_SIGMA * abs_std, (
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f} for blocksize {blocksize}"
)
assert relerr < rel_mean + N_SIGMA * rel_std, (
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f} for blocksize {blocksize}"
)
assert relerr < expected_rel[blocksize] + 0.02, f"Mean rel error {relerr:.6f} too high for blocksize {blocksize}"


def test_parametrization_forward_method():
Expand All @@ -383,9 +401,17 @@ def test_parametrization_forward_method():
relerr = (err / (original_tensor.abs().float() + 1e-8)).mean()
err_mean = err.mean()

# Use NF4 bounds from functional tests with small margin
assert err_mean < 0.08, f"Mean abs error {err_mean:.6f} too high"
assert relerr < 0.25, f"Mean rel error {relerr:.6f} too high"
# Expected (mean, std) for NF4 on small 64x64 tensor, from 200 samples on RTX 4090.
# Small tensors have higher variance due to fewer blocks in the quantization.
N_SIGMA = 7
abs_mean, abs_std = 0.072842, 0.001180
rel_mean, rel_std = 0.202648, 0.004729
assert err_mean < abs_mean + N_SIGMA * abs_std, (
f"Mean abs error {err_mean:.6f} exceeds {abs_mean:.6f} + {N_SIGMA}*{abs_std:.6f}"
)
assert relerr < rel_mean + N_SIGMA * rel_std, (
f"Mean rel error {relerr:.6f} exceeds {rel_mean:.6f} + {N_SIGMA}*{rel_std:.6f}"
)


@pytest.mark.parametrize("device", get_available_devices())
Expand Down