diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 38ea6acc0fc50..924736b274f35 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -662,6 +662,17 @@ class FusedMoEParallelConfig: def use_deepep_ll_kernels(self): return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency" + @staticmethod + def flatten_tp_across_dp( + tp_size: int, dp_size: int, dp_rank: int + ) -> tuple[int, int]: + tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank() + # There are actually dp_size * tp_size devices. Update tp_size + # and tp_rank so we shard across all devices. + flatten_tp_size = dp_size * tp_size + flatten_tp_rank = dp_rank * tp_size + tp_rank + return flatten_tp_size, flatten_tp_rank + @staticmethod def make( tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig @@ -737,19 +748,13 @@ class FusedMoEParallelConfig: between the 4 devices. """ - def flatten_tp_across_dp(dp_rank: int): - tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank() - # There are actually dp_size_ * tp_size_ devices. Update tp_size - # and tp_rank so we shard across all devices. - tp_size = dp_size_ * tp_size_ - tp_rank = dp_rank * tp_size_ + tp_rank - return tp_size, tp_rank - use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel dp_size = dp_size_ dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0 - tp_size, tp_rank = flatten_tp_across_dp(dp_rank) + tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( + tp_size_, dp_size_, dp_rank + ) if not use_ep: return FusedMoEParallelConfig( diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index 1e32c433cabae..846c8e7669bed 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -11,6 +11,7 @@ from vllm.attention import Attention, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import ( + get_dp_group, get_ep_group, get_pp_group, get_tensor_model_parallel_rank, @@ -18,6 +19,7 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -305,8 +307,13 @@ class GptOssModel(nn.Module): use_ep = self.parallel_config.enable_expert_parallel num_experts = self.config.num_local_experts - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() + # In MoE, we need to flatten the tensor parallel size across the data + # parallel size when EP is disabled. + tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( + tp_size=get_tensor_model_parallel_world_size(), + dp_size=get_dp_group().world_size, + dp_rank=get_dp_group().rank_in_group, + ) intermediate_size = self.config.intermediate_size intermediate_size_block = intermediate_size // mxfp4_block @@ -488,8 +495,13 @@ class GptOssModel(nn.Module): use_ep = self.parallel_config.enable_expert_parallel - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() + # In MoE, we need to flatten the tensor parallel size across the data + # parallel size when EP is disabled. + tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp( + tp_size=get_tensor_model_parallel_world_size(), + dp_size=get_dp_group().world_size, + dp_rank=get_dp_group().rank_in_group, + ) intermediate_size = self.config.intermediate_size per_rank_intermediate_size = cdiv(intermediate_size, tp_size)