[Misc] fix olmoe model layer can't laod in tp gt 1 (#18828)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
rongfu.leng 2025-05-29 01:36:21 +08:00 committed by GitHub
parent fced756923
commit c68b5c63eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights.""" """Inference-only OLMoE model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial
from typing import Any, Optional, Union from typing import Any, Optional, Union
import torch import torch
@ -22,7 +23,10 @@ from transformers import PretrainedConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed.utils import split_tensor_along_last_dim
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -140,8 +144,11 @@ class OlmoeAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
) )
self.q_norm = RMSNorm(hidden_size, eps=1e-5) self.tp_size = tp_size
self.k_norm = RMSNorm(hidden_size, eps=1e-5) self.tp_rank = get_tensor_model_parallel_rank()
self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5)
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
eps=1e-5)
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
@ -165,6 +172,20 @@ class OlmoeAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn") prefix=f"{prefix}.attn")
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
@ -172,7 +193,7 @@ class OlmoeAttention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)