mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-25 15:37:59 +08:00
[Kernel] Use fused rmsnorm for some models like qwen3 series (#17735)
Signed-off-by: evian <eviantai@u.nus.edu> Co-authored-by: evian <eviantai@u.nus.edu>
This commit is contained in:
parent
1a45a61387
commit
f80ae5bdcf
@ -140,6 +140,10 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
|||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
torch::Tensor& weight, // [hidden_size]
|
torch::Tensor& weight, // [hidden_size]
|
||||||
double epsilon) {
|
double epsilon) {
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
|
TORCH_CHECK(input.is_contiguous());
|
||||||
|
TORCH_CHECK(weight.is_contiguous());
|
||||||
|
|
||||||
int hidden_size = input.size(-1);
|
int hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
|||||||
@ -186,7 +186,9 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
|||||||
# layer norm ops
|
# layer norm ops
|
||||||
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||||
epsilon: float) -> None:
|
epsilon: float) -> None:
|
||||||
torch.ops._C.rms_norm(out, input, weight, epsilon)
|
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
|
||||||
|
input_contiguous = input.contiguous()
|
||||||
|
torch.ops._C.rms_norm(out, input_contiguous, weight, epsilon)
|
||||||
|
|
||||||
|
|
||||||
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||||
|
|||||||
@ -190,8 +190,8 @@ class InternParallelAttention(nn.Module):
|
|||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
q = tensor_model_parallel_all_gather(q.contiguous())
|
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||||
k = tensor_model_parallel_all_gather(k.contiguous())
|
k = tensor_model_parallel_all_gather(k.contiguous())
|
||||||
q = self.q_norm.forward_native(q)
|
q = self.q_norm(q)
|
||||||
k = self.k_norm.forward_native(k)
|
k = self.k_norm(k)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
splitter = partial(split_tensor_along_last_dim,
|
splitter = partial(split_tensor_along_last_dim,
|
||||||
num_partitions=self.tp_size)
|
num_partitions=self.tp_size)
|
||||||
@ -264,10 +264,8 @@ class InternSdpaAttention(nn.Module):
|
|||||||
|
|
||||||
if self.qk_normalization:
|
if self.qk_normalization:
|
||||||
B_, N_, H_, D_ = q.shape
|
B_, N_, H_, D_ = q.shape
|
||||||
q = self.q_norm.forward_native(q.flatten(-2,
|
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||||
-1)).view(B_, N_, H_, D_)
|
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||||
k = self.k_norm.forward_native(k.flatten(-2,
|
|
||||||
-1)).view(B_, N_, H_, D_)
|
|
||||||
q = q.transpose(1, 2)
|
q = q.transpose(1, 2)
|
||||||
k = k.transpose(1, 2)
|
k = k.transpose(1, 2)
|
||||||
v = v.transpose(1, 2)
|
v = v.transpose(1, 2)
|
||||||
|
|||||||
@ -438,8 +438,8 @@ class MolmoAttention(nn.Module):
|
|||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
q = tensor_model_parallel_all_gather(q.contiguous())
|
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||||
k = tensor_model_parallel_all_gather(k.contiguous())
|
k = tensor_model_parallel_all_gather(k.contiguous())
|
||||||
q = self.q_norm.forward_native(q)
|
q = self.q_norm(q)
|
||||||
k = self.k_norm.forward_native(k)
|
k = self.k_norm(k)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
splitter = partial(split_tensor_along_last_dim,
|
splitter = partial(split_tensor_along_last_dim,
|
||||||
num_partitions=self.tp_size)
|
num_partitions=self.tp_size)
|
||||||
|
|||||||
@ -139,8 +139,8 @@ class Olmo2Attention(nn.Module):
|
|||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
q = tensor_model_parallel_all_gather(q.contiguous())
|
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||||
k = tensor_model_parallel_all_gather(k.contiguous())
|
k = tensor_model_parallel_all_gather(k.contiguous())
|
||||||
q = self.q_norm.forward_native(q)
|
q = self.q_norm(q)
|
||||||
k = self.k_norm.forward_native(k)
|
k = self.k_norm(k)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
splitter = partial(split_tensor_along_last_dim,
|
splitter = partial(split_tensor_along_last_dim,
|
||||||
num_partitions=self.tp_size)
|
num_partitions=self.tp_size)
|
||||||
|
|||||||
@ -133,11 +133,11 @@ class Qwen3Attention(nn.Module):
|
|||||||
# Add qk-norm
|
# Add qk-norm
|
||||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
q_by_head = self.q_norm.forward_native(q_by_head)
|
q_by_head = self.q_norm(q_by_head)
|
||||||
q = q_by_head.view(q.shape)
|
q = q_by_head.view(q.shape)
|
||||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
k_by_head = self.k_norm.forward_native(k_by_head)
|
k_by_head = self.k_norm(k_by_head)
|
||||||
k = k_by_head.view(k.shape)
|
k = k_by_head.view(k.shape)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
|
|||||||
@ -225,12 +225,12 @@ class Qwen3MoeAttention(nn.Module):
|
|||||||
# Add qk-norm
|
# Add qk-norm
|
||||||
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
q_by_head = self.q_norm.forward_native(q_by_head)
|
q_by_head = self.q_norm(q_by_head)
|
||||||
q = q_by_head.view(q.shape)
|
q = q_by_head.view(q.shape)
|
||||||
|
|
||||||
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
|
||||||
self.head_dim)
|
self.head_dim)
|
||||||
k_by_head = self.k_norm.forward_native(k_by_head)
|
k_by_head = self.k_norm(k_by_head)
|
||||||
k = k_by_head.view(k.shape)
|
k = k_by_head.view(k.shape)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user