[Bug] Enforce contiguous input for dynamic_scaled_fp8_quant and static_scaled_fp8_quant (#21773)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-07-28 15:55:48 -04:00 committed by GitHub
parent b361f14e39
commit e0e58f9729
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1282,10 +1282,11 @@ def scaled_fp8_quant(
output, input.contiguous(), scale, scale_ub)
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
torch.ops._C.dynamic_scaled_fp8_quant(output, input.contiguous(),
scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
torch.ops._C.static_scaled_fp8_quant(output, input.contiguous(), scale)
return output, scale