mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[ROCM] Fix ROCm warnings, environment flag access, and GEMM kernel naming for consistency in _aiter_ops.py (#28464)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
74a9a9faad
commit
d8140b9833
@ -32,13 +32,13 @@ def if_aiter_supported(func: Callable) -> Callable:
|
||||
def wrapper(*args, **kwargs):
|
||||
# checks the platform, device arch and aiter library existance.
|
||||
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
if current_platform.is_rocm() and IS_AITER_FOUND:
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
|
||||
if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND:
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
# Return None or do nothing if not supported
|
||||
return None
|
||||
if on_gfx9():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -296,7 +296,7 @@ def _rocm_aiter_mla_decode_fwd_fake(
|
||||
pass
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_impl(
|
||||
def _rocm_aiter_gemm_a8w8_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -313,7 +313,7 @@ def _rocm_aiter_gemm_w8a8_impl(
|
||||
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_fake(
|
||||
def _rocm_aiter_gemm_a8w8_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -327,7 +327,7 @@ def _rocm_aiter_gemm_w8a8_fake(
|
||||
return Y
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
def _rocm_aiter_gemm_a8w8_blockscale_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -339,7 +339,7 @@ def _rocm_aiter_gemm_w8a8_blockscale_impl(
|
||||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||
|
||||
|
||||
def _rocm_aiter_gemm_w8a8_blockscale_fake(
|
||||
def _rocm_aiter_gemm_a8w8_blockscale_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -419,6 +419,7 @@ class rocm_aiter_ops:
|
||||
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
@ -494,6 +495,11 @@ class rocm_aiter_ops:
|
||||
def is_triton_rotary_embed_enabled(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_triton_gemm_enabled(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM
|
||||
|
||||
@staticmethod
|
||||
@if_aiter_supported
|
||||
def register_ops_once() -> None:
|
||||
@ -555,18 +561,18 @@ class rocm_aiter_ops:
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8",
|
||||
op_func=_rocm_aiter_gemm_w8a8_impl,
|
||||
op_name="rocm_aiter_gemm_a8w8",
|
||||
op_func=_rocm_aiter_gemm_a8w8_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_gemm_w8a8_fake,
|
||||
fake_impl=_rocm_aiter_gemm_a8w8_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
||||
op_func=_rocm_aiter_gemm_w8a8_blockscale_impl,
|
||||
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
||||
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake,
|
||||
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -606,7 +612,7 @@ class rocm_aiter_ops:
|
||||
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
||||
|
||||
@staticmethod
|
||||
def gemm_w8a8(
|
||||
def gemm_a8w8(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -614,10 +620,10 @@ class rocm_aiter_ops:
|
||||
bias: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype)
|
||||
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype)
|
||||
|
||||
@staticmethod
|
||||
def gemm_w8a8_blockscale(
|
||||
def gemm_a8w8_blockscale(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@ -625,7 +631,7 @@ class rocm_aiter_ops:
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
||||
return torch.ops.vllm.rocm_aiter_gemm_a8w8_blockscale(
|
||||
A, B, As, Bs, output_dtype
|
||||
)
|
||||
|
||||
@ -938,5 +944,4 @@ class rocm_aiter_ops:
|
||||
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
||||
|
||||
|
||||
if IS_AITER_FOUND:
|
||||
rocm_aiter_ops.register_ops_once()
|
||||
rocm_aiter_ops.register_ops_once()
|
||||
|
||||
@ -117,4 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
||||
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
||||
|
||||
@ -328,7 +328,7 @@ class W8A8BlockFp8LinearOp:
|
||||
if use_triton:
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
|
||||
else:
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale
|
||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
|
||||
|
||||
if input_scale is not None:
|
||||
q_input = input_2d
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
@ -105,8 +106,7 @@ def default_unquantized_gemm(
|
||||
|
||||
def use_aiter_triton_gemm(n, m, k, dtype):
|
||||
if (
|
||||
envs.VLLM_ROCM_USE_AITER == 0
|
||||
or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0
|
||||
not rocm_aiter_ops.is_triton_gemm_enabled()
|
||||
# MI300's - fp8nuz=True
|
||||
or current_platform.is_fp8_fnuz()
|
||||
or dtype not in [torch.float16, torch.bfloat16]
|
||||
|
||||
@ -325,6 +325,7 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
@ -332,9 +333,7 @@ class RocmPlatform(Platform):
|
||||
parallel_config = vllm_config.parallel_config
|
||||
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
||||
|
||||
use_aiter_rms_norm = (
|
||||
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
)
|
||||
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user