mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 21:14:30 +08:00
[Bugfix] Enable padded FP4 quantization (#25947)
Signed-off-by: Roi Koren <roik@nvidia.com>
This commit is contained in:
parent
0d37450eb7
commit
4069db3f2e
@ -1384,7 +1384,7 @@ def scaled_fp4_quant(
|
||||
rounded_m = round_up(m, 128)
|
||||
scale_n = n // block_size
|
||||
rounded_n = round_up(scale_n, 4)
|
||||
output_scale = torch.empty(
|
||||
output_scale = torch.zeros(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
|
||||
@ -386,8 +386,6 @@ def flashinfer_scaled_fp4_mm(
|
||||
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
|
||||
assert a.stride(-1) == 1 and b.stride(-1) == 1
|
||||
assert a.shape[1] == b.shape[1]
|
||||
assert block_scale_a.shape[1] == a.shape[1] // 8
|
||||
assert block_scale_b.shape[1] == b.shape[1] // 8
|
||||
|
||||
if backend == "cutlass":
|
||||
block_scale_a = block_scale_a.view(torch.uint8)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user