diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 107fc0e15a880..6a515d302654a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -537,6 +537,7 @@ def fused_moe_kernel( c_mask = token_mask[:, None] & (offs_cn[None, :] < N) tl.store(c_ptrs, accumulator, mask=c_mask) + def invoke_fused_moe_triton_kernel_wna16( A: torch.Tensor, B: torch.Tensor, @@ -550,12 +551,12 @@ def invoke_fused_moe_triton_kernel_wna16( mul_routed_weight: bool, top_k: int, config: dict[str, Any], - block_shape: list[int] | None = None, + block_shape: list[int], ): assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 - assert block_shape[0] == 0 - + assert block_shape is None or block_shape[0] == 0 + M = A.size(0) num_tokens = M * top_k bit = 4 @@ -592,6 +593,7 @@ def invoke_fused_moe_triton_kernel_wna16( bit, ) + def invoke_fused_moe_triton_kernel_gptq_awq( A: torch.Tensor, B: torch.Tensor, @@ -608,11 +610,11 @@ def invoke_fused_moe_triton_kernel_gptq_awq( compute_type: tl.dtype, use_int8_w8a16: bool, use_int4_w4a16: bool, - block_shape: list[int] | None = None, + block_shape: list[int], ): assert B_scale is not None and B_scale.ndim == 3 assert B_zp is None or B_zp.ndim == 3 - assert block_shape[0] == 0 + assert block_shape is None or block_shape[0] == 0 M = A.size(0) num_tokens = M * top_k @@ -642,7 +644,7 @@ def invoke_fused_moe_triton_kernel_gptq_awq( block_size_m=config["BLOCK_SIZE_M"], ) ) - + fused_moe_kernel_gptq_awq[grid]( A, B, @@ -681,6 +683,7 @@ def invoke_fused_moe_triton_kernel_gptq_awq( **config, ) + def invoke_fused_moe_triton_kernel( A: torch.Tensor, B: torch.Tensor, @@ -717,7 +720,7 @@ def invoke_fused_moe_triton_kernel( else: assert A_scale is None assert B_scale is None - + M = A.size(0) num_tokens = M * top_k @@ -782,6 +785,7 @@ def invoke_fused_moe_triton_kernel( **config, ) + def dispatch_fused_moe_kernel( A: torch.Tensor, B: torch.Tensor, @@ -812,7 +816,9 @@ def dispatch_fused_moe_kernel( M = A.size(0) num_tokens = M * top_k - if ((use_int8_w8a16 or use_int4_w4a16) and (block_shape is not None and block_shape[1] > 0)): + if (use_int8_w8a16 or use_int4_w4a16) and ( + block_shape is not None and block_shape[1] > 0 + ): assert B_bias is None use_moe_wna16_cuda = should_moe_wna16_use_cuda( @@ -821,9 +827,9 @@ def dispatch_fused_moe_kernel( num_experts=B.size(0), bit=4 if use_int4_w4a16 else 8, ) - + if use_moe_wna16_cuda: - invoke_fused_moe_triton_kernel_gptq_awq( + invoke_fused_moe_triton_kernel_wna16( A, B, C, @@ -857,7 +863,7 @@ def dispatch_fused_moe_kernel( use_int4_w4a16, block_shape, ) - + else: invoke_fused_moe_triton_kernel( A, diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 781345804ec82..82dbccf3fa9da 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -31,10 +31,6 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, -) if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts