From 4f0f844b1675419fd2171bc5e981a82386ec552b Mon Sep 17 00:00:00 2001 From: "Po-Han Huang (NVIDIA)" <53919306+nvpohanh@users.noreply.github.com> Date: Wed, 13 Aug 2025 12:21:50 +0800 Subject: [PATCH] Fix cuda illegal mem access with Llama4 TP8 + rms_norm custom op (#22701) Signed-off-by: Po-Han Huang --- vllm/model_executor/models/llama4.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 1f8b9d0744790..308cb3e85e27b 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -224,10 +224,14 @@ class Llama4Attention(nn.Module): if self.rotary_emb is not None: q, k = self.rotary_emb(positions, q, k) + if self.qk_norm is not None: - q = q.reshape(-1, self.num_heads, self.head_dim) + # Normalization is applied on the head_dim dimension. The rest of + # the dimensions are collapsed into a single dimension to support + # custom rms_norm cuda kernel. + q = q.reshape(-1, self.head_dim) q = self.qk_norm(q.float()).reshape(-1, self.q_size).to(q.dtype) - k = k.reshape(-1, self.num_kv_heads, self.head_dim) + k = k.reshape(-1, self.head_dim) k = self.qk_norm(k.float()).reshape(-1, self.kv_size).to(k.dtype) # We are applying temperature tuning (https://arxiv.org/abs/2501.19399)