mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-11 15:54:37 +08:00
[BugFix] GPT-OSS Attention DP + MoE TP weight loading issue (#24032)
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
parent
5f6cbf60d6
commit
aef368aa08
@ -662,6 +662,17 @@ class FusedMoEParallelConfig:
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
|
||||
|
||||
@staticmethod
|
||||
def flatten_tp_across_dp(
|
||||
tp_size: int, dp_size: int, dp_rank: int
|
||||
) -> tuple[int, int]:
|
||||
tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size * tp_size devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
flatten_tp_size = dp_size * tp_size
|
||||
flatten_tp_rank = dp_rank * tp_size + tp_rank
|
||||
return flatten_tp_size, flatten_tp_rank
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig
|
||||
@ -737,19 +748,13 @@ class FusedMoEParallelConfig:
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
def flatten_tp_across_dp(dp_rank: int):
|
||||
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
tp_size = dp_size_ * tp_size_
|
||||
tp_rank = dp_rank * tp_size_ + tp_rank
|
||||
return tp_size, tp_rank
|
||||
|
||||
use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel
|
||||
|
||||
dp_size = dp_size_
|
||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
|
||||
tp_size_, dp_size_, dp_rank
|
||||
)
|
||||
|
||||
if not use_ep:
|
||||
return FusedMoEParallelConfig(
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.attention import Attention, AttentionType
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
@ -18,6 +19,7 @@ from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -305,8 +307,13 @@ class GptOssModel(nn.Module):
|
||||
use_ep = self.parallel_config.enable_expert_parallel
|
||||
num_experts = self.config.num_local_experts
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# In MoE, we need to flatten the tensor parallel size across the data
|
||||
# parallel size when EP is disabled.
|
||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
dp_size=get_dp_group().world_size,
|
||||
dp_rank=get_dp_group().rank_in_group,
|
||||
)
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
intermediate_size_block = intermediate_size // mxfp4_block
|
||||
@ -488,8 +495,13 @@ class GptOssModel(nn.Module):
|
||||
|
||||
use_ep = self.parallel_config.enable_expert_parallel
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
# In MoE, we need to flatten the tensor parallel size across the data
|
||||
# parallel size when EP is disabled.
|
||||
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
dp_size=get_dp_group().world_size,
|
||||
dp_rank=get_dp_group().rank_in_group,
|
||||
)
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user