mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 23:17:03 +08:00
Add aiter fused_add_rmsnorm_pad to gpt-oss
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
parent
25b5317ce4
commit
577128bd80
@ -651,6 +651,9 @@ class rocm_aiter_ops:
|
||||
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
||||
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||
_TRITON_FUSED_ADD_RMSNORM_PAD_ENABLED = (
|
||||
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD
|
||||
)
|
||||
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||
@ -686,6 +689,12 @@ class rocm_aiter_ops:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_triton_fused_add_rmsnorm_pad_enabled(cls) -> bool:
|
||||
""" "Verifies device specs and availability of env variable."""
|
||||
return cls._AITER_ENABLED and cls._TRITON_FUSED_ADD_RMSNORM_PAD_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_fused_moe_enabled(cls) -> bool:
|
||||
|
||||
@ -115,6 +115,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||
VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD: bool = True
|
||||
VLLM_ROCM_USE_AITER_MLA: bool = True
|
||||
VLLM_ROCM_USE_AITER_MHA: bool = False
|
||||
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
||||
@ -956,6 +957,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ROCM_USE_AITER_RMSNORM": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")
|
||||
),
|
||||
# Whether to use fused add+rmsnorm+pad kernel for gpt-oss
|
||||
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD": lambda: (
|
||||
os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD", "True").lower()
|
||||
in ("true", "1")
|
||||
),
|
||||
# Whether to use aiter mla ops.
|
||||
# By default is enabled.
|
||||
"VLLM_ROCM_USE_AITER_MLA": lambda: (
|
||||
|
||||
@ -50,6 +50,12 @@ from .utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and rocm_aiter_ops.is_triton_fused_add_rmsnorm_pad_enabled()
|
||||
):
|
||||
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
|
||||
|
||||
|
||||
class OAIAttention(nn.Module):
|
||||
def __init__(
|
||||
@ -191,7 +197,7 @@ class MLPBlock(torch.nn.Module):
|
||||
)
|
||||
else:
|
||||
g = self.router(x)
|
||||
x = self.experts(hidden_states=x, router_logits=g)
|
||||
x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
|
||||
|
||||
if self.is_sequence_parallel:
|
||||
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
|
||||
@ -222,6 +228,10 @@ class TransformerBlock(torch.nn.Module):
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
self.is_rocm_aiter_fused_add_rmsnorm_pad_enabled = (
|
||||
rocm_aiter_ops.is_triton_fused_add_rmsnorm_pad_enabled()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -237,7 +247,18 @@ class TransformerBlock(torch.nn.Module):
|
||||
hidden_states = self.attn(hidden_states, positions)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
if self.is_rocm_aiter_fused_add_rmsnorm_pad_enabled:
|
||||
hidden_states, residual = fused_add_rmsnorm_pad(
|
||||
hidden_states,
|
||||
self.post_attention_layernorm.weight,
|
||||
self.post_attention_layernorm.variance_epsilon,
|
||||
residual,
|
||||
x_pad_to_multiple=256,
|
||||
)
|
||||
else:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
output = self.mlp(hidden_states)
|
||||
return output, residual
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user