[Kernel] Use fused rmsnorm for some models like qwen3 series (#17735)

Signed-off-by: evian <eviantai@u.nus.edu>
Co-authored-by: evian <eviantai@u.nus.edu>
This commit is contained in:
Wanrui Dai 2025-05-07 14:10:02 +08:00 committed by GitHub
parent 1a45a61387
commit f80ae5bdcf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 19 additions and 15 deletions

View File

@ -140,6 +140,10 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
double epsilon) {
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;

View File

@ -186,7 +186,9 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
torch.ops._C.rms_norm(out, input, weight, epsilon)
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
input_contiguous = input.contiguous()
torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,

View File

@ -190,8 +190,8 @@ class InternParallelAttention(nn.Module):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
@ -264,10 +264,8 @@ class InternSdpaAttention(nn.Module):
if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

View File

@ -438,8 +438,8 @@ class MolmoAttention(nn.Module):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)

View File

@ -139,8 +139,8 @@ class Olmo2Attention(nn.Module):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)

View File

@ -133,11 +133,11 @@ class Qwen3Attention(nn.Module):
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm.forward_native(q_by_head)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm.forward_native(k_by_head)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)

View File

@ -225,12 +225,12 @@ class Qwen3MoeAttention(nn.Module):
# Add qk-norm
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
self.head_dim)
q_by_head = self.q_norm.forward_native(q_by_head)
q_by_head = self.q_norm(q_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
self.head_dim)
k_by_head = self.k_norm.forward_native(k_by_head)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)