From 27b78c73cad00f5c7bb3b2431f02dc680f7034bc Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Wed, 29 Jan 2025 22:07:09 +0800 Subject: [PATCH] [Kernel] add triton fused moe kernel for gptq/awq (#12185) --- tests/kernels/test_moe.py | 91 ++++ .../layers/fused_moe/fused_moe.py | 407 ++++++++++++++--- .../layers/quantization/__init__.py | 7 +- .../layers/quantization/moe_wna16.py | 424 ++++++++++++++++++ 4 files changed, 874 insertions(+), 55 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/moe_wna16.py diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 7fa5de198445..7aa248ed1475 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -18,6 +18,8 @@ from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( marlin_quantize) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + quantize_weights) from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform from vllm.scalar_type import scalar_types @@ -55,6 +57,95 @@ def test_fused_moe( rtol=0) +@pytest.mark.parametrize("m", [1, 32, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("group_size", [64, 128]) +@pytest.mark.parametrize("has_zp", [True, False]) +@pytest.mark.parametrize("weight_bits", [4, 8]) +def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype, group_size: int, has_zp: bool, + weight_bits: int): + print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + if weight_bits == 4: + pack_factor = 2 + quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8 + elif weight_bits == 8: + pack_factor = 1 + quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128 + + w1_ref = w1.clone() + w2_ref = w2.clone() + w1_qweight = torch.empty((e, 2 * n, k // pack_factor), + device="cuda", + dtype=torch.uint8) + w2_qweight = torch.empty((e, k, n // pack_factor), + device="cuda", + dtype=torch.uint8) + w1_scales = torch.empty((e, 2 * n, k // group_size), + device="cuda", + dtype=dtype) + w2_scales = torch.empty((e, k, n // group_size), + device="cuda", + dtype=dtype) + w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size), + device="cuda", + dtype=torch.uint8) + w2_qzeros = torch.empty((e, k // pack_factor, n // group_size), + device="cuda", + dtype=torch.uint8) + + for i in range(e * 2): + expert_id = i % e + if i // e == 0: + w, w_ref, w_qweight, w_scales, w_qzeros = \ + w1, w1_ref, w1_qweight, w1_scales, w1_qzeros + else: + w, w_ref, w_qweight, w_scales, w_qzeros = \ + w2, w2_ref, w2_qweight, w2_scales, w2_qzeros + weight, qweight, scales, qzeros = quantize_weights( + w[expert_id].T, quant_type, group_size, has_zp, False) + weight = weight.T + qweight = qweight.T.contiguous().to(torch.uint8) + scales = scales.T + if has_zp: + qzeros = qzeros.T.contiguous().to(torch.uint8) + if weight_bits == 4: + qweight = qweight[:, 1::2] * 16 + qweight[:, ::2] + if has_zp: + qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] + + w_ref[expert_id] = weight + w_qweight[expert_id] = qweight + w_scales[expert_id] = scales + if has_zp: + w_qzeros[expert_id] = qzeros + + triton_output = fused_moe(a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size]) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @torch.inference_mode() diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 308c1d6ac6db..dbb6c2ce4649 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -19,6 +19,206 @@ from vllm.utils import direct_register_custom_op logger = init_logger(__name__) +@triton.jit +def fused_moe_kernel_gptq_awq( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + b_scale_ptr, + b_zp_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N: tl.constexpr, + K: tl.constexpr, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsk, + stride_bsn, + stride_bze, + stride_bzk, + stride_bzn, + block_k_diviable: tl.constexpr, + group_size: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + has_zp: tl.constexpr, + use_int4_w4a16: tl.constexpr, + use_int8_w8a16: tl.constexpr): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to( + tl.int64) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) + + if use_int4_w4a16: + b_ptrs = b_ptr + off_experts * stride_be + \ + (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn + b_shifter = (offs_k[:, None] % 2) * 4 + elif use_int8_w8a16: + b_ptrs = b_ptr + off_experts * stride_be + \ + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + + if not has_zp and use_int4_w4a16: + b_zp_num = 8 + if not has_zp and use_int8_w8a16: + b_zp_num = 128 + elif has_zp and use_int4_w4a16: + b_zp_shifter = (offs_bn[None, :] % 2) * 4 + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + + if not block_k_diviable: + k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K + k_other = 0.0 + else: + k_mask = None + k_other = None + + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs) + if use_int4_w4a16: + b = (b >> b_shifter) & 0xF + + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \ + offs_bn[None, :] * stride_bsn + \ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk + b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other) + b_scale = b_scale.to(tl.float32) + + if has_zp and use_int4_w4a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ + (offs_bn[None, :] // 2) * stride_bzn + \ + offs_k_true * stride_bzk + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = ((b_zp >> b_zp_shifter) & 0xF) + b_zp = b_zp.to(tl.float32) + elif has_zp and use_int8_w8a16: + offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size + b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \ + offs_bn[None, :] * stride_bzn + \ + offs_k_true * stride_bzk + b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other) + b_zp = b_zp.to(tl.float32) + + # We accumulate along the K dimension. + if has_zp: + b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type) + else: + b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type) + accumulator = tl.dot(a, b, acc=accumulator) + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + if use_int4_w4a16: + b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk + else: + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @triton.jit def fused_moe_kernel( # Pointers to matrices @@ -266,6 +466,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, C: torch.Tensor, A_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor], + B_zp: Optional[torch.Tensor], topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, @@ -277,6 +478,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + use_int4_w4a16: bool, block_shape: Optional[List[int]] = None) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 @@ -292,50 +494,108 @@ def invoke_fused_moe_kernel(A: torch.Tensor, assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] - elif use_int8_w8a16: + elif use_int8_w8a16 or use_int4_w4a16: assert B_scale is not None + assert block_shape is None or block_shape[0] == 0 else: assert A_scale is None assert B_scale is None - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + EM = sorted_token_ids.shape[0] + if A.shape[0] < config["BLOCK_SIZE_M"]: + # optimize for small batch_size. + # We assume that top_ids of each token is unique, so + # so num_valid_experts <= batch_size <= BLOCK_SIZE_M, + # and we can skip some invalid blocks. + EM = min(sorted_token_ids.shape[0], + A.shape[0] * top_k * config['BLOCK_SIZE_M']) + grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( + B.shape[1], META['BLOCK_SIZE_N']), ) - fused_moe_kernel[grid]( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - B.shape[1], - B.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - A.stride(0), - A.stride(1), - B.stride(0), - B.stride(2), - B.stride(1), - C.stride(1), - C.stride(2), - A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, - A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, - B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, - B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, - B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, - 0 if block_shape is None else block_shape[0], - 0 if block_shape is None else block_shape[1], - MUL_ROUTED_WEIGHT=mul_routed_weight, - top_k=top_k, - compute_type=compute_type, - use_fp8_w8a8=use_fp8_w8a8, - use_int8_w8a16=use_int8_w8a16, - **config, - ) + if (use_int8_w8a16 or use_int4_w4a16) and \ + block_shape is not None and block_shape[1] > 0: + assert B_scale is not None and B_scale.ndim == 3 + assert B_zp is None or B_zp.ndim == 3 + + fused_moe_kernel_gptq_awq[grid]( + A, + B, + C, + B_scale, + B_zp, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(2), + B_scale.stride(1), + B_zp.stride(0) if B_zp is not None else 0, + B_zp.stride(2) if B_zp is not None else 0, + B_zp.stride(1) if B_zp is not None else 0, + block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0, + group_size=block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + has_zp=B_zp is not None, + use_int4_w4a16=use_int4_w4a16, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + else: + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + A.shape[1], + EM, + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + A_scale.stride(0) + if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) + if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) + if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) + if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) + if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: @@ -432,7 +692,7 @@ def try_get_optimal_moe_config( # NOTE: For block-wise quant, # BLOCK_K must be divisible by block_shape[1] # BLOCK_N and BLOCK_M has no requirements - if block_shape is not None: + if block_shape is not None and block_shape[0] != 0: config["BLOCK_SIZE_N"] = block_shape[0] config["BLOCK_SIZE_K"] = block_shape[1] return config @@ -531,12 +791,15 @@ def grouped_topk(hidden_states: torch.Tensor, def get_config_dtype_str(dtype: torch.dtype, + use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): if use_fp8_w8a8: return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" + elif use_int4_w4a16: + return "int4_w8a16" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -551,14 +814,17 @@ def inplace_fused_experts(hidden_states: torch.Tensor, topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> None: fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, - use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, - a1_scale, a2_scale, block_shape) + use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale, + w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) def inplace_fused_experts_fake( @@ -569,8 +835,11 @@ def inplace_fused_experts_fake( topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> None: @@ -593,14 +862,18 @@ def outplace_fused_experts( topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> torch.Tensor: return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, - False, use_fp8_w8a8, use_int8_w8a16, w1_scale, - w2_scale, a1_scale, a2_scale, block_shape) + False, use_fp8_w8a8, use_int8_w8a16, + use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, + a1_scale, a2_scale, block_shape) def outplace_fused_experts_fake( @@ -611,8 +884,11 @@ def outplace_fused_experts_fake( topk_ids: torch.Tensor, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None) -> torch.Tensor: @@ -635,8 +911,11 @@ def fused_experts(hidden_states: torch.Tensor, inplace: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None): @@ -644,16 +923,15 @@ def fused_experts(hidden_states: torch.Tensor, torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, use_int8_w8a16, - w1_scale, w2_scale, a1_scale, + use_int4_w4a16, w1_scale, + w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) return hidden_states else: - return torch.ops.vllm.outplace_fused_experts(hidden_states, w1, w2, - topk_weights, topk_ids, - use_fp8_w8a8, - use_int8_w8a16, w1_scale, - w2_scale, a1_scale, - a2_scale, block_shape) + return torch.ops.vllm.outplace_fused_experts( + hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, + use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, + a1_scale, a2_scale, block_shape) def fused_experts_impl(hidden_states: torch.Tensor, @@ -664,13 +942,21 @@ def fused_experts_impl(hidden_states: torch.Tensor, inplace: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + if use_int4_w4a16: + assert hidden_states.shape[1] // 2 == w1.shape[ + 2], "Hidden size mismatch" + else: + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -687,6 +973,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, M = min(num_tokens, CHUNK_SIZE) config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, dtype=hidden_states.dtype) get_config_func = functools.partial( @@ -755,6 +1042,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache1, a1_scale, w1_scale, + w1_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -766,6 +1054,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, block_shape=block_shape) torch.ops._C.silu_and_mul(intermediate_cache2, @@ -776,6 +1065,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, intermediate_cache3, a2_scale, w2_scale, + w2_zp, curr_topk_weights, curr_topk_ids, sorted_token_ids, @@ -787,6 +1077,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, block_shape=block_shape) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), @@ -808,8 +1099,11 @@ def fused_moe( custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + w1_zp: Optional[torch.Tensor] = None, + w2_zp: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, @@ -834,8 +1128,12 @@ def fused_moe( note: Deepseekv2 model uses grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. + - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 + activation to compute the inner products for w1 and w2. + Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -873,8 +1171,11 @@ def fused_moe( inplace=inplace, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + use_int4_w4a16=use_int4_w4a16, w1_scale=w1_scale, w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index d2bde13fcf54..bd0fd4799339 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -26,7 +26,8 @@ QUANTIZATION_METHODS: List[str] = [ "experts_int8", "neuron_quant", "ipex", - "quark" + "quark", + "moe_wna16" ] # The customized quantization methods which will be added to this dict. @@ -94,6 +95,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .ipex_quant import IPEXConfig from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config + from .moe_wna16 import MoeWNA16Config from .neuron_quant import NeuronQuantConfig from .qqq import QQQConfig from .tpu_int8 import Int8TpuConfig @@ -121,7 +123,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "experts_int8": ExpertsInt8Config, "neuron_quant": NeuronQuantConfig, "ipex": IPEXConfig, - "quark": QuarkConfig + "quark": QuarkConfig, + "moe_wna16": MoeWNA16Config, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py new file mode 100644 index 000000000000..8cd9c0a7ef25 --- /dev/null +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -0,0 +1,424 @@ +from typing import Any, Callable, Dict, List, Optional + +import torch + +from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.awq import (AWQConfig, + AWQLinearMethod) +from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinConfig, AWQMarlinLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.gptq import (GPTQConfig, + GPTQLinearMethod) +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig, GPTQMarlinLinearMethod) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + + +class MoeWNA16Config(QuantizationConfig): + """Config class for MOE WNA16 (W8A16/W4A16) quantization.""" + + def __init__(self, linear_quant_method: str, weight_bits: int, + group_size: int, has_zp: bool, lm_head_quantized: bool, + modules_to_not_convert: Optional[List[str]], + full_config: Dict[str, Any]) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.has_zp = has_zp + self.bit8_pack_factor = 8 // self.weight_bits + self.lm_head_quantized = lm_head_quantized + self.linear_quant_method = linear_quant_method + self.full_config = full_config + self.use_marlin = False + if self.linear_quant_method == "gptq": + self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible( + full_config) + elif self.linear_quant_method == "awq": + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + awq_min_capability = AWQConfig.get_min_capability() + if device_capability < awq_min_capability: + raise ValueError( + "The quantization method moe_wna16 + awq is not supported " + "for the current GPU. " + f"Minimum capability: {awq_min_capability}. " + f"Current capability: {device_capability}.") + self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible( + full_config) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + + if modules_to_not_convert is None: + self.modules_to_not_convert = [] + else: + self.modules_to_not_convert = modules_to_not_convert + + @classmethod + def get_name(cls) -> str: + return "moe_wna16" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config": + linear_quant_method = cls.get_from_keys(config, ["quant_method"]) + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + if linear_quant_method == "gptq": + has_zp = not cls.get_from_keys(config, ["sym"]) + modules_to_not_convert = [] + elif linear_quant_method == "awq": + has_zp = cls.get_from_keys(config, ["zero_point"]) + modules_to_not_convert = cls.get_from_keys( + config, ["modules_to_not_convert"]) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + + return cls(linear_quant_method, weight_bits, group_size, has_zp, + lm_head_quantized, modules_to_not_convert, config) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) + if can_convert and user_quant == "moe_wna16": + return cls.get_name() + return None + + @classmethod + def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + desc_act = quant_config.get("desc_act") + + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + awq_min_capability = AWQConfig.get_min_capability() + + gptq_compatible = quant_method == "gptq" and \ + not desc_act and num_bits in [4, 8] + awq_compatible = quant_method == "awq" and num_bits == 4 and \ + device_capability >= awq_min_capability + + return gptq_compatible or awq_compatible + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if is_layer_skipped_quant(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + elif isinstance(layer, FusedMoE): + return MoeWNA16Method(self) + else: + if self.linear_quant_method == "gptq": + if self.use_marlin: + return GPTQMarlinLinearMethod( + GPTQMarlinConfig.from_config(self.full_config)) + else: + return GPTQLinearMethod( + GPTQConfig.from_config(self.full_config)) + elif self.linear_quant_method == "awq": + if self.use_marlin: + return AWQMarlinLinearMethod( + AWQMarlinConfig.from_config(self.full_config)) + else: + return AWQLinearMethod( + AWQConfig.from_config(self.full_config)) + else: + raise ValueError("moe_wna16 only support gptq and awq.") + + +def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]): + return any(module_name in prefix for module_name in modules_to_not_convert) + + +class MoeWNA16Method(FusedMoEMethodBase): + """Linear method for MOE WNA16 (W8A16/W4A16) quantization. + + Args: + quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. + """ + + def __init__(self, quant_config: MoeWNA16Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + layer.quant_config = self.quant_config + bit8_pack_factor = self.quant_config.bit8_pack_factor + group_size = self.quant_config.group_size + group_size_div_factor = 1 + + # make intermediate_size and hidden_size diviable by group_size + # we reduce the group size to ensure that + # and we would repeat the loaded_weight later + while intermediate_size_per_partition % group_size or \ + hidden_size % group_size: + group_size = group_size // 2 + group_size_div_factor *= 2 + assert group_size >= 32 + layer.group_size = group_size + layer.group_size_div_factor = group_size_div_factor + + strategy = FusedMoeWeightScaleSupported.GROUP.value + extra_weight_attrs.update({ + "quant_method": strategy, + "is_transposed": False + }) + + assert 'weight_loader' in extra_weight_attrs + weight_loader = extra_weight_attrs['weight_loader'] + wrapped_weight_loader = MoeWNA16Method.get_weight_loader( + layer, weight_loader) + extra_weight_attrs['weight_loader'] = wrapped_weight_loader + + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // bit8_pack_factor, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // bit8_pack_factor, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + w13_scales = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // group_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // group_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + if self.quant_config.has_zp: + w13_qzeros = torch.nn.Parameter(torch.zeros( + num_experts, + 2 * intermediate_size_per_partition // bit8_pack_factor, + hidden_size // group_size, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = torch.nn.Parameter(torch.zeros( + num_experts, + hidden_size // bit8_pack_factor, + intermediate_size_per_partition // group_size, + dtype=torch.uint8), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + if self.quant_config.linear_quant_method == "gptq": + # some param are unused, but we need to init them in order to + # load weights + invalid_param_keys = ["w13_g_idx", "w2_g_idx"] + if not self.quant_config.has_zp: + invalid_param_keys += ["w13_qzeros", "w2_qzeros"] + for key in invalid_param_keys: + param = torch.nn.Parameter(torch.empty((0, ), + dtype=torch.int32), + requires_grad=False) + layer.register_parameter(key, param) + set_weight_attrs(param, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + weight_bits = self.quant_config.weight_bits + has_zp = self.quant_config.has_zp + + return fused_experts(x, + layer.w13_qweight, + layer.w2_qweight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_zp=layer.w13_qzeros if has_zp else None, + w2_zp=layer.w2_qzeros if has_zp else None, + block_shape=[0, layer.group_size]) + + @staticmethod + def get_weight_loader(layer, weight_loader): + + def convert_awq_tensor(tensor, tensor_type): + # convert awq qweight/qzeros to a standard format (assume int4) + # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8) + # qzeros: (k // group_size, n // pack_factor_bit32) -> + # (n // pack_factor_bit8, k // group_size) + # pack_factor_bit32 = 32 // weight_bits + # pack_factor_bit8 = 8 // weight_bits + + # 0. suppose origin shape (a, b), dtype int32 + # 1. convert to uint8, shape (a, b) -> (a, 4 * b) + size0 = tensor.size(0) + tensor = tensor.view(torch.uint8) + + # 2. unpack to uint4 (only when weight_bits == 4) + # shape (a, 4 * b) -> (a, 4 * b, 2) + shifter = torch.tensor([0, 4], + dtype=torch.uint8, + device=tensor.device) + tensor = (tensor[:, :, None] >> shifter) & 0xF + + # 3. change order, see + # https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py + # shape -> (a, 4 * b * pack_factor_bit8) + reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7] + tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order] + tensor = tensor.view(size0, -1) + + # 4. transpose, shape -> (4 * b * pack_factor_bit8, a) + tensor = tensor.T.contiguous() + + # 5. repack (only when weight_bits == 4) + # qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8) + # qzeros shape -> (4 * b, a) + + if tensor_type == "qweight": + tensor = tensor[:, 1::2] * 16 + tensor[:, ::2] + elif tensor_type == "qzeros": + tensor = tensor[1::2, :] * 16 + tensor[::2, :] + return tensor + + def convert_gptq_int4_qzeros(tensor): + tensor = tensor.view(torch.uint8) + shifter = torch.tensor([0, 4], + dtype=torch.uint8, + device=tensor.device) + tensor = (tensor[:, :, None] >> shifter) & 0xF + tensor = tensor + 1 + tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16 + return tensor + + def moe_wna16_weight_loader(param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, shard_id: str, + expert_id: int): + if "g_idx" in weight_name: + return + if not layer.quant_config.has_zp and "qzeros" in weight_name: + return + + device = get_tp_group().device + tp_rank = get_tensor_model_parallel_rank() + loaded_weight = loaded_weight.to(device) + shard_size = layer.intermediate_size_per_partition + + # convert gptq and awq weight to a standard format + if layer.quant_config.linear_quant_method == "awq": + assert layer.quant_config.weight_bits == 4 + if "weight" in weight_name: + loaded_weight = convert_awq_tensor(loaded_weight, + "qweight") + elif "zeros" in weight_name: + loaded_weight = convert_awq_tensor(loaded_weight, "qzeros") + else: + loaded_weight = loaded_weight.T + elif layer.quant_config.linear_quant_method == "gptq": + assert layer.quant_config.weight_bits in [4, 8] + if "weight" in weight_name: + loaded_weight = loaded_weight.T.contiguous().view( + torch.uint8) + elif "zeros" in weight_name: + # add 1 to gptq qzeros to align with awq + loaded_weight = loaded_weight.view(torch.uint8) + if layer.quant_config.weight_bits == 4: + loaded_weight = convert_gptq_int4_qzeros( + loaded_weight).T + else: + loaded_weight = loaded_weight.T + 1 + else: + loaded_weight = loaded_weight.T + + # repeat the qzeros/scales to fit new group size + if layer.group_size_div_factor > 1 and \ + "qzeros" in weight_name or "scales" in weight_name: + loaded_weight = loaded_weight.repeat_interleave( + layer.group_size_div_factor, 1) + + if "w13_qzeros" in weight_name: + tensor = loaded_weight.view(layer.tp_size, -1, + loaded_weight.size(1))[tp_rank] + if shard_id == "w1": + param.data[expert_id, :shard_size // 2] = tensor + else: + param.data[expert_id, shard_size // 2:] = tensor + elif "w2_qzeros" in weight_name: + param.data[expert_id] = loaded_weight.view( + loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank] + else: + weight_loader(param, loaded_weight, weight_name, shard_id, + expert_id) + + return moe_wna16_weight_loader