mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 02:57:01 +08:00
pre-commit
Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
parent
f13eb40d18
commit
fb567f60d0
@ -537,6 +537,7 @@ def fused_moe_kernel(
|
|||||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
def invoke_fused_moe_triton_kernel_wna16(
|
def invoke_fused_moe_triton_kernel_wna16(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
@ -550,12 +551,12 @@ def invoke_fused_moe_triton_kernel_wna16(
|
|||||||
mul_routed_weight: bool,
|
mul_routed_weight: bool,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
config: dict[str, Any],
|
config: dict[str, Any],
|
||||||
block_shape: list[int] | None = None,
|
block_shape: list[int],
|
||||||
):
|
):
|
||||||
assert B_scale is not None and B_scale.ndim == 3
|
assert B_scale is not None and B_scale.ndim == 3
|
||||||
assert B_zp is None or B_zp.ndim == 3
|
assert B_zp is None or B_zp.ndim == 3
|
||||||
assert block_shape[0] == 0
|
assert block_shape is None or block_shape[0] == 0
|
||||||
|
|
||||||
M = A.size(0)
|
M = A.size(0)
|
||||||
num_tokens = M * top_k
|
num_tokens = M * top_k
|
||||||
bit = 4
|
bit = 4
|
||||||
@ -592,6 +593,7 @@ def invoke_fused_moe_triton_kernel_wna16(
|
|||||||
bit,
|
bit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def invoke_fused_moe_triton_kernel_gptq_awq(
|
def invoke_fused_moe_triton_kernel_gptq_awq(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
@ -608,11 +610,11 @@ def invoke_fused_moe_triton_kernel_gptq_awq(
|
|||||||
compute_type: tl.dtype,
|
compute_type: tl.dtype,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
use_int4_w4a16: bool,
|
use_int4_w4a16: bool,
|
||||||
block_shape: list[int] | None = None,
|
block_shape: list[int],
|
||||||
):
|
):
|
||||||
assert B_scale is not None and B_scale.ndim == 3
|
assert B_scale is not None and B_scale.ndim == 3
|
||||||
assert B_zp is None or B_zp.ndim == 3
|
assert B_zp is None or B_zp.ndim == 3
|
||||||
assert block_shape[0] == 0
|
assert block_shape is None or block_shape[0] == 0
|
||||||
|
|
||||||
M = A.size(0)
|
M = A.size(0)
|
||||||
num_tokens = M * top_k
|
num_tokens = M * top_k
|
||||||
@ -642,7 +644,7 @@ def invoke_fused_moe_triton_kernel_gptq_awq(
|
|||||||
block_size_m=config["BLOCK_SIZE_M"],
|
block_size_m=config["BLOCK_SIZE_M"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
fused_moe_kernel_gptq_awq[grid](
|
fused_moe_kernel_gptq_awq[grid](
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
@ -681,6 +683,7 @@ def invoke_fused_moe_triton_kernel_gptq_awq(
|
|||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def invoke_fused_moe_triton_kernel(
|
def invoke_fused_moe_triton_kernel(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
@ -717,7 +720,7 @@ def invoke_fused_moe_triton_kernel(
|
|||||||
else:
|
else:
|
||||||
assert A_scale is None
|
assert A_scale is None
|
||||||
assert B_scale is None
|
assert B_scale is None
|
||||||
|
|
||||||
M = A.size(0)
|
M = A.size(0)
|
||||||
num_tokens = M * top_k
|
num_tokens = M * top_k
|
||||||
|
|
||||||
@ -782,6 +785,7 @@ def invoke_fused_moe_triton_kernel(
|
|||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def dispatch_fused_moe_kernel(
|
def dispatch_fused_moe_kernel(
|
||||||
A: torch.Tensor,
|
A: torch.Tensor,
|
||||||
B: torch.Tensor,
|
B: torch.Tensor,
|
||||||
@ -812,7 +816,9 @@ def dispatch_fused_moe_kernel(
|
|||||||
M = A.size(0)
|
M = A.size(0)
|
||||||
num_tokens = M * top_k
|
num_tokens = M * top_k
|
||||||
|
|
||||||
if ((use_int8_w8a16 or use_int4_w4a16) and (block_shape is not None and block_shape[1] > 0)):
|
if (use_int8_w8a16 or use_int4_w4a16) and (
|
||||||
|
block_shape is not None and block_shape[1] > 0
|
||||||
|
):
|
||||||
assert B_bias is None
|
assert B_bias is None
|
||||||
|
|
||||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||||
@ -821,9 +827,9 @@ def dispatch_fused_moe_kernel(
|
|||||||
num_experts=B.size(0),
|
num_experts=B.size(0),
|
||||||
bit=4 if use_int4_w4a16 else 8,
|
bit=4 if use_int4_w4a16 else 8,
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_moe_wna16_cuda:
|
if use_moe_wna16_cuda:
|
||||||
invoke_fused_moe_triton_kernel_gptq_awq(
|
invoke_fused_moe_triton_kernel_wna16(
|
||||||
A,
|
A,
|
||||||
B,
|
B,
|
||||||
C,
|
C,
|
||||||
@ -857,7 +863,7 @@ def dispatch_fused_moe_kernel(
|
|||||||
use_int4_w4a16,
|
use_int4_w4a16,
|
||||||
block_shape,
|
block_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
invoke_fused_moe_triton_kernel(
|
invoke_fused_moe_triton_kernel(
|
||||||
A,
|
A,
|
||||||
|
|||||||
@ -31,10 +31,6 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|
||||||
MoEPrepareAndFinalizeNoEP,
|
|
||||||
)
|
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .fused_batched_moe import BatchedTritonExperts
|
from .fused_batched_moe import BatchedTritonExperts
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user