diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2c86b8ed32f32..1667bfd4c7eb9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1102,6 +1102,7 @@ class FusedMoE(CustomOp): self.params_dtype = params_dtype vllm_config = get_current_vllm_config() + self.vllm_config = vllm_config # FIXME (varun): We should have a better way of inferring the activation # datatype. This works for now as the tensor datatype entering the MoE @@ -1342,26 +1343,6 @@ class FusedMoE(CustomOp): self.batched_hidden_states: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None - if self.use_dp_chunking: - states_shape: tuple[int, ...] - logits_shape: tuple[int, ...] - - # Note here we use `num_experts` which is logical expert count - if vllm_config.parallel_config.enable_dbo: - states_shape = (2, moe.max_num_tokens, self.hidden_size) - logits_shape = (2, moe.max_num_tokens, num_experts) - else: - states_shape = (moe.max_num_tokens, self.hidden_size) - logits_shape = (moe.max_num_tokens, num_experts) - - self.batched_hidden_states = torch.zeros( - states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() - ) - - self.batched_router_logits = torch.zeros( - logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() - ) - @property def shared_experts(self) -> torch.nn.Module | None: return None @@ -1420,8 +1401,6 @@ class FusedMoE(CustomOp): @property def use_dp_chunking(self) -> bool: - # Route to the chunked forward path using the FlashInfer Cutlass kernel - # only when data parallelism (DP) is enabled. return ( self.moe_parallel_config.use_pplx_kernels or self.moe_parallel_config.use_deepep_ll_kernels @@ -1988,12 +1967,40 @@ class FusedMoE(CustomOp): self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] self.logical_replica_count = logical_replica_count[moe_layer_idx] - def ensure_moe_quant_config(self): + def ensure_moe_quant_config_init(self): if self.quant_method.moe_quant_config is None: self.quant_method.moe_quant_config = ( self.quant_method.get_fused_moe_quant_config(self) ) + if self.moe_quant_config is None: + self.moe_quant_config = self.quant_method.moe_quant_config + + def ensure_dp_chunking_init(self): + if not self.use_dp_chunking or self.batched_hidden_states is not None: + return + + states_shape: tuple[int, ...] + logits_shape: tuple[int, ...] + + moe = self.moe_config + + # Note here we use `num_experts` which is logical expert count + if self.vllm_config.parallel_config.enable_dbo: + states_shape = (2, moe.max_num_tokens, self.hidden_size) + logits_shape = (2, moe.max_num_tokens, moe.num_experts) + else: + states_shape = (moe.max_num_tokens, self.hidden_size) + logits_shape = (moe.max_num_tokens, moe.num_experts) + + self.batched_hidden_states = torch.zeros( + states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() + ) + + self.batched_router_logits = torch.zeros( + logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() + ) + @staticmethod def select_experts( hidden_states: torch.Tensor, @@ -2224,8 +2231,6 @@ class FusedMoE(CustomOp): assert self.batched_hidden_states.size(-1) == full_hidden_states.size(-1) assert self.batched_router_logits.size(-1) == full_router_logits.size(-1) - self.ensure_moe_quant_config() - full_fused_final_hidden_states = torch.empty_like(full_hidden_states) if self.shared_experts is not None: full_shared_final_hidden_states = torch.empty_like(full_hidden_states) @@ -2383,7 +2388,8 @@ class FusedMoE(CustomOp): ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.quant_method is not None - self.ensure_moe_quant_config() + self.ensure_moe_quant_config_init() + self.ensure_dp_chunking_init() has_separate_shared_experts = ( not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)