diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bf51554341607..d229b2db880e2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -538,7 +538,259 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) -def invoke_fused_moe_kernel( +def invoke_fused_moe_triton_kernel_wna16( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + B_scale: torch.Tensor | None, + B_zp: torch.Tensor | None, + topk_weights: torch.Tensor | None, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict[str, Any], + block_shape: list[int], +): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + assert block_shape is None or block_shape[0] == 0 + + M = A.size(0) + num_tokens = M * top_k + bit = 4 + + config = config.copy() + config.update( + get_moe_wna16_block_config( + config=config, + use_moe_wna16_cuda=True, + num_valid_tokens=num_tokens, + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), + group_size=block_shape[1], + real_top_k=top_k, + block_size_m=config["BLOCK_SIZE_M"], + ) + ) + + ops.moe_wna16_gemm( + A, + C, + B, + B_scale, + B_zp, + topk_weights if mul_routed_weight else None, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + bit, + ) + + +def invoke_fused_moe_triton_kernel_gptq_awq( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + B_scale: torch.Tensor | None, + B_zp: torch.Tensor | None, + topk_weights: torch.Tensor | None, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict[str, Any], + compute_type: tl.dtype, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + block_shape: list[int], +): + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + assert block_shape is None or block_shape[0] == 0 + + M = A.size(0) + num_tokens = M * top_k + + EM = sorted_token_ids.size(0) + if A.size(0) < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) + grid = lambda META: ( + triton.cdiv(EM, META["BLOCK_SIZE_M"]) + * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), + ) + config = config.copy() + config.update( + get_moe_wna16_block_config( + config=config, + use_moe_wna16_cuda=False, + num_valid_tokens=num_tokens, + size_k=A.size(1), + size_n=B.size(1), + num_experts=B.size(1), + group_size=block_shape[1], + real_top_k=top_k, + block_size_m=config["BLOCK_SIZE_M"], + ) + ) + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.size(1), + A.size(1), + EM, + num_tokens, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + +def invoke_fused_moe_triton_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: torch.Tensor | None, + B_scale: torch.Tensor | None, + topk_weights: torch.Tensor | None, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: dict[str, Any], + compute_type: tl.dtype, + use_fp8_w8a8: bool, + use_int8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + per_channel_quant: bool, + block_shape: list[int] | None = None, + B_bias: torch.Tensor | None = None, +): + assert topk_weights is not None or not mul_routed_weight + assert topk_weights is None or topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + if use_fp8_w8a8 or use_int8_w8a8: + assert B_scale is not None + assert block_shape is None or triton.cdiv( + B.size(-2), block_shape[0] + ) == B_scale.size(-2) + assert block_shape is None or triton.cdiv( + B.size(-1), block_shape[1] + ) == B_scale.size(-1) + elif use_int8_w8a16 or use_int4_w4a16: + assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 + else: + assert A_scale is None + assert B_scale is None + + M = A.size(0) + num_tokens = M * top_k + + EM = sorted_token_ids.size(0) + if A.size(0) < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) + grid = lambda META: ( + triton.cdiv(EM, META["BLOCK_SIZE_M"]) + * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), + ) + HAS_BIAS = B_bias is not None + + config = config.copy() + config["SPLIT_K"] = 1 + BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") + if block_shape is not None: + BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) + fused_moe_kernel[grid]( + A, + B, + C, + B_bias, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.size(1), + B.size(2), + EM, + num_tokens, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_bias.stride(0) if B_bias is not None else 0, + B_bias.stride(1) if B_bias is not None else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, + use_int8_w8a16=use_int8_w8a16, + per_channel_quant=per_channel_quant, + HAS_BIAS=HAS_BIAS, + BLOCK_SIZE_K=BLOCK_SIZE_K, + **config, + ) + + +def dispatch_fused_moe_kernel( A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, @@ -565,44 +817,13 @@ def invoke_fused_moe_kernel( assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 - if use_fp8_w8a8 or use_int8_w8a8: - assert B_scale is not None - assert block_shape is None or triton.cdiv( - B.size(-2), block_shape[0] - ) == B_scale.size(-2) - assert block_shape is None or triton.cdiv( - B.size(-1), block_shape[1] - ) == B_scale.size(-1) - - elif use_int8_w8a16 or use_int4_w4a16: - assert B_scale is not None - assert block_shape is None or block_shape[0] == 0 - else: - assert A_scale is None - assert B_scale is None - M = A.size(0) num_tokens = M * top_k - EM = sorted_token_ids.size(0) - if A.size(0) < config["BLOCK_SIZE_M"]: - # optimize for small batch_size. - # We assume that top_ids of each token is unique, - # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, - # and we can skip some invalid blocks. - EM = min(sorted_token_ids.size(0), A.size(0) * top_k * config["BLOCK_SIZE_M"]) - grid = lambda META: ( - triton.cdiv(EM, META["BLOCK_SIZE_M"]) - * triton.cdiv(B.size(1), META["BLOCK_SIZE_N"]), - ) - HAS_BIAS = B_bias is not None - if ( - (use_int8_w8a16 or use_int4_w4a16) - and block_shape is not None - and block_shape[1] > 0 + if (use_int8_w8a16 or use_int4_w4a16) and ( + block_shape is not None and block_shape[1] > 0 ): - assert B_scale is not None and B_scale.ndim == 3 - assert B_zp is None or B_zp.ndim == 3 + assert B_bias is None use_moe_wna16_cuda = should_moe_wna16_use_cuda( num_valid_tokens=num_tokens, @@ -610,41 +831,25 @@ def invoke_fused_moe_kernel( num_experts=B.size(0), bit=4 if use_int4_w4a16 else 8, ) - config = config.copy() - config.update( - get_moe_wna16_block_config( - config=config, - use_moe_wna16_cuda=use_moe_wna16_cuda, - num_valid_tokens=num_tokens, - size_k=A.size(1), - size_n=B.size(1), - num_experts=B.size(1), - group_size=block_shape[1], - real_top_k=top_k, - block_size_m=config["BLOCK_SIZE_M"], - ) - ) if use_moe_wna16_cuda: - bit = 4 if use_int4_w4a16 else 8 - ops.moe_wna16_gemm( + invoke_fused_moe_triton_kernel_wna16( A, - C, B, + C, B_scale, B_zp, - topk_weights if mul_routed_weight else None, + topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, + mul_routed_weight, top_k, - config["BLOCK_SIZE_M"], - config["BLOCK_SIZE_N"], - config["BLOCK_SIZE_K"], - bit, + config, + block_shape, ) return - fused_moe_kernel_gptq_awq[grid]( + invoke_fused_moe_triton_kernel_gptq_awq( A, B, C, @@ -654,80 +859,37 @@ def invoke_fused_moe_kernel( sorted_token_ids, expert_ids, num_tokens_post_padded, - B.size(1), - A.size(1), - EM, - num_tokens, - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - B_scale.stride(0), - B_scale.stride(2), - B_scale.stride(1), - B_zp.stride(0) if B_zp is not None else 0, - B_zp.stride(2) if B_zp is not None else 0, - B_zp.stride(1) if B_zp is not None else 0, - block_k_diviable=A.size(1) % config["BLOCK_SIZE_K"] == 0, - group_size=block_shape[1], - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - has_zp=B_zp is not None, - use_int4_w4a16=use_int4_w4a16, - use_int8_w8a16=use_int8_w8a16, - **config, + mul_routed_weight, + top_k, + config, + compute_type, + use_int8_w8a16, + use_int4_w4a16, + block_shape, ) + else: - config = config.copy() - config["SPLIT_K"] = 1 - BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K") - if block_shape is not None: - BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0], block_shape[1])) - fused_moe_kernel[grid]( + invoke_fused_moe_triton_kernel( A, B, C, - B_bias, A_scale, B_scale, topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, - B.size(1), - B.size(2), - EM, - num_tokens, - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, - A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, - B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, - B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, - B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, - B_bias.stride(0) if B_bias is not None else 0, - B_bias.stride(1) if B_bias is not None else 0, - 0 if block_shape is None else block_shape[0], - 0 if block_shape is None else block_shape[1], - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a8=use_int8_w8a8, - use_int8_w8a16=use_int8_w8a16, - per_channel_quant=per_channel_quant, - HAS_BIAS=HAS_BIAS, - BLOCK_SIZE_K=BLOCK_SIZE_K, - **config, + mul_routed_weight, + top_k, + config, + compute_type, + use_fp8_w8a8, + use_int8_w8a8, + use_int8_w8a16, + use_int4_w4a16, + per_channel_quant, + block_shape, + B_bias, ) @@ -1675,6 +1837,7 @@ def fused_experts( quant_config: FusedMoEQuantConfig | None = None, allow_deep_gemm: bool = False, ) -> torch.Tensor: + # import pdb; pdb.set_trace() if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG @@ -1966,7 +2129,7 @@ def fused_experts_impl( ignore_invalid_experts=True, ) - invoke_fused_moe_kernel( + dispatch_fused_moe_kernel( qcurr_hidden_states, w1, intermediate_cache1, @@ -2025,7 +2188,7 @@ def fused_experts_impl( if expert_map is not None: intermediate_cache3.zero_() - invoke_fused_moe_kernel( + dispatch_fused_moe_kernel( qintermediate_cache2, w2, intermediate_cache3, @@ -2176,13 +2339,12 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map ) - invoke_fused_moe_kernel( + invoke_fused_moe_triton_kernel( hidden_states, w1, intermediate_cache1, a1q_scale, self.w1_scale, - self.w1_zp, None, # topk_weights sorted_token_ids, expert_ids, @@ -2214,13 +2376,12 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): self.block_shape, ) - invoke_fused_moe_kernel( + invoke_fused_moe_triton_kernel( qintermediate_cache2, w2, intermediate_cache3, a2q_scale, self.w2_scale, - self.w2_zp, topk_weights, sorted_token_ids, expert_ids,