mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 21:25:33 +08:00
[Misc] fix olmoe model layer can't laod in tp gt 1 (#18828)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
parent
fced756923
commit
c68b5c63eb
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only OLMoE model compatible with HuggingFace weights."""
|
"""Inference-only OLMoE model compatible with HuggingFace weights."""
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -22,7 +23,10 @@ from transformers import PretrainedConfig
|
|||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
tensor_model_parallel_all_gather)
|
||||||
|
from vllm.distributed.utils import split_tensor_along_last_dim
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -140,8 +144,11 @@ class OlmoeAttention(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.q_norm = RMSNorm(hidden_size, eps=1e-5)
|
self.tp_size = tp_size
|
||||||
self.k_norm = RMSNorm(hidden_size, eps=1e-5)
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5)
|
||||||
|
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
|
||||||
|
eps=1e-5)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@ -165,6 +172,20 @@ class OlmoeAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn")
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
|
def _apply_qk_norm(self, q: torch.Tensor,
|
||||||
|
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if self.tp_size > 1:
|
||||||
|
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||||
|
k = tensor_model_parallel_all_gather(k.contiguous())
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
if self.tp_size > 1:
|
||||||
|
splitter = partial(split_tensor_along_last_dim,
|
||||||
|
num_partitions=self.tp_size)
|
||||||
|
q = splitter(q)[self.tp_rank]
|
||||||
|
k = splitter(k)[self.tp_rank]
|
||||||
|
return q, k
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@ -172,7 +193,7 @@ class OlmoeAttention(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
|
q, k = self._apply_qk_norm(q, k)
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user