Add aiter fused_add_rmsnorm_pad to gpt-oss

Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
Rohan138 2025-12-18 11:44:59 -06:00 committed by Gregory Shtrasberg
parent 25b5317ce4
commit 577128bd80
3 changed files with 38 additions and 2 deletions

View File

@ -651,6 +651,9 @@ class rocm_aiter_ops:
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
_TRITON_FUSED_ADD_RMSNORM_PAD_ENABLED = (
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD
)
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
@ -686,6 +689,12 @@ class rocm_aiter_ops:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
@classmethod
@if_aiter_supported
def is_triton_fused_add_rmsnorm_pad_enabled(cls) -> bool:
""" "Verifies device specs and availability of env variable."""
return cls._AITER_ENABLED and cls._TRITON_FUSED_ADD_RMSNORM_PAD_ENABLED
@classmethod
@if_aiter_supported
def is_fused_moe_enabled(cls) -> bool:

View File

@ -115,6 +115,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = False
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
@ -956,6 +957,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ROCM_USE_AITER_RMSNORM": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")
),
# Whether to use fused add+rmsnorm+pad kernel for gpt-oss
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD": lambda: (
os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD", "True").lower()
in ("true", "1")
),
# Whether to use aiter mla ops.
# By default is enabled.
"VLLM_ROCM_USE_AITER_MLA": lambda: (

View File

@ -50,6 +50,12 @@ from .utils import (
maybe_prefix,
)
if (
current_platform.is_rocm()
and rocm_aiter_ops.is_triton_fused_add_rmsnorm_pad_enabled()
):
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
class OAIAttention(nn.Module):
def __init__(
@ -191,7 +197,7 @@ class MLPBlock(torch.nn.Module):
)
else:
g = self.router(x)
x = self.experts(hidden_states=x, router_logits=g)
x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
if self.is_sequence_parallel:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
@ -222,6 +228,10 @@ class TransformerBlock(torch.nn.Module):
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.is_rocm_aiter_fused_add_rmsnorm_pad_enabled = (
rocm_aiter_ops.is_triton_fused_add_rmsnorm_pad_enabled()
)
def forward(
self,
hidden_states: torch.Tensor,
@ -237,7 +247,18 @@ class TransformerBlock(torch.nn.Module):
hidden_states = self.attn(hidden_states, positions)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
if self.is_rocm_aiter_fused_add_rmsnorm_pad_enabled:
hidden_states, residual = fused_add_rmsnorm_pad(
hidden_states,
self.post_attention_layernorm.weight,
self.post_attention_layernorm.variance_epsilon,
residual,
x_pad_to_multiple=256,
)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
output = self.mlp(hidden_states)
return output, residual