static_scaled_fp8_quant should not run when scale.numel is not 1 (#20076)

This commit is contained in:
Eldar Kurtić 2025-06-25 21:08:03 +02:00 committed by GitHub
parent 23a04e0895
commit 8b8c209e35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1276,7 +1276,7 @@ def scaled_fp8_quant(
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
# num_token_padding not implemented for this case
assert (scale.numel() == 1 or num_token_padding is None)
assert (scale.numel() == 1 and num_token_padding is None)
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale