mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-14 23:37:05 +08:00
[Bugfix][Hardware][AMD] Fix FP8 dtype in silu_mul quantization (#31179)
Signed-off-by: c0de128 <kevin.mckay@outlook.com>
This commit is contained in:
parent
1ff67df182
commit
66c9887440
@ -625,8 +625,9 @@ def silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
M, N = input.size()
|
||||
N_2 = N // 2
|
||||
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
if output is None:
|
||||
output = torch.empty((M, N_2), dtype=torch.float8_e4m3fn, device=input.device)
|
||||
output = torch.empty((M, N_2), dtype=fp8_dtype, device=input.device)
|
||||
|
||||
output_scales = torch.empty(
|
||||
((N_2 // GROUP_SIZE), M), dtype=torch.float32, device=input.device
|
||||
@ -637,9 +638,12 @@ def silu_mul_per_token_group_quant_fp8_colmajor(
|
||||
assert M % BLOCK_M == 0
|
||||
assert N_2 % BLOCK_N == 0
|
||||
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm
|
||||
# platforms that use the torch.float8_e4m3fnuz dtype.
|
||||
finfo = torch.finfo(fp8_dtype)
|
||||
fp8_min = -224.0 if current_platform.is_fp8_fnuz() else finfo.min
|
||||
fp8_max = 224.0 if current_platform.is_fp8_fnuz() else finfo.max
|
||||
|
||||
# Force even division so we can avoid edgecases within the kernel.
|
||||
assert M % BLOCK_M == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user