From 6b87ce2ecd2de67b19b94fd81b024268512b45a3 Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Wed, 24 Sep 2025 02:32:22 +0100 Subject: [PATCH] [fix]: add Arm 4bit fused moe support (#23809) Signed-off-by: Nikhil Gupta Signed-off-by: yewentao256 --- cmake/cpu_extension.cmake | 3 +- csrc/cpu/torch_bindings.cpp | 10 + csrc/moe/dynamic_4bit_int_moe_cpu.cpp | 156 +++++++++ csrc/ops.h | 6 + .../layers/fused_moe/cpu_fused_moe.py | 15 +- vllm/model_executor/layers/fused_moe/layer.py | 2 - .../compressed_tensors_moe.py | 307 +++++++++++++++++- 7 files changed, 488 insertions(+), 11 deletions(-) create mode 100644 csrc/moe/dynamic_4bit_int_moe_cpu.cpp diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 06494463223b..2a2ec08f8695 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -258,7 +258,8 @@ set(VLLM_EXT_SRC "csrc/cpu/layernorm.cpp" "csrc/cpu/mla_decode.cpp" "csrc/cpu/pos_encoding.cpp" - "csrc/cpu/torch_bindings.cpp") + "csrc/cpu/torch_bindings.cpp" + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp") if (AVX512_FOUND AND NOT AVX512_DISABLED) set(VLLM_EXT_SRC diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 98c3ebc5a75f..d279c03e0b59 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -88,8 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); + ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1); + ops.def( + "dynamic_4bit_int_moe(" + "Tensor x, Tensor topk_ids, Tensor topk_weights," + "Tensor w13_packed, Tensor w2_packed, int H, int I, int I2," + "int group_size, bool apply_router_weight_on_input, int activation_kind" + ") -> Tensor"); + + ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu); + // PagedAttention V2. ops.def( "paged_attention_v2(" diff --git a/csrc/moe/dynamic_4bit_int_moe_cpu.cpp b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp new file mode 100644 index 000000000000..1d06fc6b5b0a --- /dev/null +++ b/csrc/moe/dynamic_4bit_int_moe_cpu.cpp @@ -0,0 +1,156 @@ +#include +#include +#include + +// _dyn_quant_matmul_4bit is only available on AArch64. +#if defined(__aarch64__) + #include +#endif + +inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w, + int64_t group_size_eff, int64_t in_features, + int64_t out_features) { +#if defined(__aarch64__) + return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff, + in_features, out_features); +#else + TORCH_CHECK(false, + "dynamic 4-bit int MoE path requires AArch64 (ARM64); " + "_dyn_quant_matmul_4bit is unavailable on this architecture"); + return {}; +#endif +} + +enum ActivationKind : int64_t { + SwiGLU_Gu = 0, // act = SiLU(g) * u + SwiGLUOAI = 1, // act = SiLU(u) * g + SiLU = 2 // SiLU +}; + +torch::Tensor dynamic_4bit_int_moe_cpu( + torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, + torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I, + int64_t I2, int64_t group_size, bool apply_router_weight_on_input, + int64_t activation_kind) { + TORCH_CHECK(x.dim() == 2, "x must be 2D"); + TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2, + "topk tensors must be [T, K]"); + TORCH_CHECK( + w13_packed.size(0) == w2_packed.size(0), + "w13_packed and w2_packed must have same number of experts in dim 0"); + TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I"); + + const int64_t T = x.size(0); + const int64_t K = topk_ids.size(1); + const int64_t E = w13_packed.size(0); + const int64_t N = T * K; + + auto x_c = x.contiguous(); + auto ids_c = topk_ids.contiguous(); + auto gates_c = topk_weights.to(at::kFloat).contiguous(); + + // bucketing tokens -> experts + c10::SmallVector counts( + E, 0); // Small vector uses stack allocation + { + const auto* ids_ptr = ids_c.data_ptr(); + for (int64_t i = 0; i < N; ++i) { + const int64_t e_id = ids_ptr[i]; + TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range"); + counts[e_id]++; + } + } + c10::SmallVector offsets(E + 1, 0); // ( E +1 ) + for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e]; + + auto expert_tokens = at::empty({offsets[E]}, ids_c.options()); + auto expert_gates = at::empty({offsets[E]}, gates_c.options()); + { + c10::SmallVector cursor(E, 0); + const auto* ids_ptr = ids_c.data_ptr(); + const auto* gts_ptr = gates_c.data_ptr(); + auto* tok_ptr = expert_tokens.data_ptr(); + auto* gate_ptr = expert_gates.data_ptr(); + + for (int64_t t = 0; t < T; ++t) { + const int64_t base = t * K; + for (int64_t k = 0; k < K; ++k) { + const int64_t idx = base + k; + const int64_t e = ids_ptr[idx]; + const int64_t p = offsets[e] + (cursor[e]++); + tok_ptr[p] = t; + gate_ptr[p] = gts_ptr[idx]; + } + } + } + + const int64_t g_eff_13 = (group_size != -1) ? group_size : H; + const int64_t g_eff_2 = (group_size != -1) ? group_size : I; + + // Per-expert outputs filled in parallel + std::vector y_list(E); + y_list.resize(E); + + at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) { + for (int64_t e = e_begin; e < e_end; ++e) { + const int64_t te = counts[e]; + if (te == 0) { + y_list[e] = at::empty({0, H}, x_c.options()); + continue; + } + + const int64_t start = offsets[e]; + + auto sel_tokens = + expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); + auto gates_e = + expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te); + + auto x_e = x_c.index_select(/*dim=*/0, sel_tokens); + + if (apply_router_weight_on_input) { + x_e = x_e.mul(gates_e.unsqueeze(1)); + } + + auto w13_e = w13_packed.select(/*dim=*/0, e); + auto w2_e = w2_packed.select(/*dim=*/0, e); + + // W13 + auto y13 = + mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2); + + auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I); + auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I); + + torch::Tensor act; + if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI + constexpr double kAlpha = 1.702; // GPT-OSS default + constexpr double kLimit = 7.0; // GPT-OSS default + auto gate_c = at::clamp_max(g_part, kLimit); + auto up_c = at::clamp(u_part, -kLimit, kLimit); + auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha))); + act = up_c.add(1.0).mul(glu); + } else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul() + act = at::silu(g_part).mul(u_part); + } + + // W2 + auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H); + + if (!apply_router_weight_on_input) { + y = y.mul(gates_e.unsqueeze(1)); + } + + // Store per-expert result + y_list[e] = y; + } + }); + + // Concatenate all expert outputs to match expert_tokens order + auto Y_all = at::cat(y_list, /*dim=*/0); + auto out = at::zeros({T, H}, x.options()); + out = + at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all); + + return out; +} diff --git a/csrc/ops.h b/csrc/ops.h index fd9c55b94895..2ada7905da4b 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -328,6 +328,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const std::optional& has_initial_state, const torch::Tensor& ssm_states, int64_t pad_slot_id); +torch::Tensor dynamic_4bit_int_moe_cpu( + torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, + torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I, + int64_t I2, int64_t group_size, bool apply_router_weight_on_input, + int64_t activation_kind); + using fptr_t = int64_t; fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 0eec93601b3f..114f349538fb 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -98,13 +98,16 @@ def select_experts( e_score_correction_bias=e_score_correction_bias) elif custom_routing_function is None: assert scoring_func == "softmax" - topk_weights = torch.nn.functional.softmax(router_logits, - dim=1, - dtype=torch.float32) - topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1) + topk_logit_vals, topk_idx = torch.topk(router_logits, + k=top_k, + dim=-1, + sorted=False) if renormalize: - topk_weights /= topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids.to(torch.int32) + topk_vals = torch.softmax(topk_logit_vals, dim=-1) + else: + logZ = torch.logsumexp(router_logits, dim=-1, keepdim=True) + topk_vals = (topk_logit_vals - logZ).exp() + return topk_vals.to(torch.float32), topk_idx.to(torch.int32) else: return custom_routing_function(hidden_states=hidden_states, gating_output=router_logits, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2bf3bf96baf1..89e0cee08170 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -69,8 +69,6 @@ else: if is_rocm_aiter_moe_enabled(): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 rocm_aiter_grouped_topk as grouped_topk) -elif current_platform.is_cpu(): - pass else: from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk if current_platform.is_tpu(): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 10f9085be4d1..a7d3e920414d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config, int8_w8a16_moe_quant_config, nvfp4_moe_quant_config) +from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( is_valid_flashinfer_cutlass_fused_moe) from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa @@ -47,7 +48,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform +from vllm.platforms import CpuArchEnum, current_platform from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used @@ -63,7 +64,7 @@ __all__ = [ "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsW8A8Int8MoEMethod", "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", - "CompressedTensorsW4A4MoeMethod" + "CompressedTensorsW4A4MoeMethod", "CompressedTensorsW4A8Int8MoEMethod" ] @@ -139,6 +140,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Int8MoEMethod(quant_config, layer.moe_config) + elif quant_config._is_dynamic_token_w4a8_int(weight_quant, + input_quant): + return CompressedTensorsW4A8Int8MoEMethod(quant_config, + layer.moe_config) else: raise RuntimeError( f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") @@ -1769,3 +1774,301 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): expert_map=expert_map, quant_config=self.moe_quant_config, ) + + +class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod): + """ + CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform + - Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles) + - Scales: Fp32 for Channelwise , bf16 for groupwise quantization + - Bias: Same data type as original weights + - Activations: FP32/Bf16 dynamic per-token (A8 Int), + quantized inside the kernel + """ + + def __init__( + self, + quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + moe: FusedMoEConfig): + super().__init__(moe) + self.has_bias = self.moe.has_bias + self.quant_config = quant_config + + # Validate scheme: weights=W4 (channel or group), + # activations=dynamic TOKEN (A8) + wq = self.quant_config.target_scheme_map["Linear"].get("weights") + aq = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + # Must be dynamic per-token activations + if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic: + raise ValueError( + "W4A8-int MoE needs dynamic per-token activation quantization." + ) + + # Weight can be channel-wise (group_size=None) or group-wise + self.group_size = wq.group_size if (wq.group_size is not None) else -1 + if wq.num_bits != 4: + raise ValueError( + "This method only supports 4-bit weights (num_bits=4).") + + # CPU only + if not current_platform.is_cpu(): + raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.") + + # Arm: check _dyn ops availability + if current_platform.get_cpu_architecture() == CpuArchEnum.ARM: + try: + _ = torch.ops.aten._dyn_quant_matmul_4bit + _ = torch.ops.aten._dyn_quant_pack_4bit_weight + except AttributeError as err: + raise RuntimeError( + f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops; + install a newer build.""") from err + self.static_input_scales = False # always dynamic per token + + # ---- parameter creation ---- + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Shapes per local rank (TP/EP): + # w13: [E, 2*I_local, H] int8 (int4 values in [-8,7]) + # w2 : [E, H, I_local] int8 + # Scales: + # channel-wise: group_size=-1 -> per-output-row, single scale per row + # group-wise : group_size=g -> + # per-output-row, (in_features/g) scales + + E = num_experts + H = hidden_size + IN = intermediate_size_per_partition + g = self.group_size + + # Per-row scale columns + def _n_scale_cols(in_features: int) -> int: + return 1 if g == -1 else (in_features // g) + + # Register unpacked int4-as-int8 weights the loader will fill. + w13 = torch.nn.Parameter(torch.empty(E, 2 * IN, H, dtype=torch.int8), + requires_grad=False) + set_weight_attrs(w13, extra_weight_attrs) + layer.register_parameter("w13_weight", w13) + + w2 = torch.nn.Parameter(torch.empty(E, H, IN, dtype=torch.int8), + requires_grad=False) + set_weight_attrs(w2, extra_weight_attrs) + layer.register_parameter("w2_weight", w2) + + # Register scales + # KleidiAI groupwise kernels accepts float32 scales + # KleidiAI groupwise kernels accepts bfloat16 scales + scale_dtype = torch.float32 if g == -1 else torch.bfloat16 + + w13_s = torch.nn.Parameter(torch.ones(E, + 2 * IN, + _n_scale_cols(H), + dtype=scale_dtype), + requires_grad=False) + set_weight_attrs( + w13_s, { + "quant_method": "channel" if g == -1 else "group", + **extra_weight_attrs + }) + layer.register_parameter("w13_weight_scale", w13_s) + + w2_s = torch.nn.Parameter(torch.ones(E, + H, + _n_scale_cols(IN), + dtype=scale_dtype), + requires_grad=False) + set_weight_attrs( + w2_s, { + "quant_method": "channel" if g == -1 else "group", + **extra_weight_attrs + }) + layer.register_parameter("w2_weight_scale", w2_s) + + if self.has_bias: + w13_bias = torch.nn.Parameter(torch.zeros(E, + 2 * IN, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter(torch.zeros(num_experts, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + # Placeholders for packed weights (will be replaced after packing) + layer.register_parameter( + "w13_weight_packed", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs) + + layer.register_parameter( + "w2_weight_packed", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs) + + # dims for 4 bit fused matmuls + layer.w13_in_features = H + layer.w13_out_features = 2 * IN + layer.w2_in_features = IN + layer.w2_out_features = H + layer.group_size = g + + # post-load packing to dyn-4bit KleidiAI kernel's format + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + E = layer.w13_weight.shape[0] + H = layer.w13_in_features + I2 = layer.w13_out_features + IN = layer.w2_in_features + g = layer.group_size + + def _pack_matrix(int4_as_int8_2d: torch.Tensor, + scales_2d: torch.Tensor, + bias_1d: Optional[torch.Tensor], in_features: int, + out_features: int) -> torch.Tensor: + # int4 values are stored as int8 in [-8,7]. + # Shift to unsigned nibble and pack pairs along input-dim. + tmp = int4_as_int8_2d.add(8) # [out, in] + uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to( + torch.uint8) # [out, in//2] + + # KleidiAI groupwise kernels accepts float32 scales + # KleidiAI groupwise kernels accepts bfloat16 scales + scale_dtype = torch.float32 if g == -1 else torch.bfloat16 + scales = scales_2d.to(scale_dtype) + bias = None if bias_1d is None else bias_1d.to(torch.float32) + return torch.ops.aten._dyn_quant_pack_4bit_weight( + uint8_nibbles, scales, bias, g if g != -1 else in_features, + in_features, out_features) + + # Pack per expert + w13_packed_list = [] + w2_packed_list = [] + + has_w13_bias = hasattr(layer, + "w13_bias") and layer.w13_bias is not None + has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None + + for e in range(E): + w13_packed_list.append( + _pack_matrix( + layer.w13_weight[e], # [2I, H] + layer.w13_weight_scale[e], # [2I, H/g or 1] + layer.w13_bias[e] if has_w13_bias else None, # [2I] + H, + I2)) + w2_packed_list.append( + _pack_matrix( + # w2 shape is [H, IN]; we need [out, in] == [H, IN]. + layer.w2_weight[e], # [H, IN] + layer.w2_weight_scale[e], # [H, IN/g or 1] + layer.w2_bias[e] if has_w2_bias else None, # [H] + IN, + layer.w2_out_features # in_features=IN, out_features=H + )) + + # each packed tensor has identical shape per expert; stack on dim 0 + w13_packed = torch.stack(w13_packed_list, dim=0) + w2_packed = torch.stack(w2_packed_list, dim=0) + + replace_parameter(layer, "w13_weight_packed", + torch.nn.Parameter(w13_packed, requires_grad=False)) + replace_parameter(layer, "w2_weight_packed", + torch.nn.Parameter(w2_packed, requires_grad=False)) + + # free raw tensors/scales/bias now that they're packed into the payload. + replace_parameter( + layer, "w13_weight", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + replace_parameter( + layer, "w2_weight", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + replace_parameter( + layer, "w13_weight_scale", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + replace_parameter( + layer, "w2_weight_scale", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + if has_w13_bias: + replace_parameter( + layer, "w13_bias", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + if has_w2_bias: + replace_parameter( + layer, "w2_bias", + torch.nn.Parameter(torch.empty(0), requires_grad=False)) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: + # CPU dynamic 4-bit MoE path does not use modular kernels or + # fused_experts; quant config is not needed. + return None + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet." + assert activation in ( + "silu", "swigluoai", + "swiglu"), "Only SiLU/SwiGLUGU/SwiGLUUG are supported." + assert expert_map is None, """expert_map/EP not implemented + for CPU dyn-4bit MoE.""" + + def _act_kind(s: str) -> int: + # 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU + if s == "swiglu": + return 0 + if s == "swigluoai": + return 1 + if s == "silu": + return 2 + raise ValueError(f"Unknown activation '{s}'") + + # Apply topk softmax on router output + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + ) + + return torch.ops._C.dynamic_4bit_int_moe( + x, topk_ids.to(torch.long), topk_weights, layer.w13_weight_packed, + layer.w2_weight_packed, layer.w2_out_features, + layer.w2_in_features, layer.w13_out_features, layer.group_size, + apply_router_weight_on_input, int(_act_kind(activation))) \ No newline at end of file