[BugFix] GPT-OSS Attention DP + MoE TP weight loading issue (#24032)

Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
Po-Han Huang (NVIDIA) 2025-10-21 12:03:47 +08:00 committed by GitHub
parent 5f6cbf60d6
commit aef368aa08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 30 additions and 13 deletions

View File

@ -662,6 +662,17 @@ class FusedMoEParallelConfig:
def use_deepep_ll_kernels(self):
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@staticmethod
def flatten_tp_across_dp(
tp_size: int, dp_size: int, dp_rank: int
) -> tuple[int, int]:
tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank()
# There are actually dp_size * tp_size devices. Update tp_size
# and tp_rank so we shard across all devices.
flatten_tp_size = dp_size * tp_size
flatten_tp_rank = dp_rank * tp_size + tp_rank
return flatten_tp_size, flatten_tp_rank
@staticmethod
def make(
tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig
@ -737,19 +748,13 @@ class FusedMoEParallelConfig:
between the 4 devices.
"""
def flatten_tp_across_dp(dp_rank: int):
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
# There are actually dp_size_ * tp_size_ devices. Update tp_size
# and tp_rank so we shard across all devices.
tp_size = dp_size_ * tp_size_
tp_rank = dp_rank * tp_size_ + tp_rank
return tp_size, tp_rank
use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel
dp_size = dp_size_
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
tp_size_, dp_size_, dp_rank
)
if not use_ep:
return FusedMoEParallelConfig(

View File

@ -11,6 +11,7 @@ from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pp_group,
get_tensor_model_parallel_rank,
@ -18,6 +19,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -305,8 +307,13 @@ class GptOssModel(nn.Module):
use_ep = self.parallel_config.enable_expert_parallel
num_experts = self.config.num_local_experts
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
# In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled.
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group,
)
intermediate_size = self.config.intermediate_size
intermediate_size_block = intermediate_size // mxfp4_block
@ -488,8 +495,13 @@ class GptOssModel(nn.Module):
use_ep = self.parallel_config.enable_expert_parallel
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
# In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled.
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group,
)
intermediate_size = self.config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)