From 9a1f1da5d1f2e94adb49e0d82b464dc3c1318cc7 Mon Sep 17 00:00:00 2001 From: Shane A Date: Fri, 21 Feb 2025 22:07:45 -0800 Subject: [PATCH] [Bugfix][Model] OLMo 2: split qkv correctly for GQA and MQA (#13687) --- vllm/model_executor/models/olmo2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/olmo2.py b/vllm/model_executor/models/olmo2.py index 4b0455098eedb..d06f894123ac8 100644 --- a/vllm/model_executor/models/olmo2.py +++ b/vllm/model_executor/models/olmo2.py @@ -157,7 +157,7 @@ class Olmo2Attention(nn.Module): attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata)