mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
Fix cuda illegal mem access with Llama4 TP8 + rms_norm custom op (#22701)
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
parent
c5830381af
commit
4f0f844b16
@ -224,10 +224,14 @@ class Llama4Attention(nn.Module):
|
||||
|
||||
if self.rotary_emb is not None:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
if self.qk_norm is not None:
|
||||
q = q.reshape(-1, self.num_heads, self.head_dim)
|
||||
# Normalization is applied on the head_dim dimension. The rest of
|
||||
# the dimensions are collapsed into a single dimension to support
|
||||
# custom rms_norm cuda kernel.
|
||||
q = q.reshape(-1, self.head_dim)
|
||||
q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype)
|
||||
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
|
||||
k = k.reshape(-1, self.head_dim)
|
||||
k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype)
|
||||
|
||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user