From 577128bd80773b95a3e1ae29f6315f5be03ea2d0 Mon Sep 17 00:00:00 2001 From: Rohan138 Date: Thu, 18 Dec 2025 11:44:59 -0600 Subject: [PATCH] Add aiter fused_add_rmsnorm_pad to gpt-oss Signed-off-by: Rohan138 --- vllm/_aiter_ops.py | 9 +++++++++ vllm/envs.py | 6 ++++++ vllm/model_executor/models/gpt_oss.py | 25 +++++++++++++++++++++++-- 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 61592eb65efa7..59cfd8627cc5c 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -651,6 +651,9 @@ class rocm_aiter_ops: _AITER_ENABLED = envs.VLLM_ROCM_USE_AITER _LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR _RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM + _TRITON_FUSED_ADD_RMSNORM_PAD_ENABLED = ( + envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD + ) _FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE _MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA _PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN @@ -686,6 +689,12 @@ class rocm_aiter_ops: """ "Verifies device specs and availability of env variable.""" return cls._AITER_ENABLED and cls._RMSNORM_ENABLED + @classmethod + @if_aiter_supported + def is_triton_fused_add_rmsnorm_pad_enabled(cls) -> bool: + """ "Verifies device specs and availability of env variable.""" + return cls._AITER_ENABLED and cls._TRITON_FUSED_ADD_RMSNORM_PAD_ENABLED + @classmethod @if_aiter_supported def is_fused_moe_enabled(cls) -> bool: diff --git a/vllm/envs.py b/vllm/envs.py index 47cd5ebc6a85c..23b53f5377e13 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -115,6 +115,7 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True + VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = False VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False @@ -956,6 +957,11 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_ROCM_USE_AITER_RMSNORM": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1") ), + # Whether to use fused add+rmsnorm+pad kernel for gpt-oss + "VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD", "True").lower() + in ("true", "1") + ), # Whether to use aiter mla ops. # By default is enabled. "VLLM_ROCM_USE_AITER_MLA": lambda: ( diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index a423d54a31889..f9d4cce06cc4a 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -50,6 +50,12 @@ from .utils import ( maybe_prefix, ) +if ( + current_platform.is_rocm() + and rocm_aiter_ops.is_triton_fused_add_rmsnorm_pad_enabled() +): + from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad + class OAIAttention(nn.Module): def __init__( @@ -191,7 +197,7 @@ class MLPBlock(torch.nn.Module): ) else: g = self.router(x) - x = self.experts(hidden_states=x, router_logits=g) + x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size] if self.is_sequence_parallel: x = tensor_model_parallel_all_gather(x.contiguous(), 0) @@ -222,6 +228,10 @@ class TransformerBlock(torch.nn.Module): self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + self.is_rocm_aiter_fused_add_rmsnorm_pad_enabled = ( + rocm_aiter_ops.is_triton_fused_add_rmsnorm_pad_enabled() + ) + def forward( self, hidden_states: torch.Tensor, @@ -237,7 +247,18 @@ class TransformerBlock(torch.nn.Module): hidden_states = self.attn(hidden_states, positions) # Fully Connected - hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + if self.is_rocm_aiter_fused_add_rmsnorm_pad_enabled: + hidden_states, residual = fused_add_rmsnorm_pad( + hidden_states, + self.post_attention_layernorm.weight, + self.post_attention_layernorm.variance_epsilon, + residual, + x_pad_to_multiple=256, + ) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) output = self.mlp(hidden_states) return output, residual