[Bugfix] Fix mismatched nvfp4 gemm output shape (#29742)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-11-30 17:15:01 +08:00 committed by GitHub
parent 2afcec4dec
commit 47539cfd3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -184,7 +184,7 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
return out
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight_packed.shape[0]]
output_shape = [*x.shape[:-1], layer.weight_packed.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)