pre-commit

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Yongye Zhu 2025-12-20 00:21:11 +00:00
parent f13eb40d18
commit fb567f60d0
2 changed files with 17 additions and 15 deletions

View File

@ -537,6 +537,7 @@ def fused_moe_kernel(
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
def invoke_fused_moe_triton_kernel_wna16( def invoke_fused_moe_triton_kernel_wna16(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
@ -550,12 +551,12 @@ def invoke_fused_moe_triton_kernel_wna16(
mul_routed_weight: bool, mul_routed_weight: bool,
top_k: int, top_k: int,
config: dict[str, Any], config: dict[str, Any],
block_shape: list[int] | None = None, block_shape: list[int],
): ):
assert B_scale is not None and B_scale.ndim == 3 assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3 assert B_zp is None or B_zp.ndim == 3
assert block_shape[0] == 0 assert block_shape is None or block_shape[0] == 0
M = A.size(0) M = A.size(0)
num_tokens = M * top_k num_tokens = M * top_k
bit = 4 bit = 4
@ -592,6 +593,7 @@ def invoke_fused_moe_triton_kernel_wna16(
bit, bit,
) )
def invoke_fused_moe_triton_kernel_gptq_awq( def invoke_fused_moe_triton_kernel_gptq_awq(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
@ -608,11 +610,11 @@ def invoke_fused_moe_triton_kernel_gptq_awq(
compute_type: tl.dtype, compute_type: tl.dtype,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
block_shape: list[int] | None = None, block_shape: list[int],
): ):
assert B_scale is not None and B_scale.ndim == 3 assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3 assert B_zp is None or B_zp.ndim == 3
assert block_shape[0] == 0 assert block_shape is None or block_shape[0] == 0
M = A.size(0) M = A.size(0)
num_tokens = M * top_k num_tokens = M * top_k
@ -642,7 +644,7 @@ def invoke_fused_moe_triton_kernel_gptq_awq(
block_size_m=config["BLOCK_SIZE_M"], block_size_m=config["BLOCK_SIZE_M"],
) )
) )
fused_moe_kernel_gptq_awq[grid]( fused_moe_kernel_gptq_awq[grid](
A, A,
B, B,
@ -681,6 +683,7 @@ def invoke_fused_moe_triton_kernel_gptq_awq(
**config, **config,
) )
def invoke_fused_moe_triton_kernel( def invoke_fused_moe_triton_kernel(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
@ -717,7 +720,7 @@ def invoke_fused_moe_triton_kernel(
else: else:
assert A_scale is None assert A_scale is None
assert B_scale is None assert B_scale is None
M = A.size(0) M = A.size(0)
num_tokens = M * top_k num_tokens = M * top_k
@ -782,6 +785,7 @@ def invoke_fused_moe_triton_kernel(
**config, **config,
) )
def dispatch_fused_moe_kernel( def dispatch_fused_moe_kernel(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
@ -812,7 +816,9 @@ def dispatch_fused_moe_kernel(
M = A.size(0) M = A.size(0)
num_tokens = M * top_k num_tokens = M * top_k
if ((use_int8_w8a16 or use_int4_w4a16) and (block_shape is not None and block_shape[1] > 0)): if (use_int8_w8a16 or use_int4_w4a16) and (
block_shape is not None and block_shape[1] > 0
):
assert B_bias is None assert B_bias is None
use_moe_wna16_cuda = should_moe_wna16_use_cuda( use_moe_wna16_cuda = should_moe_wna16_use_cuda(
@ -821,9 +827,9 @@ def dispatch_fused_moe_kernel(
num_experts=B.size(0), num_experts=B.size(0),
bit=4 if use_int4_w4a16 else 8, bit=4 if use_int4_w4a16 else 8,
) )
if use_moe_wna16_cuda: if use_moe_wna16_cuda:
invoke_fused_moe_triton_kernel_gptq_awq( invoke_fused_moe_triton_kernel_wna16(
A, A,
B, B,
C, C,
@ -857,7 +863,7 @@ def dispatch_fused_moe_kernel(
use_int4_w4a16, use_int4_w4a16,
block_shape, block_shape,
) )
else: else:
invoke_fused_moe_triton_kernel( invoke_fused_moe_triton_kernel(
A, A,

View File

@ -31,10 +31,6 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts from .fused_batched_moe import BatchedTritonExperts