mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 19:53:35 +08:00
[FEAT] [ROCm] [AITER]: Add AITER HIP block quant kernel (#21242)
This commit is contained in:
parent
c7ffe93d9c
commit
e626d286f5
@ -82,6 +82,13 @@ if current_platform.is_rocm():
|
|||||||
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||||
|
and current_platform.is_fp8_fnuz()):
|
||||||
|
|
||||||
|
import aiter as rocm_aiter
|
||||||
|
from aiter import get_hip_quant
|
||||||
|
|
||||||
|
aiter_per1x128_quant = get_hip_quant(rocm_aiter.QuantType.per_1x128)
|
||||||
|
|
||||||
|
|
||||||
def dispatch_w8a8_blockscale_func(
|
def dispatch_w8a8_blockscale_func(
|
||||||
@ -178,8 +185,12 @@ def apply_w8a8_block_fp8_linear(
|
|||||||
block_size, input.dtype)
|
block_size, input.dtype)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
q_input, x_scale = per_token_group_quant_fp8(
|
if use_aiter_and_is_supported:
|
||||||
input_2d, block_size[1], column_major_scales=use_cutlass)
|
q_input, x_scale = aiter_per1x128_quant(
|
||||||
|
input_2d.contiguous(), quant_dtype=rocm_aiter.dtypes.fp8)
|
||||||
|
else:
|
||||||
|
q_input, x_scale = per_token_group_quant_fp8(
|
||||||
|
input_2d, block_size[1], column_major_scales=use_cutlass)
|
||||||
|
|
||||||
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale,
|
||||||
block_size, input.dtype)
|
block_size, input.dtype)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user