diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c03217b790c5e..e95eeba5e411a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -184,7 +184,6 @@ class FusedMoEParallelConfig: # Adapted from pplx-kernels tests/all_to_all_utils.py @dataclass class MoEConfig: - max_num_tokens: int num_experts: int experts_per_token: int hidden_dim: int @@ -347,33 +346,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None - experts: Optional[FusedMoEPermuteExpertsUnpermute] = None - if isinstance(prepare_finalize, (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): logger.debug("BatchedTritonExperts %s", self.moe) - experts = BatchedTritonExperts( + return BatchedTritonExperts( max_num_tokens=MOE_DP_CHUNK_SIZE, world_size=all2all_manager.world_size, # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, ) else: logger.debug("TritonExperts %s", self.moe) - experts = TritonExperts( - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - block_shape=None, - per_channel_quant=False, - ) - return experts + return TritonExperts() def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -472,35 +456,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) - def set_prepare_finalize( - self, - dp_size: int, - world_size: int, - prepare_finalize: FusedMoEPrepareAndFinalize, - ) -> bool: - assert self.fused_experts == fused_experts - - experts: Optional[FusedMoEPermuteExpertsUnpermute] = None - - if isinstance(prepare_finalize, - (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): - logger.debug("BatchedTritonExperts %s", self.moe) - experts = BatchedTritonExperts( - max_num_tokens=MOE_DP_CHUNK_SIZE, - world_size=world_size, - dp_size=dp_size, - ) - else: - logger.debug("TritonExperts %s", self.moe) - experts = TritonExperts() - - self.fused_experts = FusedMoEModularKernel( - prepare_finalize, - experts, - ) - - return True - def forward_cuda( self, layer: torch.nn.Module, @@ -815,16 +770,14 @@ class FusedMoE(torch.nn.Module): from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) - logger.debug("Model dtype = %s", vllm_config.model_config.dtype) - moe = MoEConfig( - max_num_tokens=MOE_DP_CHUNK_SIZE, num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, - in_dtype=moe.in_dtype, + in_dtype=vllm_config.model_config.dtype, + max_num_tokens=MOE_DP_CHUNK_SIZE, ) self.moe_config = moe self.quant_config = quant_config @@ -1281,7 +1234,7 @@ class FusedMoE(torch.nn.Module): assert (self.batched_hidden_states.size(0) # type: ignore >= chunk_size) - assert (self.batched_router_logits.size(0) # type: ignore + assert (self.batched_router_logits.size(0) # type: ignore >= chunk_size) staged_hidden_states = self.batched_hidden_states[: chunk_size, :] # type: ignore diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index a449ac5ca596f..589629dbfe243 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -108,8 +108,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # There's not much point setting this unless it is != indices.size(0) bound_m: Optional[torch.Tensor] = None - #print(f"SCALE= {a1q_scale.shape}") - self.a2a.dispatch( out_expert_num_tokens=expert_num_tokens, out_expert_x=expert_x, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 456c84496aba5..2fc7c7d7d94f4 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -11,7 +11,7 @@ from torch.nn.parameter import Parameter import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (MOE_DP_CHUNK_SIZE, FusedMoE, FusedMoEMethodBase, @@ -771,17 +771,26 @@ class Fp8MoEMethod(FusedMoEMethodBase): def select_gemm_impl(self, prepare_finalize): from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedPrepareAndFinalize, + BatchedTritonExperts) + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet.") + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + if isinstance(prepare_finalize, (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): logger.debug("BatchedTritonExperts(fp8)") + self.use_pplx_kernels = True return BatchedTritonExperts( max_num_tokens=MOE_DP_CHUNK_SIZE, - world_size=world_size, - dp_size=dp_size, + world_size=all2all_manager.world_size, + dp_size=all2all_manager.tp_group.world_size, qtype=torch.float8_e4m3fn, block_shape=self.quant_config.weight_block_size, per_act_token=False, #? diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index fda0a5b554bcd..0d52bc9a68148 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -492,11 +492,6 @@ class MPClient(EngineCoreClient): (e for e in self.core_engines if e.identity == eng_identity), None) if engine is None: - msg = msgspec.msgpack.decode(ready_msg_bytes) - status, local = msg["status"], msg["local"] - logger.debug(f"XXXXXX {status} message from " - f"{'local' if local else 'remote'} " - f"engine {eng_index}") raise RuntimeError(f"Message from engine with unexpected data " f"parallel rank: {eng_index}") msg = msgspec.msgpack.decode(ready_msg_bytes)