mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[Kernel] add triton fused moe kernel for gptq/awq (#12185)
This commit is contained in:
parent
b02fd288b2
commit
27b78c73ca
@ -18,6 +18,8 @@ from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as iterative_moe)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
@ -55,6 +57,95 @@ def test_fused_moe(
|
||||
rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 32, 222])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.parametrize("has_zp", [True, False])
|
||||
@pytest.mark.parametrize("weight_bits", [4, 8])
|
||||
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
dtype: torch.dtype, group_size: int, has_zp: bool,
|
||||
weight_bits: int):
|
||||
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
if weight_bits == 4:
|
||||
pack_factor = 2
|
||||
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
|
||||
elif weight_bits == 8:
|
||||
pack_factor = 1
|
||||
quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128
|
||||
|
||||
w1_ref = w1.clone()
|
||||
w2_ref = w2.clone()
|
||||
w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
w2_qweight = torch.empty((e, k, n // pack_factor),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
w1_scales = torch.empty((e, 2 * n, k // group_size),
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
w2_scales = torch.empty((e, k, n // group_size),
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
|
||||
for i in range(e * 2):
|
||||
expert_id = i % e
|
||||
if i // e == 0:
|
||||
w, w_ref, w_qweight, w_scales, w_qzeros = \
|
||||
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
|
||||
else:
|
||||
w, w_ref, w_qweight, w_scales, w_qzeros = \
|
||||
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
|
||||
weight, qweight, scales, qzeros = quantize_weights(
|
||||
w[expert_id].T, quant_type, group_size, has_zp, False)
|
||||
weight = weight.T
|
||||
qweight = qweight.T.contiguous().to(torch.uint8)
|
||||
scales = scales.T
|
||||
if has_zp:
|
||||
qzeros = qzeros.T.contiguous().to(torch.uint8)
|
||||
if weight_bits == 4:
|
||||
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
|
||||
if has_zp:
|
||||
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
|
||||
|
||||
w_ref[expert_id] = weight
|
||||
w_qweight[expert_id] = qweight
|
||||
w_scales[expert_id] = scales
|
||||
if has_zp:
|
||||
w_qzeros[expert_id] = qzeros
|
||||
|
||||
triton_output = fused_moe(a,
|
||||
w1_qweight,
|
||||
w2_qweight,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
w1_scale=w1_scales,
|
||||
w2_scale=w2_scales,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
|
||||
@ -19,6 +19,206 @@ from vllm.utils import direct_register_custom_op
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel_gptq_awq(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
b_scale_ptr,
|
||||
b_zp_ptr,
|
||||
topk_weights_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
num_tokens_post_padded_ptr,
|
||||
# Matrix dimensions
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
EM,
|
||||
num_valid_tokens,
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
# (A has M rows).
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_be,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_bse,
|
||||
stride_bsk,
|
||||
stride_bsn,
|
||||
stride_bze,
|
||||
stride_bzk,
|
||||
stride_bzn,
|
||||
block_k_diviable: tl.constexpr,
|
||||
group_size: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
compute_type: tl.constexpr,
|
||||
has_zp: tl.constexpr,
|
||||
use_int4_w4a16: tl.constexpr,
|
||||
use_int8_w8a16: tl.constexpr):
|
||||
"""
|
||||
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||
token and expert matrices.
|
||||
|
||||
Key Parameters:
|
||||
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||
be any shape representing batches and K is the feature dimension of
|
||||
each token.
|
||||
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||
the number of experts, K is the input feature dimension, and N is
|
||||
the output feature dimension.
|
||||
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||
total number of tokens post padding, topk is the number of times
|
||||
each token is repeated, and N is the output feature dimension.
|
||||
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||
repeated topk times and arranged by the expert index they are
|
||||
assigned to.
|
||||
- expert_ids: A tensor containing the indices of the expert for each
|
||||
block. It determines which expert matrix from B should be used for
|
||||
each block in A.
|
||||
This kernel performs the multiplication of a token by its corresponding
|
||||
expert matrix as determined by `expert_ids`. The sorting of
|
||||
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||
multiplication across different blocks processed by the same expert.
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse.
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
||||
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||
return
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
|
||||
tl.int64)
|
||||
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N +
|
||||
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||
offs_k[None, :] * stride_ak)
|
||||
|
||||
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||
|
||||
if use_int4_w4a16:
|
||||
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
|
||||
b_shifter = (offs_k[:, None] % 2) * 4
|
||||
elif use_int8_w8a16:
|
||||
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
||||
|
||||
if not has_zp and use_int4_w4a16:
|
||||
b_zp_num = 8
|
||||
if not has_zp and use_int8_w8a16:
|
||||
b_zp_num = 128
|
||||
elif has_zp and use_int4_w4a16:
|
||||
b_zp_shifter = (offs_bn[None, :] % 2) * 4
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the
|
||||
# K dimension.
|
||||
|
||||
if not block_k_diviable:
|
||||
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
|
||||
k_other = 0.0
|
||||
else:
|
||||
k_mask = None
|
||||
k_other = None
|
||||
|
||||
a = tl.load(a_ptrs,
|
||||
mask=token_mask[:, None] &
|
||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs)
|
||||
if use_int4_w4a16:
|
||||
b = (b >> b_shifter) & 0xF
|
||||
|
||||
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
|
||||
offs_bn[None, :] * stride_bsn + \
|
||||
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
||||
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
||||
b_scale = b_scale.to(tl.float32)
|
||||
|
||||
if has_zp and use_int4_w4a16:
|
||||
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
||||
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
|
||||
(offs_bn[None, :] // 2) * stride_bzn + \
|
||||
offs_k_true * stride_bzk
|
||||
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
||||
b_zp = ((b_zp >> b_zp_shifter) & 0xF)
|
||||
b_zp = b_zp.to(tl.float32)
|
||||
elif has_zp and use_int8_w8a16:
|
||||
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
||||
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
|
||||
offs_bn[None, :] * stride_bzn + \
|
||||
offs_k_true * stride_bzk
|
||||
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
||||
b_zp = b_zp.to(tl.float32)
|
||||
|
||||
# We accumulate along the K dimension.
|
||||
if has_zp:
|
||||
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
|
||||
else:
|
||||
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
|
||||
accumulator = tl.dot(a, b, acc=accumulator)
|
||||
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
if use_int4_w4a16:
|
||||
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
|
||||
else:
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||
mask=token_mask,
|
||||
other=0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
|
||||
accumulator = accumulator.to(compute_type)
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||
None, :]
|
||||
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_moe_kernel(
|
||||
# Pointers to matrices
|
||||
@ -266,6 +466,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
B_zp: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
@ -277,6 +478,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
compute_type: tl.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
@ -292,50 +494,108 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||
elif use_int8_w8a16:
|
||||
elif use_int8_w8a16 or use_int4_w4a16:
|
||||
assert B_scale is not None
|
||||
assert block_shape is None or block_shape[0] == 0
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
|
||||
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
||||
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
EM = sorted_token_ids.shape[0]
|
||||
if A.shape[0] < config["BLOCK_SIZE_M"]:
|
||||
# optimize for small batch_size.
|
||||
# We assume that top_ids of each token is unique, so
|
||||
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||
# and we can skip some invalid blocks.
|
||||
EM = min(sorted_token_ids.shape[0],
|
||||
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
|
||||
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
B.shape[1], META['BLOCK_SIZE_N']), )
|
||||
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
B.shape[2],
|
||||
sorted_token_ids.shape[0],
|
||||
topk_ids.numel(),
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
if (use_int8_w8a16 or use_int4_w4a16) and \
|
||||
block_shape is not None and block_shape[1] > 0:
|
||||
assert B_scale is not None and B_scale.ndim == 3
|
||||
assert B_zp is None or B_zp.ndim == 3
|
||||
|
||||
fused_moe_kernel_gptq_awq[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
B_scale,
|
||||
B_zp,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
A.shape[1],
|
||||
EM,
|
||||
topk_ids.numel(),
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
B_scale.stride(0),
|
||||
B_scale.stride(2),
|
||||
B_scale.stride(1),
|
||||
B_zp.stride(0) if B_zp is not None else 0,
|
||||
B_zp.stride(2) if B_zp is not None else 0,
|
||||
B_zp.stride(1) if B_zp is not None else 0,
|
||||
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
|
||||
group_size=block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
has_zp=B_zp is not None,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
|
||||
else:
|
||||
fused_moe_kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
A_scale,
|
||||
B_scale,
|
||||
topk_weights,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_padded,
|
||||
B.shape[1],
|
||||
A.shape[1],
|
||||
EM,
|
||||
topk_ids.numel(),
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
B.stride(0),
|
||||
B.stride(2),
|
||||
B.stride(1),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
A_scale.stride(0)
|
||||
if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
A_scale.stride(1)
|
||||
if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||
B_scale.stride(0)
|
||||
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
B_scale.stride(2)
|
||||
if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||
B_scale.stride(1)
|
||||
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||
0 if block_shape is None else block_shape[0],
|
||||
0 if block_shape is None else block_shape[1],
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
top_k=top_k,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
||||
@ -432,7 +692,7 @@ def try_get_optimal_moe_config(
|
||||
# NOTE: For block-wise quant,
|
||||
# BLOCK_K must be divisible by block_shape[1]
|
||||
# BLOCK_N and BLOCK_M has no requirements
|
||||
if block_shape is not None:
|
||||
if block_shape is not None and block_shape[0] != 0:
|
||||
config["BLOCK_SIZE_N"] = block_shape[0]
|
||||
config["BLOCK_SIZE_K"] = block_shape[1]
|
||||
return config
|
||||
@ -531,12 +791,15 @@ def grouped_topk(hidden_states: torch.Tensor,
|
||||
|
||||
|
||||
def get_config_dtype_str(dtype: torch.dtype,
|
||||
use_int4_w4a16: Optional[bool] = False,
|
||||
use_int8_w8a16: Optional[bool] = False,
|
||||
use_fp8_w8a8: Optional[bool] = False):
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif use_int4_w4a16:
|
||||
return "int4_w8a16"
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
@ -551,14 +814,17 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||
use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def inplace_fused_experts_fake(
|
||||
@ -569,8 +835,11 @@ def inplace_fused_experts_fake(
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> None:
|
||||
@ -593,14 +862,18 @@ def outplace_fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
False, use_fp8_w8a8, use_int8_w8a16, w1_scale,
|
||||
w2_scale, a1_scale, a2_scale, block_shape)
|
||||
False, use_fp8_w8a8, use_int8_w8a16,
|
||||
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def outplace_fused_experts_fake(
|
||||
@ -611,8 +884,11 @@ def outplace_fused_experts_fake(
|
||||
topk_ids: torch.Tensor,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||
@ -635,8 +911,11 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None):
|
||||
@ -644,16 +923,15 @@ def fused_experts(hidden_states: torch.Tensor,
|
||||
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
|
||||
topk_weights, topk_ids,
|
||||
use_fp8_w8a8, use_int8_w8a16,
|
||||
w1_scale, w2_scale, a1_scale,
|
||||
use_int4_w4a16, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
return hidden_states
|
||||
else:
|
||||
return torch.ops.vllm.outplace_fused_experts(hidden_states, w1, w2,
|
||||
topk_weights, topk_ids,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a16, w1_scale,
|
||||
w2_scale, a1_scale,
|
||||
a2_scale, block_shape)
|
||||
return torch.ops.vllm.outplace_fused_experts(
|
||||
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
||||
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||
a1_scale, a2_scale, block_shape)
|
||||
|
||||
|
||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
@ -664,13 +942,21 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
inplace: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None):
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
if use_int4_w4a16:
|
||||
assert hidden_states.shape[1] // 2 == w1.shape[
|
||||
2], "Hidden size mismatch"
|
||||
else:
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
@ -687,6 +973,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
@ -755,6 +1042,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
w1_zp,
|
||||
curr_topk_weights,
|
||||
curr_topk_ids,
|
||||
sorted_token_ids,
|
||||
@ -766,6 +1054,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
block_shape=block_shape)
|
||||
|
||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||
@ -776,6 +1065,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
curr_topk_weights,
|
||||
curr_topk_ids,
|
||||
sorted_token_ids,
|
||||
@ -787,6 +1077,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
||||
compute_type=compute_type,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
block_shape=block_shape)
|
||||
|
||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
@ -808,8 +1099,11 @@ def fused_moe(
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
@ -834,8 +1128,12 @@ def fused_moe(
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
||||
activation to compute the inner products for w1 and w2.
|
||||
Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
@ -873,8 +1171,11 @@ def fused_moe(
|
||||
inplace=inplace,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape)
|
||||
|
||||
@ -26,7 +26,8 @@ QUANTIZATION_METHODS: List[str] = [
|
||||
"experts_int8",
|
||||
"neuron_quant",
|
||||
"ipex",
|
||||
"quark"
|
||||
"quark",
|
||||
"moe_wna16"
|
||||
]
|
||||
|
||||
# The customized quantization methods which will be added to this dict.
|
||||
@ -94,6 +95,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
from .ipex_quant import IPEXConfig
|
||||
from .marlin import MarlinConfig
|
||||
from .modelopt import ModelOptFp8Config
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .neuron_quant import NeuronQuantConfig
|
||||
from .qqq import QQQConfig
|
||||
from .tpu_int8 import Int8TpuConfig
|
||||
@ -121,7 +123,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"neuron_quant": NeuronQuantConfig,
|
||||
"ipex": IPEXConfig,
|
||||
"quark": QuarkConfig
|
||||
"quark": QuarkConfig,
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
||||
424
vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
424
vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
@ -0,0 +1,424 @@
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
||||
AWQLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||
AWQMarlinConfig, AWQMarlinLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
|
||||
GPTQLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig, GPTQMarlinLinearMethod)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class MoeWNA16Config(QuantizationConfig):
|
||||
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
|
||||
|
||||
def __init__(self, linear_quant_method: str, weight_bits: int,
|
||||
group_size: int, has_zp: bool, lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[List[str]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.has_zp = has_zp
|
||||
self.bit8_pack_factor = 8 // self.weight_bits
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.linear_quant_method = linear_quant_method
|
||||
self.full_config = full_config
|
||||
self.use_marlin = False
|
||||
if self.linear_quant_method == "gptq":
|
||||
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(
|
||||
full_config)
|
||||
elif self.linear_quant_method == "awq":
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
awq_min_capability = AWQConfig.get_min_capability()
|
||||
if device_capability < awq_min_capability:
|
||||
raise ValueError(
|
||||
"The quantization method moe_wna16 + awq is not supported "
|
||||
"for the current GPU. "
|
||||
f"Minimum capability: {awq_min_capability}. "
|
||||
f"Current capability: {device_capability}.")
|
||||
self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible(
|
||||
full_config)
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
|
||||
if modules_to_not_convert is None:
|
||||
self.modules_to_not_convert = []
|
||||
else:
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "moe_wna16"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
|
||||
linear_quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
if linear_quant_method == "gptq":
|
||||
has_zp = not cls.get_from_keys(config, ["sym"])
|
||||
modules_to_not_convert = []
|
||||
elif linear_quant_method == "awq":
|
||||
has_zp = cls.get_from_keys(config, ["zero_point"])
|
||||
modules_to_not_convert = cls.get_from_keys(
|
||||
config, ["modules_to_not_convert"])
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
|
||||
return cls(linear_quant_method, weight_bits, group_size, has_zp,
|
||||
lm_head_quantized, modules_to_not_convert, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
||||
if can_convert and user_quant == "moe_wna16":
|
||||
return cls.get_name()
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
desc_act = quant_config.get("desc_act")
|
||||
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (-1 if capability_tuple is None else
|
||||
capability_tuple.to_int())
|
||||
awq_min_capability = AWQConfig.get_min_capability()
|
||||
|
||||
gptq_compatible = quant_method == "gptq" and \
|
||||
not desc_act and num_bits in [4, 8]
|
||||
awq_compatible = quant_method == "awq" and num_bits == 4 and \
|
||||
device_capability >= awq_min_capability
|
||||
|
||||
return gptq_compatible or awq_compatible
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return MoeWNA16Method(self)
|
||||
else:
|
||||
if self.linear_quant_method == "gptq":
|
||||
if self.use_marlin:
|
||||
return GPTQMarlinLinearMethod(
|
||||
GPTQMarlinConfig.from_config(self.full_config))
|
||||
else:
|
||||
return GPTQLinearMethod(
|
||||
GPTQConfig.from_config(self.full_config))
|
||||
elif self.linear_quant_method == "awq":
|
||||
if self.use_marlin:
|
||||
return AWQMarlinLinearMethod(
|
||||
AWQMarlinConfig.from_config(self.full_config))
|
||||
else:
|
||||
return AWQLinearMethod(
|
||||
AWQConfig.from_config(self.full_config))
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
|
||||
|
||||
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
||||
|
||||
class MoeWNA16Method(FusedMoEMethodBase):
|
||||
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
|
||||
|
||||
Args:
|
||||
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: MoeWNA16Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
layer.quant_config = self.quant_config
|
||||
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
||||
group_size = self.quant_config.group_size
|
||||
group_size_div_factor = 1
|
||||
|
||||
# make intermediate_size and hidden_size diviable by group_size
|
||||
# we reduce the group size to ensure that
|
||||
# and we would repeat the loaded_weight later
|
||||
while intermediate_size_per_partition % group_size or \
|
||||
hidden_size % group_size:
|
||||
group_size = group_size // 2
|
||||
group_size_div_factor *= 2
|
||||
assert group_size >= 32
|
||||
layer.group_size = group_size
|
||||
layer.group_size_div_factor = group_size_div_factor
|
||||
|
||||
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||
extra_weight_attrs.update({
|
||||
"quant_method": strategy,
|
||||
"is_transposed": False
|
||||
})
|
||||
|
||||
assert 'weight_loader' in extra_weight_attrs
|
||||
weight_loader = extra_weight_attrs['weight_loader']
|
||||
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
|
||||
layer, weight_loader)
|
||||
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // bit8_pack_factor,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qweight", w13_qweight)
|
||||
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_qweight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // bit8_pack_factor,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qweight", w2_qweight)
|
||||
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||
|
||||
w13_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // group_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_scales", w13_scales)
|
||||
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||
|
||||
w2_scales = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // group_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_scales", w2_scales)
|
||||
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.has_zp:
|
||||
w13_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition // bit8_pack_factor,
|
||||
hidden_size // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||
|
||||
w2_qzeros = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size // bit8_pack_factor,
|
||||
intermediate_size_per_partition // group_size,
|
||||
dtype=torch.uint8),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||
|
||||
if self.quant_config.linear_quant_method == "gptq":
|
||||
# some param are unused, but we need to init them in order to
|
||||
# load weights
|
||||
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
|
||||
if not self.quant_config.has_zp:
|
||||
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
|
||||
for key in invalid_param_keys:
|
||||
param = torch.nn.Parameter(torch.empty((0, ),
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter(key, param)
|
||||
set_weight_attrs(param, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
has_zp = self.quant_config.has_zp
|
||||
|
||||
return fused_experts(x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
w1_scale=layer.w13_scales,
|
||||
w2_scale=layer.w2_scales,
|
||||
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||
block_shape=[0, layer.group_size])
|
||||
|
||||
@staticmethod
|
||||
def get_weight_loader(layer, weight_loader):
|
||||
|
||||
def convert_awq_tensor(tensor, tensor_type):
|
||||
# convert awq qweight/qzeros to a standard format (assume int4)
|
||||
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
|
||||
# qzeros: (k // group_size, n // pack_factor_bit32) ->
|
||||
# (n // pack_factor_bit8, k // group_size)
|
||||
# pack_factor_bit32 = 32 // weight_bits
|
||||
# pack_factor_bit8 = 8 // weight_bits
|
||||
|
||||
# 0. suppose origin shape (a, b), dtype int32
|
||||
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
|
||||
size0 = tensor.size(0)
|
||||
tensor = tensor.view(torch.uint8)
|
||||
|
||||
# 2. unpack to uint4 (only when weight_bits == 4)
|
||||
# shape (a, 4 * b) -> (a, 4 * b, 2)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
|
||||
# 3. change order, see
|
||||
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
|
||||
# shape -> (a, 4 * b * pack_factor_bit8)
|
||||
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
|
||||
tensor = tensor.view(size0, -1)
|
||||
|
||||
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
|
||||
tensor = tensor.T.contiguous()
|
||||
|
||||
# 5. repack (only when weight_bits == 4)
|
||||
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
|
||||
# qzeros shape -> (4 * b, a)
|
||||
|
||||
if tensor_type == "qweight":
|
||||
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
|
||||
elif tensor_type == "qzeros":
|
||||
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
|
||||
return tensor
|
||||
|
||||
def convert_gptq_int4_qzeros(tensor):
|
||||
tensor = tensor.view(torch.uint8)
|
||||
shifter = torch.tensor([0, 4],
|
||||
dtype=torch.uint8,
|
||||
device=tensor.device)
|
||||
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||
tensor = tensor + 1
|
||||
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
|
||||
return tensor
|
||||
|
||||
def moe_wna16_weight_loader(param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
weight_name: str, shard_id: str,
|
||||
expert_id: int):
|
||||
if "g_idx" in weight_name:
|
||||
return
|
||||
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
||||
return
|
||||
|
||||
device = get_tp_group().device
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
loaded_weight = loaded_weight.to(device)
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
|
||||
# convert gptq and awq weight to a standard format
|
||||
if layer.quant_config.linear_quant_method == "awq":
|
||||
assert layer.quant_config.weight_bits == 4
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight,
|
||||
"qweight")
|
||||
elif "zeros" in weight_name:
|
||||
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
|
||||
else:
|
||||
loaded_weight = loaded_weight.T
|
||||
elif layer.quant_config.linear_quant_method == "gptq":
|
||||
assert layer.quant_config.weight_bits in [4, 8]
|
||||
if "weight" in weight_name:
|
||||
loaded_weight = loaded_weight.T.contiguous().view(
|
||||
torch.uint8)
|
||||
elif "zeros" in weight_name:
|
||||
# add 1 to gptq qzeros to align with awq
|
||||
loaded_weight = loaded_weight.view(torch.uint8)
|
||||
if layer.quant_config.weight_bits == 4:
|
||||
loaded_weight = convert_gptq_int4_qzeros(
|
||||
loaded_weight).T
|
||||
else:
|
||||
loaded_weight = loaded_weight.T + 1
|
||||
else:
|
||||
loaded_weight = loaded_weight.T
|
||||
|
||||
# repeat the qzeros/scales to fit new group size
|
||||
if layer.group_size_div_factor > 1 and \
|
||||
"qzeros" in weight_name or "scales" in weight_name:
|
||||
loaded_weight = loaded_weight.repeat_interleave(
|
||||
layer.group_size_div_factor, 1)
|
||||
|
||||
if "w13_qzeros" in weight_name:
|
||||
tensor = loaded_weight.view(layer.tp_size, -1,
|
||||
loaded_weight.size(1))[tp_rank]
|
||||
if shard_id == "w1":
|
||||
param.data[expert_id, :shard_size // 2] = tensor
|
||||
else:
|
||||
param.data[expert_id, shard_size // 2:] = tensor
|
||||
elif "w2_qzeros" in weight_name:
|
||||
param.data[expert_id] = loaded_weight.view(
|
||||
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
|
||||
else:
|
||||
weight_loader(param, loaded_weight, weight_name, shard_id,
|
||||
expert_id)
|
||||
|
||||
return moe_wna16_weight_loader
|
||||
Loading…
x
Reference in New Issue
Block a user