[Bugfix] Enable padded FP4 quantization (#25947)

Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
roikoren755 2025-10-09 20:59:41 +03:00 committed by GitHub
parent 0d37450eb7
commit 4069db3f2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 1 additions and 3 deletions

View File

@ -1384,7 +1384,7 @@ def scaled_fp4_quant(
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty(
output_scale = torch.zeros(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)

View File

@ -386,8 +386,6 @@ def flashinfer_scaled_fp4_mm(
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
assert a.stride(-1) == 1 and b.stride(-1) == 1
assert a.shape[1] == b.shape[1]
assert block_scale_a.shape[1] == a.shape[1] // 8
assert block_scale_b.shape[1] == b.shape[1] // 8
if backend == "cutlass":
block_scale_a = block_scale_a.view(torch.uint8)