mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-30 21:07:09 +08:00
Merge 4f69e85bd9bb538e070a689db18240c7995fdad2 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
bcc71a532a
@ -538,7 +538,259 @@ def fused_moe_kernel(
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(
|
||||
def invoke_fused_moe_triton_kernel_wna16(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: dict[str, Any],
|
||||
block_shape: list[int],
|
||||
):
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
|
||||
M = A.size(0)
|
||||
num_tokens = M * top_k
|
||||
bit = 4
|
||||
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(
|
||||
config=config,
|
||||
use_moe_wna16_cuda=True,
|
||||
num_valid_tokens=num_tokens,
|
||||
size_k=A.size(1),
|
||||
size_n=B.size(1),
|
||||
num_experts=B.size(1),
|
||||
group_size=block_shape[1],
|
||||
real_top_k=top_k,
|
||||
block_size_m=config["BLOCK_SIZE_M"],
|
||||
)
|
||||
)
|
||||
|
||||
ops.moe_wna16_gemm(
|
||||
A,
|
||||
C,
|
||||
B,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights if mul_routed_weight else None,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
top_k,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
bit,
|
||||
)
|
||||
|
||||
|
||||
def invoke_fused_moe_triton_kernel_gptq_awq(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
B_scale: torch.Tensor | None,
|
||||
B_zp: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
block_shape: list[int],
|
||||
):
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
|
||||
M = A.size(0)
|
||||
num_tokens = M * top_k
|
||||
|
||||
EM = sorted_token_ids.size(0)
|
||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique,
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(
|
||||
config=config,
|
||||
use_moe_wna16_cuda=False,
|
||||
num_valid_tokens=num_tokens,
|
||||
size_k=A.size(1),
|
||||
size_n=B.size(1),
|
||||
num_experts=B.size(1),
|
||||
group_size=block_shape[1],
|
||||
real_top_k=top_k,
|
||||
block_size_m=config["BLOCK_SIZE_M"],
|
||||
)
|
||||
)
|
||||
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.size(1),
|
||||
A.size(1),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
B_scale.stride(0),
|
||||
B_scale.stride(2),
|
||||
B_scale.stride(1),
|
||||
B_zp.stride(0) if B_zp is not None else 0,
|
||||
B_zp.stride(2) if B_zp is not None else 0,
|
||||
B_zp.stride(1) if B_zp is not None else 0,
|
||||
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
||||
group_size=block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
has_zp=B_zp is not None,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
def invoke_fused_moe_triton_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: torch.Tensor | None,
|
||||
B_scale: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor | None,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
mul_routed_weight: bool,
|
||||
top_k: int,
|
||||
config: dict[str, Any],
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
block_shape: list[int] | None = None,
|
||||
B_bias: torch.Tensor | None = None,
|
||||
):
|
||||
assert topk_weights is not None or not mul_routed_weight
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or triton.cdiv(
|
||||
B.size(-2), block_shape[0]
|
||||
) == B_scale.size(-2)
|
||||
assert block_shape is None or triton.cdiv(
|
||||
B.size(-1), block_shape[1]
|
||||
) == B_scale.size(-1)
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
M = A.size(0)
|
||||
num_tokens = M * top_k
|
||||
|
||||
EM = sorted_token_ids.size(0)
|
||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique,
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
HAS_BIAS = B_bias is not None
|
||||
|
||||
config = config.copy()
|
||||
config["SPLIT_K"] = 1
|
||||
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
||||
if block_shape is not None:
|
||||
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_bias,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.size(1),
|
||||
B.size(2),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_bias.stride(0) if B_bias is not None else 0,
|
||||
B_bias.stride(1) if B_bias is not None else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
HAS_BIAS=HAS_BIAS,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
def dispatch_fused_moe_kernel(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
@ -565,44 +817,13 @@ def invoke_fused_moe_kernel(
|
||||
assert topk_weights is None or topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if use_fp8_w8a8 or use_int8_w8a8:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or triton.cdiv(
|
||||
B.size(-2), block_shape[0]
|
||||
) == B_scale.size(-2)
|
||||
assert block_shape is None or triton.cdiv(
|
||||
B.size(-1), block_shape[1]
|
||||
) == B_scale.size(-1)
|
||||
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
M = A.size(0)
|
||||
num_tokens = M * top_k
|
||||
|
||||
EM = sorted_token_ids.size(0)
|
||||
if A.size(0) < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique,
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"])
|
||||
grid = lambda META: (
|
||||
triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
HAS_BIAS = B_bias is not None
|
||||
if (
|
||||
(use_int8_w8a16 or use_int4_w4a16)
|
||||
and block_shape is not None
|
||||
and block_shape[1] > 0
|
||||
if (use_int8_w8a16 or use_int4_w4a16) and (
|
||||
block_shape is not None and block_shape[1] > 0
|
||||
):
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
assert B_bias is None
|
||||
|
||||
use_moe_wna16_cuda = should_moe_wna16_use_cuda(
|
||||
num_valid_tokens=num_tokens,
|
||||
@ -610,41 +831,25 @@ def invoke_fused_moe_kernel(
|
||||
num_experts=B.size(0),
|
||||
bit=4 if use_int4_w4a16 else 8,
|
||||
)
|
||||
config = config.copy()
|
||||
config.update(
|
||||
get_moe_wna16_block_config(
|
||||
config=config,
|
||||
use_moe_wna16_cuda=use_moe_wna16_cuda,
|
||||
num_valid_tokens=num_tokens,
|
||||
size_k=A.size(1),
|
||||
size_n=B.size(1),
|
||||
num_experts=B.size(1),
|
||||
group_size=block_shape[1],
|
||||
real_top_k=top_k,
|
||||
block_size_m=config["BLOCK_SIZE_M"],
|
||||
)
|
||||
)
|
||||
|
||||
if use_moe_wna16_cuda:
|
||||
bit = 4 if use_int4_w4a16 else 8
|
||||
ops.moe_wna16_gemm(
|
||||
invoke_fused_moe_triton_kernel_wna16(
|
||||
A,
|
||||
C,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights if mul_routed_weight else None,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
mul_routed_weight,
|
||||
top_k,
|
||||
config["BLOCK_SIZE_M"],
|
||||
config["BLOCK_SIZE_N"],
|
||||
config["BLOCK_SIZE_K"],
|
||||
bit,
|
||||
config,
|
||||
block_shape,
|
||||
)
|
||||
return
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
invoke_fused_moe_triton_kernel_gptq_awq(
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
@ -654,80 +859,37 @@ def invoke_fused_moe_kernel(
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.size(1),
|
||||
A.size(1),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
B_scale.stride(0),
|
||||
B_scale.stride(2),
|
||||
B_scale.stride(1),
|
||||
B_zp.stride(0) if B_zp is not None else 0,
|
||||
B_zp.stride(2) if B_zp is not None else 0,
|
||||
B_zp.stride(1) if B_zp is not None else 0,
|
||||
block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0,
|
||||
group_size=block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
has_zp=B_zp is not None,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
mul_routed_weight,
|
||||
top_k,
|
||||
config,
|
||||
compute_type,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
else:
|
||||
config = config.copy()
|
||||
config["SPLIT_K"] = 1
|
||||
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
|
||||
if block_shape is not None:
|
||||
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1]))
|
||||
fused_moe_kernel[grid](
|
||||
invoke_fused_moe_triton_kernel(
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_bias,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.size(1),
|
||||
B.size(2),
|
||||
EM,
|
||||
num_tokens,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_bias.stride(0) if B_bias is not None else 0,
|
||||
B_bias.stride(1) if B_bias is not None else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
HAS_BIAS=HAS_BIAS,
|
||||
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
||||
**config,
|
||||
mul_routed_weight,
|
||||
top_k,
|
||||
config,
|
||||
compute_type,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
per_channel_quant,
|
||||
block_shape,
|
||||
B_bias,
|
||||
)
|
||||
|
||||
|
||||
@ -1675,6 +1837,7 @@ def fused_experts(
|
||||
quant_config: FusedMoEQuantConfig | None = None,
|
||||
allow_deep_gemm: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# import pdb; pdb.set_trace()
|
||||
if quant_config is None:
|
||||
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
|
||||
|
||||
@ -1966,7 +2129,7 @@ def fused_experts_impl(
|
||||
ignore_invalid_experts=True,
|
||||
)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
dispatch_fused_moe_kernel(
|
||||
qcurr_hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
@ -2025,7 +2188,7 @@ def fused_experts_impl(
|
||||
if expert_map is not None:
|
||||
intermediate_cache3.zero_()
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
dispatch_fused_moe_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
@ -2176,13 +2339,12 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
|
||||
)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
invoke_fused_moe_triton_kernel(
|
||||
hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1q_scale,
|
||||
self.w1_scale,
|
||||
self.w1_zp,
|
||||
None, # topk_weights
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
@ -2214,13 +2376,12 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.block_shape,
|
||||
)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
invoke_fused_moe_triton_kernel(
|
||||
qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2q_scale,
|
||||
self.w2_scale,
|
||||
self.w2_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user