diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 6364b89fb837..af289455527c 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -13,6 +13,7 @@ # limitations under the License. """Inference-only OLMoE model compatible with HuggingFace weights.""" from collections.abc import Iterable +from functools import partial from typing import Any, Optional, Union import torch @@ -22,7 +23,10 @@ from transformers import PretrainedConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) +from vllm.distributed.utils import split_tensor_along_last_dim from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -140,8 +144,11 @@ class OlmoeAttention(nn.Module): bias=False, quant_config=quant_config, ) - self.q_norm = RMSNorm(hidden_size, eps=1e-5) - self.k_norm = RMSNorm(hidden_size, eps=1e-5) + self.tp_size = tp_size + self.tp_rank = get_tensor_model_parallel_rank() + self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5) + self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, + eps=1e-5) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, @@ -165,6 +172,20 @@ class OlmoeAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.attn") + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm(q) + k = self.k_norm(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + def forward( self, positions: torch.Tensor, @@ -172,7 +193,7 @@ class OlmoeAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) + q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output)