mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:55:40 +08:00
[ROCM] Fix ROCm warnings, environment flag access, and GEMM kernel naming for consistency in _aiter_ops.py (#28464)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
74a9a9faad
commit
d8140b9833
@ -32,13 +32,13 @@ def if_aiter_supported(func: Callable) -> Callable:
|
|||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
# checks the platform, device arch and aiter library existance.
|
# checks the platform, device arch and aiter library existance.
|
||||||
|
|
||||||
from vllm.platforms.rocm import on_gfx9
|
if current_platform.is_rocm() and IS_AITER_FOUND:
|
||||||
|
from vllm.platforms.rocm import on_gfx9
|
||||||
|
|
||||||
if current_platform.is_rocm() and on_gfx9() and IS_AITER_FOUND:
|
if on_gfx9():
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
else:
|
|
||||||
# Return None or do nothing if not supported
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
@ -296,7 +296,7 @@ def _rocm_aiter_mla_decode_fwd_fake(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_gemm_w8a8_impl(
|
def _rocm_aiter_gemm_a8w8_impl(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
As: torch.Tensor,
|
As: torch.Tensor,
|
||||||
@ -313,7 +313,7 @@ def _rocm_aiter_gemm_w8a8_impl(
|
|||||||
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
|
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_gemm_w8a8_fake(
|
def _rocm_aiter_gemm_a8w8_fake(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
As: torch.Tensor,
|
As: torch.Tensor,
|
||||||
@ -327,7 +327,7 @@ def _rocm_aiter_gemm_w8a8_fake(
|
|||||||
return Y
|
return Y
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_gemm_w8a8_blockscale_impl(
|
def _rocm_aiter_gemm_a8w8_blockscale_impl(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
As: torch.Tensor,
|
As: torch.Tensor,
|
||||||
@ -339,7 +339,7 @@ def _rocm_aiter_gemm_w8a8_blockscale_impl(
|
|||||||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_gemm_w8a8_blockscale_fake(
|
def _rocm_aiter_gemm_a8w8_blockscale_fake(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
As: torch.Tensor,
|
As: torch.Tensor,
|
||||||
@ -419,6 +419,7 @@ class rocm_aiter_ops:
|
|||||||
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
_FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||||
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
_TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE
|
||||||
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
|
||||||
|
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
@ -494,6 +495,11 @@ class rocm_aiter_ops:
|
|||||||
def is_triton_rotary_embed_enabled(cls) -> bool:
|
def is_triton_rotary_embed_enabled(cls) -> bool:
|
||||||
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
|
return cls._AITER_ENABLED and cls._TRITON_ROTARY_EMBED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_triton_gemm_enabled(cls) -> bool:
|
||||||
|
return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def register_ops_once() -> None:
|
def register_ops_once() -> None:
|
||||||
@ -555,18 +561,18 @@ class rocm_aiter_ops:
|
|||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_gemm_w8a8",
|
op_name="rocm_aiter_gemm_a8w8",
|
||||||
op_func=_rocm_aiter_gemm_w8a8_impl,
|
op_func=_rocm_aiter_gemm_a8w8_impl,
|
||||||
mutates_args=[],
|
mutates_args=[],
|
||||||
fake_impl=_rocm_aiter_gemm_w8a8_fake,
|
fake_impl=_rocm_aiter_gemm_a8w8_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
direct_register_custom_op(
|
direct_register_custom_op(
|
||||||
op_name="rocm_aiter_gemm_w8a8_blockscale",
|
op_name="rocm_aiter_gemm_a8w8_blockscale",
|
||||||
op_func=_rocm_aiter_gemm_w8a8_blockscale_impl,
|
op_func=_rocm_aiter_gemm_a8w8_blockscale_impl,
|
||||||
mutates_args=[],
|
mutates_args=[],
|
||||||
fake_impl=_rocm_aiter_gemm_w8a8_blockscale_fake,
|
fake_impl=_rocm_aiter_gemm_a8w8_blockscale_fake,
|
||||||
dispatch_key=current_platform.dispatch_key,
|
dispatch_key=current_platform.dispatch_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -606,7 +612,7 @@ class rocm_aiter_ops:
|
|||||||
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
return torch.ops.vllm.rocm_aiter_rms_norm(x, weight, variance_epsilon)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gemm_w8a8(
|
def gemm_a8w8(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
As: torch.Tensor,
|
As: torch.Tensor,
|
||||||
@ -614,10 +620,10 @@ class rocm_aiter_ops:
|
|||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
output_dtype: torch.dtype = torch.float16,
|
output_dtype: torch.dtype = torch.float16,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8(A, B, As, Bs, bias, output_dtype)
|
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gemm_w8a8_blockscale(
|
def gemm_a8w8_blockscale(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
As: torch.Tensor,
|
As: torch.Tensor,
|
||||||
@ -625,7 +631,7 @@ class rocm_aiter_ops:
|
|||||||
block_size: list[int],
|
block_size: list[int],
|
||||||
output_dtype: torch.dtype = torch.float16,
|
output_dtype: torch.dtype = torch.float16,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale(
|
return torch.ops.vllm.rocm_aiter_gemm_a8w8_blockscale(
|
||||||
A, B, As, Bs, output_dtype
|
A, B, As, Bs, output_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -938,5 +944,4 @@ class rocm_aiter_ops:
|
|||||||
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
||||||
|
|
||||||
|
|
||||||
if IS_AITER_FOUND:
|
rocm_aiter_ops.register_ops_once()
|
||||||
rocm_aiter_ops.register_ops_once()
|
|
||||||
|
|||||||
@ -117,4 +117,4 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
|||||||
# a to be [M, K]
|
# a to be [M, K]
|
||||||
# b to be [N, K]
|
# b to be [N, K]
|
||||||
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||||
return rocm_aiter_ops.gemm_w8a8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
||||||
|
|||||||
@ -328,7 +328,7 @@ class W8A8BlockFp8LinearOp:
|
|||||||
if use_triton:
|
if use_triton:
|
||||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
|
gemm_a8w8_blockscale_op = rocm_aiter_ops.triton_gemm_a8w8_blockscale
|
||||||
else:
|
else:
|
||||||
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_w8a8_blockscale
|
gemm_a8w8_blockscale_op = rocm_aiter_ops.gemm_a8w8_blockscale
|
||||||
|
|
||||||
if input_scale is not None:
|
if input_scale is not None:
|
||||||
q_input = input_2d
|
q_input = input_2d
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
@ -105,8 +106,7 @@ def default_unquantized_gemm(
|
|||||||
|
|
||||||
def use_aiter_triton_gemm(n, m, k, dtype):
|
def use_aiter_triton_gemm(n, m, k, dtype):
|
||||||
if (
|
if (
|
||||||
envs.VLLM_ROCM_USE_AITER == 0
|
not rocm_aiter_ops.is_triton_gemm_enabled()
|
||||||
or envs.VLLM_ROCM_USE_AITER_TRITON_GEMM == 0
|
|
||||||
# MI300's - fp8nuz=True
|
# MI300's - fp8nuz=True
|
||||||
or current_platform.is_fp8_fnuz()
|
or current_platform.is_fp8_fnuz()
|
||||||
or dtype not in [torch.float16, torch.bfloat16]
|
or dtype not in [torch.float16, torch.bfloat16]
|
||||||
|
|||||||
@ -325,6 +325,7 @@ class RocmPlatform(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.config.compilation import CUDAGraphMode
|
from vllm.config.compilation import CUDAGraphMode
|
||||||
|
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
@ -332,9 +333,7 @@ class RocmPlatform(Platform):
|
|||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
is_eager_execution = compilation_config == CUDAGraphMode.NONE
|
||||||
|
|
||||||
use_aiter_rms_norm = (
|
use_aiter_rms_norm = rocm_aiter_ops.is_rmsnorm_enabled()
|
||||||
envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_RMSNORM
|
|
||||||
)
|
|
||||||
|
|
||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 16
|
cache_config.block_size = 16
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user