[Bugfix][Hardware][AMD] Fix FP8 dtype in silu_mul quantization (#31179)

Signed-off-by: c0de128 <kevin.mckay@outlook.com>
This commit is contained in:
Kevin McKay 2025-12-24 09:37:11 -06:00 committed by GitHub
parent 1ff67df182
commit 66c9887440
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -625,8 +625,9 @@ def silu_mul_per_token_group_quant_fp8_colmajor(
M, N = input.size()
N_2 = N // 2
fp8_dtype = current_platform.fp8_dtype()
if output is None:
output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device)
output = torch.empty((M, N_2), dtype=fp8_dtype, device=input.device)
output_scales = torch.empty(
((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device
@ -637,9 +638,12 @@ def silu_mul_per_token_group_quant_fp8_colmajor(
assert M % BLOCK_M == 0
assert N_2 % BLOCK_N == 0
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_min = finfo.min
fp8_max = finfo.max
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
# platforms that use the torch.float8_e4m3fnuz dtype.
finfo = torch.finfo(fp8_dtype)
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max
# Force even division so we can avoid edgecases within the kernel.
assert M % BLOCK_M == 0