mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 03:39:10 +08:00
Add aiter fused_add_rmsnorm_pad to gpt-oss
Signed-off-by: Rohan138 <rohanpotdar138@gmail.com>
This commit is contained in:
parent
25b5317ce4
commit
577128bd80
@ -651,6 +651,9 @@ class rocm_aiter_ops:
|
|||||||
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
_AITER_ENABLED = envs.VLLM_ROCM_USE_AITER
|
||||||
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
_LINEAR_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR
|
||||||
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
_RMSNORM_ENABLED = envs.VLLM_ROCM_USE_AITER_RMSNORM
|
||||||
|
_TRITON_FUSED_ADD_RMSNORM_PAD_ENABLED = (
|
||||||
|
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD
|
||||||
|
)
|
||||||
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
_FMOE_ENABLED = envs.VLLM_ROCM_USE_AITER_MOE
|
||||||
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
_MLA_ENABLED = envs.VLLM_ROCM_USE_AITER_MLA
|
||||||
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
_PG_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
|
||||||
@ -686,6 +689,12 @@ class rocm_aiter_ops:
|
|||||||
""" "Verifies device specs and availability of env variable."""
|
""" "Verifies device specs and availability of env variable."""
|
||||||
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
return cls._AITER_ENABLED and cls._RMSNORM_ENABLED
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@if_aiter_supported
|
||||||
|
def is_triton_fused_add_rmsnorm_pad_enabled(cls) -> bool:
|
||||||
|
""" "Verifies device specs and availability of env variable."""
|
||||||
|
return cls._AITER_ENABLED and cls._TRITON_FUSED_ADD_RMSNORM_PAD_ENABLED
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@if_aiter_supported
|
@if_aiter_supported
|
||||||
def is_fused_moe_enabled(cls) -> bool:
|
def is_fused_moe_enabled(cls) -> bool:
|
||||||
|
|||||||
@ -115,6 +115,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
||||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||||
|
VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD: bool = True
|
||||||
VLLM_ROCM_USE_AITER_MLA: bool = True
|
VLLM_ROCM_USE_AITER_MLA: bool = True
|
||||||
VLLM_ROCM_USE_AITER_MHA: bool = False
|
VLLM_ROCM_USE_AITER_MHA: bool = False
|
||||||
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
||||||
@ -956,6 +957,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_ROCM_USE_AITER_RMSNORM": lambda: (
|
"VLLM_ROCM_USE_AITER_RMSNORM": lambda: (
|
||||||
os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")
|
os.getenv("VLLM_ROCM_USE_AITER_RMSNORM", "True").lower() in ("true", "1")
|
||||||
),
|
),
|
||||||
|
# Whether to use fused add+rmsnorm+pad kernel for gpt-oss
|
||||||
|
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD": lambda: (
|
||||||
|
os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD", "True").lower()
|
||||||
|
in ("true", "1")
|
||||||
|
),
|
||||||
# Whether to use aiter mla ops.
|
# Whether to use aiter mla ops.
|
||||||
# By default is enabled.
|
# By default is enabled.
|
||||||
"VLLM_ROCM_USE_AITER_MLA": lambda: (
|
"VLLM_ROCM_USE_AITER_MLA": lambda: (
|
||||||
|
|||||||
@ -50,6 +50,12 @@ from .utils import (
|
|||||||
maybe_prefix,
|
maybe_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
current_platform.is_rocm()
|
||||||
|
and rocm_aiter_ops.is_triton_fused_add_rmsnorm_pad_enabled()
|
||||||
|
):
|
||||||
|
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
|
||||||
|
|
||||||
|
|
||||||
class OAIAttention(nn.Module):
|
class OAIAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -191,7 +197,7 @@ class MLPBlock(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
g = self.router(x)
|
g = self.router(x)
|
||||||
x = self.experts(hidden_states=x, router_logits=g)
|
x = self.experts(hidden_states=x, router_logits=g)[:, : self.hidden_size]
|
||||||
|
|
||||||
if self.is_sequence_parallel:
|
if self.is_sequence_parallel:
|
||||||
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
|
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
|
||||||
@ -222,6 +228,10 @@ class TransformerBlock(torch.nn.Module):
|
|||||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||||
|
|
||||||
|
self.is_rocm_aiter_fused_add_rmsnorm_pad_enabled = (
|
||||||
|
rocm_aiter_ops.is_triton_fused_add_rmsnorm_pad_enabled()
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -237,7 +247,18 @@ class TransformerBlock(torch.nn.Module):
|
|||||||
hidden_states = self.attn(hidden_states, positions)
|
hidden_states = self.attn(hidden_states, positions)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
if self.is_rocm_aiter_fused_add_rmsnorm_pad_enabled:
|
||||||
|
hidden_states, residual = fused_add_rmsnorm_pad(
|
||||||
|
hidden_states,
|
||||||
|
self.post_attention_layernorm.weight,
|
||||||
|
self.post_attention_layernorm.variance_epsilon,
|
||||||
|
residual,
|
||||||
|
x_pad_to_multiple=256,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
|
hidden_states, residual
|
||||||
|
)
|
||||||
output = self.mlp(hidden_states)
|
output = self.mlp(hidden_states)
|
||||||
return output, residual
|
return output, residual
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user