Fix num_token_padding support for static per-tensor scaled_fp8_quant (#20188)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-06-28 14:48:13 +09:00 committed by GitHub
parent e53be6f00a
commit a29e62ea34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1274,8 +1274,7 @@ def scaled_fp8_quant(
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
# num_token_padding not implemented for this case
assert (scale.numel() == 1 and num_token_padding is None)
assert scale.numel() == 1
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale