Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-28 23:09:32 +00:00
parent 468d16654a
commit c169b05541
4 changed files with 17 additions and 62 deletions

View File

@ -184,7 +184,6 @@ class FusedMoEParallelConfig:
# Adapted from pplx-kernels tests/all_to_all_utils.py
@dataclass
class MoEConfig:
max_num_tokens: int
num_experts: int
experts_per_token: int
hidden_dim: int
@ -347,33 +346,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
return BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
return experts
return TritonExperts()
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@ -472,35 +456,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool:
assert self.fused_experts == fused_experts
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts()
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
def forward_cuda(
self,
layer: torch.nn.Module,
@ -815,16 +770,14 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
logger.debug("Model dtype = %s", vllm_config.model_config.dtype)
moe = MoEConfig(
max_num_tokens=MOE_DP_CHUNK_SIZE,
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe.in_dtype,
in_dtype=vllm_config.model_config.dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
)
self.moe_config = moe
self.quant_config = quant_config
@ -1281,7 +1234,7 @@ class FusedMoE(torch.nn.Module):
assert (self.batched_hidden_states.size(0) # type: ignore
>= chunk_size)
assert (self.batched_router_logits.size(0) # type: ignore
assert (self.batched_router_logits.size(0) # type: ignore
>= chunk_size)
staged_hidden_states = self.batched_hidden_states[:
chunk_size, :] # type: ignore

View File

@ -108,8 +108,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# There's not much point setting this unless it is != indices.size(0)
bound_m: Optional[torch.Tensor] = None
#print(f"SCALE= {a1q_scale.shape}")
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,

View File

@ -11,7 +11,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (MOE_DP_CHUNK_SIZE, FusedMoE,
FusedMoEMethodBase,
@ -771,17 +771,26 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def select_gemm_impl(self, prepare_finalize):
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize,
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.")
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts(fp8)")
self.use_pplx_kernels = True
return BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
world_size=all2all_manager.world_size,
dp_size=all2all_manager.tp_group.world_size,
qtype=torch.float8_e4m3fn,
block_shape=self.quant_config.weight_block_size,
per_act_token=False, #?

View File

@ -492,11 +492,6 @@ class MPClient(EngineCoreClient):
(e for e in self.core_engines if e.identity == eng_identity),
None)
if engine is None:
msg = msgspec.msgpack.decode(ready_msg_bytes)
status, local = msg["status"], msg["local"]
logger.debug(f"XXXXXX {status} message from "
f"{'local' if local else 'remote'} "
f"engine {eng_index}")
raise RuntimeError(f"Message from engine with unexpected data "
f"parallel rank: {eng_index}")
msg = msgspec.msgpack.decode(ready_msg_bytes)