mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 02:46:59 +08:00
[ROCm] Add missing gemm_a8w8_blockscale import (#28378)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
30700b1cd7
commit
021143561f
@ -316,38 +316,39 @@ class W8A8BlockFp8LinearOp:
|
|||||||
assert self.act_quant_group_shape == GroupShape(1, 128)
|
assert self.act_quant_group_shape == GroupShape(1, 128)
|
||||||
|
|
||||||
n, k = weight.shape
|
n, k = weight.shape
|
||||||
if input_scale is not None:
|
|
||||||
q_input = input_2d
|
|
||||||
|
|
||||||
# MI350 case uses triton kernel
|
use_triton = (
|
||||||
if (
|
|
||||||
not current_platform.is_fp8_fnuz()
|
not current_platform.is_fp8_fnuz()
|
||||||
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
|
and rocm_aiter_ops.is_triton_gemm_w8a8_tuned(n, k)
|
||||||
):
|
)
|
||||||
|
|
||||||
|
if use_triton:
|
||||||
|
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
|
||||||
|
else:
|
||||||
|
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale
|
||||||
|
|
||||||
|
if input_scale is not None:
|
||||||
|
q_input = input_2d
|
||||||
|
# MI350 case uses triton kernel
|
||||||
|
elif use_triton:
|
||||||
q_input, input_scale = per_token_group_quant_fp8(
|
q_input, input_scale = per_token_group_quant_fp8(
|
||||||
input_2d,
|
input_2d,
|
||||||
self.act_quant_group_shape.col,
|
self.act_quant_group_shape.col,
|
||||||
column_major_scales=False,
|
column_major_scales=False,
|
||||||
use_ue8m0=False,
|
use_ue8m0=False,
|
||||||
)
|
)
|
||||||
return rocm_aiter_ops.triton_gemm_a8w8_blockscale(
|
|
||||||
q_input,
|
|
||||||
weight,
|
|
||||||
input_scale,
|
|
||||||
weight_scale,
|
|
||||||
input_2d.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# MI300 uses tuned AITER ASM/C++ kernel
|
# MI300 uses tuned AITER ASM/C++ kernel
|
||||||
else:
|
else:
|
||||||
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
|
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
|
||||||
return rocm_aiter_ops.gemm_w8a8_blockscale(
|
|
||||||
q_input,
|
return gemm_a8w8_blockscale_op(
|
||||||
weight,
|
q_input,
|
||||||
input_scale,
|
weight,
|
||||||
weight_scale,
|
input_scale,
|
||||||
input_2d.dtype,
|
weight_scale,
|
||||||
)
|
list(self.weight_group_shape),
|
||||||
|
output_dtype=input_2d.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
def _run_triton(
|
def _run_triton(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user