diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6001b6d83c398..9b4d77a060c29 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -520,6 +520,10 @@ class FusedMoE(CustomOp): self._init_aiter_shared_experts_topK_buffer( vllm_config=vllm_config, dp_size=dp_size_ ) + if self.use_ep and self.rocm_aiter_fmoe_enabled: + assert self.expert_mask is None or torch.all( + (expert_mask == 0) | (expert_mask == 1) + ), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s." assert intermediate_size % self.tp_size == 0 self.hidden_size = hidden_size diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 8be0299eaa66f..9e2b2134310fc 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -633,6 +633,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): topk_ids=topk_ids, activation=activation, quant_config=self.moe_quant_config, + expert_map=expert_map, ) else: from vllm.model_executor.layers.fused_moe import fused_experts