[Bugfix] Enforce contiguous input for dynamic_per_token FP8/INT8 quant (#19452)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin 2025-06-12 15:39:15 -04:00 committed by GitHub
parent 9d880f594d
commit a3319f4f04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1270,7 +1270,7 @@ def scaled_fp8_quant(
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale, scale_ub)
output, input.contiguous(), scale, scale_ub)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
@ -1379,8 +1379,8 @@ def scaled_int8_quant(
dtype=torch.float32)
input_azp = None if symmetric else torch.empty_like(input_scales,
dtype=torch.int32)
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
input_azp)
torch.ops._C.dynamic_scaled_int8_quant(output, input.contiguous(),
input_scales, input_azp)
return output, input_scales, input_azp