mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:01:40 +08:00
[Kernel] add triton fused moe kernel for gptq/awq (#12185)
This commit is contained in:
parent
b02fd288b2
commit
27b78c73ca
@ -18,6 +18,8 @@ from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
|||||||
fused_moe as iterative_moe)
|
fused_moe as iterative_moe)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
marlin_quantize)
|
marlin_quantize)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
quantize_weights)
|
||||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
@ -55,6 +57,95 @@ def test_fused_moe(
|
|||||||
rtol=0)
|
rtol=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m", [1, 32, 222])
|
||||||
|
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||||
|
@pytest.mark.parametrize("k", [128, 1024])
|
||||||
|
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||||
|
@pytest.mark.parametrize("topk", TOP_KS)
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("group_size", [64, 128])
|
||||||
|
@pytest.mark.parametrize("has_zp", [True, False])
|
||||||
|
@pytest.mark.parametrize("weight_bits", [4, 8])
|
||||||
|
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||||
|
dtype: torch.dtype, group_size: int, has_zp: bool,
|
||||||
|
weight_bits: int):
|
||||||
|
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
if weight_bits == 4:
|
||||||
|
pack_factor = 2
|
||||||
|
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
|
||||||
|
elif weight_bits == 8:
|
||||||
|
pack_factor = 1
|
||||||
|
quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128
|
||||||
|
|
||||||
|
w1_ref = w1.clone()
|
||||||
|
w2_ref = w2.clone()
|
||||||
|
w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.uint8)
|
||||||
|
w2_qweight = torch.empty((e, k, n // pack_factor),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.uint8)
|
||||||
|
w1_scales = torch.empty((e, 2 * n, k // group_size),
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype)
|
||||||
|
w2_scales = torch.empty((e, k, n // group_size),
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype)
|
||||||
|
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.uint8)
|
||||||
|
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.uint8)
|
||||||
|
|
||||||
|
for i in range(e * 2):
|
||||||
|
expert_id = i % e
|
||||||
|
if i // e == 0:
|
||||||
|
w, w_ref, w_qweight, w_scales, w_qzeros = \
|
||||||
|
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
|
||||||
|
else:
|
||||||
|
w, w_ref, w_qweight, w_scales, w_qzeros = \
|
||||||
|
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
|
||||||
|
weight, qweight, scales, qzeros = quantize_weights(
|
||||||
|
w[expert_id].T, quant_type, group_size, has_zp, False)
|
||||||
|
weight = weight.T
|
||||||
|
qweight = qweight.T.contiguous().to(torch.uint8)
|
||||||
|
scales = scales.T
|
||||||
|
if has_zp:
|
||||||
|
qzeros = qzeros.T.contiguous().to(torch.uint8)
|
||||||
|
if weight_bits == 4:
|
||||||
|
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
|
||||||
|
if has_zp:
|
||||||
|
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
|
||||||
|
|
||||||
|
w_ref[expert_id] = weight
|
||||||
|
w_qweight[expert_id] = qweight
|
||||||
|
w_scales[expert_id] = scales
|
||||||
|
if has_zp:
|
||||||
|
w_qzeros[expert_id] = qzeros
|
||||||
|
|
||||||
|
triton_output = fused_moe(a,
|
||||||
|
w1_qweight,
|
||||||
|
w2_qweight,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
renormalize=False,
|
||||||
|
use_int4_w4a16=weight_bits == 4,
|
||||||
|
use_int8_w8a16=weight_bits == 8,
|
||||||
|
w1_scale=w1_scales,
|
||||||
|
w2_scale=w2_scales,
|
||||||
|
w1_zp=w1_qzeros if has_zp else None,
|
||||||
|
w2_zp=w2_qzeros if has_zp else None,
|
||||||
|
block_shape=[0, group_size])
|
||||||
|
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
|
||||||
|
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype",
|
@pytest.mark.parametrize("dtype",
|
||||||
[torch.float32, torch.float16, torch.bfloat16])
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
@ -19,6 +19,206 @@ from vllm.utils import direct_register_custom_op
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def fused_moe_kernel_gptq_awq(
|
||||||
|
# Pointers to matrices
|
||||||
|
a_ptr,
|
||||||
|
b_ptr,
|
||||||
|
c_ptr,
|
||||||
|
b_scale_ptr,
|
||||||
|
b_zp_ptr,
|
||||||
|
topk_weights_ptr,
|
||||||
|
sorted_token_ids_ptr,
|
||||||
|
expert_ids_ptr,
|
||||||
|
num_tokens_post_padded_ptr,
|
||||||
|
# Matrix dimensions
|
||||||
|
N: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
EM,
|
||||||
|
num_valid_tokens,
|
||||||
|
# The stride variables represent how much to increase the ptr by when
|
||||||
|
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||||
|
# how much to increase `a_ptr` by to get the element one row down
|
||||||
|
# (A has M rows).
|
||||||
|
stride_am,
|
||||||
|
stride_ak,
|
||||||
|
stride_be,
|
||||||
|
stride_bk,
|
||||||
|
stride_bn,
|
||||||
|
stride_cm,
|
||||||
|
stride_cn,
|
||||||
|
stride_bse,
|
||||||
|
stride_bsk,
|
||||||
|
stride_bsn,
|
||||||
|
stride_bze,
|
||||||
|
stride_bzk,
|
||||||
|
stride_bzn,
|
||||||
|
block_k_diviable: tl.constexpr,
|
||||||
|
group_size: tl.constexpr,
|
||||||
|
# Meta-parameters
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
|
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||||
|
top_k: tl.constexpr,
|
||||||
|
compute_type: tl.constexpr,
|
||||||
|
has_zp: tl.constexpr,
|
||||||
|
use_int4_w4a16: tl.constexpr,
|
||||||
|
use_int8_w8a16: tl.constexpr):
|
||||||
|
"""
|
||||||
|
Implements the fused computation for a Mixture of Experts (MOE) using
|
||||||
|
token and expert matrices.
|
||||||
|
|
||||||
|
Key Parameters:
|
||||||
|
- A: The input tensor representing tokens with shape (*, K), where '*' can
|
||||||
|
be any shape representing batches and K is the feature dimension of
|
||||||
|
each token.
|
||||||
|
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
|
||||||
|
the number of experts, K is the input feature dimension, and N is
|
||||||
|
the output feature dimension.
|
||||||
|
- C: The output cache tensor with shape (M, topk, N), where M is the
|
||||||
|
total number of tokens post padding, topk is the number of times
|
||||||
|
each token is repeated, and N is the output feature dimension.
|
||||||
|
- sorted_token_ids: A tensor containing the sorted indices of tokens,
|
||||||
|
repeated topk times and arranged by the expert index they are
|
||||||
|
assigned to.
|
||||||
|
- expert_ids: A tensor containing the indices of the expert for each
|
||||||
|
block. It determines which expert matrix from B should be used for
|
||||||
|
each block in A.
|
||||||
|
This kernel performs the multiplication of a token by its corresponding
|
||||||
|
expert matrix as determined by `expert_ids`. The sorting of
|
||||||
|
`sorted_token_ids` by expert index and padding ensures divisibility by
|
||||||
|
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
|
||||||
|
multiplication across different blocks processed by the same expert.
|
||||||
|
"""
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Map program ids `pid` to the block of C it should compute.
|
||||||
|
# This is done in a grouped ordering to promote L2 data reuse.
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
|
||||||
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||||
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||||
|
group_id = pid // num_pid_in_group
|
||||||
|
first_pid_m = group_id * GROUP_SIZE_M
|
||||||
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||||
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||||
|
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||||
|
|
||||||
|
# ----------------------------------------------------------
|
||||||
|
# Create pointers for the first blocks of A and B.
|
||||||
|
# We will advance this pointer as we move in the K direction
|
||||||
|
# and accumulate
|
||||||
|
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||||
|
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
||||||
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
|
||||||
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
|
||||||
|
return
|
||||||
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(
|
||||||
|
tl.int64)
|
||||||
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
|
||||||
|
token_mask = offs_token < num_valid_tokens
|
||||||
|
|
||||||
|
offs_bn = (pid_n * BLOCK_SIZE_N +
|
||||||
|
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||||
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
|
||||||
|
offs_k[None, :] * stride_ak)
|
||||||
|
|
||||||
|
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
|
||||||
|
|
||||||
|
if use_int4_w4a16:
|
||||||
|
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||||
|
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
|
||||||
|
b_shifter = (offs_k[:, None] % 2) * 4
|
||||||
|
elif use_int8_w8a16:
|
||||||
|
b_ptrs = b_ptr + off_experts * stride_be + \
|
||||||
|
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
|
||||||
|
|
||||||
|
if not has_zp and use_int4_w4a16:
|
||||||
|
b_zp_num = 8
|
||||||
|
if not has_zp and use_int8_w8a16:
|
||||||
|
b_zp_num = 128
|
||||||
|
elif has_zp and use_int4_w4a16:
|
||||||
|
b_zp_shifter = (offs_bn[None, :] % 2) * 4
|
||||||
|
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Iterate to compute a block of the C matrix.
|
||||||
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||||
|
# of fp32 values for higher accuracy.
|
||||||
|
# `accumulator` will be converted back to fp16 after the loop.
|
||||||
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||||
|
# Load the next block of A and B, generate a mask by checking the
|
||||||
|
# K dimension.
|
||||||
|
|
||||||
|
if not block_k_diviable:
|
||||||
|
k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
|
||||||
|
k_other = 0.0
|
||||||
|
else:
|
||||||
|
k_mask = None
|
||||||
|
k_other = None
|
||||||
|
|
||||||
|
a = tl.load(a_ptrs,
|
||||||
|
mask=token_mask[:, None] &
|
||||||
|
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||||
|
other=0.0)
|
||||||
|
b = tl.load(b_ptrs)
|
||||||
|
if use_int4_w4a16:
|
||||||
|
b = (b >> b_shifter) & 0xF
|
||||||
|
|
||||||
|
b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + \
|
||||||
|
offs_bn[None, :] * stride_bsn + \
|
||||||
|
((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
|
||||||
|
b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
|
||||||
|
b_scale = b_scale.to(tl.float32)
|
||||||
|
|
||||||
|
if has_zp and use_int4_w4a16:
|
||||||
|
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
||||||
|
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
|
||||||
|
(offs_bn[None, :] // 2) * stride_bzn + \
|
||||||
|
offs_k_true * stride_bzk
|
||||||
|
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
||||||
|
b_zp = ((b_zp >> b_zp_shifter) & 0xF)
|
||||||
|
b_zp = b_zp.to(tl.float32)
|
||||||
|
elif has_zp and use_int8_w8a16:
|
||||||
|
offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
|
||||||
|
b_zp_ptrs = b_zp_ptr + off_experts * stride_bze + \
|
||||||
|
offs_bn[None, :] * stride_bzn + \
|
||||||
|
offs_k_true * stride_bzk
|
||||||
|
b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
|
||||||
|
b_zp = b_zp.to(tl.float32)
|
||||||
|
|
||||||
|
# We accumulate along the K dimension.
|
||||||
|
if has_zp:
|
||||||
|
b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
|
||||||
|
else:
|
||||||
|
b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
|
||||||
|
accumulator = tl.dot(a, b, acc=accumulator)
|
||||||
|
|
||||||
|
# Advance the ptrs to the next K block.
|
||||||
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||||
|
if use_int4_w4a16:
|
||||||
|
b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
|
||||||
|
else:
|
||||||
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||||
|
|
||||||
|
if MUL_ROUTED_WEIGHT:
|
||||||
|
moe_weight = tl.load(topk_weights_ptr + offs_token,
|
||||||
|
mask=token_mask,
|
||||||
|
other=0)
|
||||||
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
|
|
||||||
|
accumulator = accumulator.to(compute_type)
|
||||||
|
# -----------------------------------------------------------
|
||||||
|
# Write back the block of the output
|
||||||
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
|
||||||
|
None, :]
|
||||||
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
|
||||||
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fused_moe_kernel(
|
def fused_moe_kernel(
|
||||||
# Pointers to matrices
|
# Pointers to matrices
|
||||||
@ -266,6 +466,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
C: torch.Tensor,
|
C: torch.Tensor,
|
||||||
A_scale: Optional[torch.Tensor],
|
A_scale: Optional[torch.Tensor],
|
||||||
B_scale: Optional[torch.Tensor],
|
B_scale: Optional[torch.Tensor],
|
||||||
|
B_zp: Optional[torch.Tensor],
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
sorted_token_ids: torch.Tensor,
|
sorted_token_ids: torch.Tensor,
|
||||||
@ -277,6 +478,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
compute_type: tl.dtype,
|
compute_type: tl.dtype,
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a16: bool,
|
use_int8_w8a16: bool,
|
||||||
|
use_int4_w4a16: bool,
|
||||||
block_shape: Optional[List[int]] = None) -> None:
|
block_shape: Optional[List[int]] = None) -> None:
|
||||||
assert topk_weights.stride(1) == 1
|
assert topk_weights.stride(1) == 1
|
||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
@ -292,50 +494,108 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
|
|||||||
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
|
||||||
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
|
||||||
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
|
||||||
elif use_int8_w8a16:
|
elif use_int8_w8a16 or use_int4_w4a16:
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
|
assert block_shape is None or block_shape[0] == 0
|
||||||
else:
|
else:
|
||||||
assert A_scale is None
|
assert A_scale is None
|
||||||
assert B_scale is None
|
assert B_scale is None
|
||||||
|
|
||||||
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
EM = sorted_token_ids.shape[0]
|
||||||
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
|
if A.shape[0] < config["BLOCK_SIZE_M"]:
|
||||||
|
# optimize for small batch_size.
|
||||||
|
# We assume that top_ids of each token is unique, so
|
||||||
|
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
|
||||||
|
# and we can skip some invalid blocks.
|
||||||
|
EM = min(sorted_token_ids.shape[0],
|
||||||
|
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
|
||||||
|
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||||
|
B.shape[1], META['BLOCK_SIZE_N']), )
|
||||||
|
|
||||||
fused_moe_kernel[grid](
|
if (use_int8_w8a16 or use_int4_w4a16) and \
|
||||||
A,
|
block_shape is not None and block_shape[1] > 0:
|
||||||
B,
|
assert B_scale is not None and B_scale.ndim == 3
|
||||||
C,
|
assert B_zp is None or B_zp.ndim == 3
|
||||||
A_scale,
|
|
||||||
B_scale,
|
fused_moe_kernel_gptq_awq[grid](
|
||||||
topk_weights,
|
A,
|
||||||
sorted_token_ids,
|
B,
|
||||||
expert_ids,
|
C,
|
||||||
num_tokens_post_padded,
|
B_scale,
|
||||||
B.shape[1],
|
B_zp,
|
||||||
B.shape[2],
|
topk_weights,
|
||||||
sorted_token_ids.shape[0],
|
sorted_token_ids,
|
||||||
topk_ids.numel(),
|
expert_ids,
|
||||||
A.stride(0),
|
num_tokens_post_padded,
|
||||||
A.stride(1),
|
B.shape[1],
|
||||||
B.stride(0),
|
A.shape[1],
|
||||||
B.stride(2),
|
EM,
|
||||||
B.stride(1),
|
topk_ids.numel(),
|
||||||
C.stride(1),
|
A.stride(0),
|
||||||
C.stride(2),
|
A.stride(1),
|
||||||
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
|
B.stride(0),
|
||||||
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
|
B.stride(2),
|
||||||
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
B.stride(1),
|
||||||
B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
|
C.stride(1),
|
||||||
B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
|
C.stride(2),
|
||||||
0 if block_shape is None else block_shape[0],
|
B_scale.stride(0),
|
||||||
0 if block_shape is None else block_shape[1],
|
B_scale.stride(2),
|
||||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
B_scale.stride(1),
|
||||||
top_k=top_k,
|
B_zp.stride(0) if B_zp is not None else 0,
|
||||||
compute_type=compute_type,
|
B_zp.stride(2) if B_zp is not None else 0,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
B_zp.stride(1) if B_zp is not None else 0,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
|
||||||
**config,
|
group_size=block_shape[1],
|
||||||
)
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||||
|
top_k=top_k,
|
||||||
|
compute_type=compute_type,
|
||||||
|
has_zp=B_zp is not None,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
fused_moe_kernel[grid](
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
A_scale,
|
||||||
|
B_scale,
|
||||||
|
topk_weights,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
num_tokens_post_padded,
|
||||||
|
B.shape[1],
|
||||||
|
A.shape[1],
|
||||||
|
EM,
|
||||||
|
topk_ids.numel(),
|
||||||
|
A.stride(0),
|
||||||
|
A.stride(1),
|
||||||
|
B.stride(0),
|
||||||
|
B.stride(2),
|
||||||
|
B.stride(1),
|
||||||
|
C.stride(1),
|
||||||
|
C.stride(2),
|
||||||
|
A_scale.stride(0)
|
||||||
|
if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||||
|
A_scale.stride(1)
|
||||||
|
if A_scale is not None and A_scale.ndim == 2 else 0,
|
||||||
|
B_scale.stride(0)
|
||||||
|
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||||
|
B_scale.stride(2)
|
||||||
|
if B_scale is not None and B_scale.ndim == 3 else 0,
|
||||||
|
B_scale.stride(1)
|
||||||
|
if B_scale is not None and B_scale.ndim >= 2 else 0,
|
||||||
|
0 if block_shape is None else block_shape[0],
|
||||||
|
0 if block_shape is None else block_shape[1],
|
||||||
|
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||||
|
top_k=top_k,
|
||||||
|
compute_type=compute_type,
|
||||||
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
|
||||||
@ -432,7 +692,7 @@ def try_get_optimal_moe_config(
|
|||||||
# NOTE: For block-wise quant,
|
# NOTE: For block-wise quant,
|
||||||
# BLOCK_K must be divisible by block_shape[1]
|
# BLOCK_K must be divisible by block_shape[1]
|
||||||
# BLOCK_N and BLOCK_M has no requirements
|
# BLOCK_N and BLOCK_M has no requirements
|
||||||
if block_shape is not None:
|
if block_shape is not None and block_shape[0] != 0:
|
||||||
config["BLOCK_SIZE_N"] = block_shape[0]
|
config["BLOCK_SIZE_N"] = block_shape[0]
|
||||||
config["BLOCK_SIZE_K"] = block_shape[1]
|
config["BLOCK_SIZE_K"] = block_shape[1]
|
||||||
return config
|
return config
|
||||||
@ -531,12 +791,15 @@ def grouped_topk(hidden_states: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
def get_config_dtype_str(dtype: torch.dtype,
|
def get_config_dtype_str(dtype: torch.dtype,
|
||||||
|
use_int4_w4a16: Optional[bool] = False,
|
||||||
use_int8_w8a16: Optional[bool] = False,
|
use_int8_w8a16: Optional[bool] = False,
|
||||||
use_fp8_w8a8: Optional[bool] = False):
|
use_fp8_w8a8: Optional[bool] = False):
|
||||||
if use_fp8_w8a8:
|
if use_fp8_w8a8:
|
||||||
return "fp8_w8a8"
|
return "fp8_w8a8"
|
||||||
elif use_int8_w8a16:
|
elif use_int8_w8a16:
|
||||||
return "int8_w8a16"
|
return "int8_w8a16"
|
||||||
|
elif use_int4_w4a16:
|
||||||
|
return "int4_w8a16"
|
||||||
elif dtype == torch.float:
|
elif dtype == torch.float:
|
||||||
# avoiding cases where kernel fails when float32 MoE
|
# avoiding cases where kernel fails when float32 MoE
|
||||||
# use fp16/bfloat16 configs
|
# use fp16/bfloat16 configs
|
||||||
@ -551,14 +814,17 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None) -> None:
|
block_shape: Optional[List[int]] = None) -> None:
|
||||||
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
|
||||||
use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale,
|
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
|
||||||
a1_scale, a2_scale, block_shape)
|
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
|
||||||
|
|
||||||
|
|
||||||
def inplace_fused_experts_fake(
|
def inplace_fused_experts_fake(
|
||||||
@ -569,8 +835,11 @@ def inplace_fused_experts_fake(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None) -> None:
|
block_shape: Optional[List[int]] = None) -> None:
|
||||||
@ -593,14 +862,18 @@ def outplace_fused_experts(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||||
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
|
||||||
False, use_fp8_w8a8, use_int8_w8a16, w1_scale,
|
False, use_fp8_w8a8, use_int8_w8a16,
|
||||||
w2_scale, a1_scale, a2_scale, block_shape)
|
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||||
|
a1_scale, a2_scale, block_shape)
|
||||||
|
|
||||||
|
|
||||||
def outplace_fused_experts_fake(
|
def outplace_fused_experts_fake(
|
||||||
@ -611,8 +884,11 @@ def outplace_fused_experts_fake(
|
|||||||
topk_ids: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
block_shape: Optional[List[int]] = None) -> torch.Tensor:
|
||||||
@ -635,8 +911,11 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None):
|
block_shape: Optional[List[int]] = None):
|
||||||
@ -644,16 +923,15 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
|
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
|
||||||
topk_weights, topk_ids,
|
topk_weights, topk_ids,
|
||||||
use_fp8_w8a8, use_int8_w8a16,
|
use_fp8_w8a8, use_int8_w8a16,
|
||||||
w1_scale, w2_scale, a1_scale,
|
use_int4_w4a16, w1_scale,
|
||||||
|
w2_scale, w1_zp, w2_zp, a1_scale,
|
||||||
a2_scale, block_shape)
|
a2_scale, block_shape)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
else:
|
else:
|
||||||
return torch.ops.vllm.outplace_fused_experts(hidden_states, w1, w2,
|
return torch.ops.vllm.outplace_fused_experts(
|
||||||
topk_weights, topk_ids,
|
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
|
||||||
use_fp8_w8a8,
|
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
|
||||||
use_int8_w8a16, w1_scale,
|
a1_scale, a2_scale, block_shape)
|
||||||
w2_scale, a1_scale,
|
|
||||||
a2_scale, block_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_impl(hidden_states: torch.Tensor,
|
def fused_experts_impl(hidden_states: torch.Tensor,
|
||||||
@ -664,13 +942,21 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
inplace: bool = False,
|
inplace: bool = False,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None):
|
block_shape: Optional[List[int]] = None):
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
if use_int4_w4a16:
|
||||||
|
assert hidden_states.shape[1] // 2 == w1.shape[
|
||||||
|
2], "Hidden size mismatch"
|
||||||
|
else:
|
||||||
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||||
|
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||||
@ -687,6 +973,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
M = min(num_tokens, CHUNK_SIZE)
|
M = min(num_tokens, CHUNK_SIZE)
|
||||||
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
dtype=hidden_states.dtype)
|
dtype=hidden_states.dtype)
|
||||||
|
|
||||||
get_config_func = functools.partial(
|
get_config_func = functools.partial(
|
||||||
@ -755,6 +1042,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
a1_scale,
|
a1_scale,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
|
w1_zp,
|
||||||
curr_topk_weights,
|
curr_topk_weights,
|
||||||
curr_topk_ids,
|
curr_topk_ids,
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
@ -766,6 +1054,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
torch.ops._C.silu_and_mul(intermediate_cache2,
|
torch.ops._C.silu_and_mul(intermediate_cache2,
|
||||||
@ -776,6 +1065,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
a2_scale,
|
a2_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
|
w2_zp,
|
||||||
curr_topk_weights,
|
curr_topk_weights,
|
||||||
curr_topk_ids,
|
curr_topk_ids,
|
||||||
sorted_token_ids,
|
sorted_token_ids,
|
||||||
@ -787,6 +1077,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
|
|||||||
compute_type=compute_type,
|
compute_type=compute_type,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|
||||||
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
|
||||||
@ -808,8 +1099,11 @@ def fused_moe(
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
use_fp8_w8a8: bool = False,
|
use_fp8_w8a8: bool = False,
|
||||||
use_int8_w8a16: bool = False,
|
use_int8_w8a16: bool = False,
|
||||||
|
use_int4_w4a16: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
w1_zp: Optional[torch.Tensor] = None,
|
||||||
|
w2_zp: Optional[torch.Tensor] = None,
|
||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
block_shape: Optional[List[int]] = None,
|
block_shape: Optional[List[int]] = None,
|
||||||
@ -834,8 +1128,12 @@ def fused_moe(
|
|||||||
note: Deepseekv2 model uses grouped_topk
|
note: Deepseekv2 model uses grouped_topk
|
||||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||||
products for w1 and w2. Defaults to False.
|
products for w1 and w2. Defaults to False.
|
||||||
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
|
- use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
|
||||||
products for w1 and w2. Defaults to False.
|
activation to compute the inner products for w1 and w2.
|
||||||
|
Defaults to False.
|
||||||
|
- use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
|
||||||
|
activation to compute the inner products for w1 and w2.
|
||||||
|
Defaults to False.
|
||||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
w1.
|
w1.
|
||||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||||
@ -873,8 +1171,11 @@ def fused_moe(
|
|||||||
inplace=inplace,
|
inplace=inplace,
|
||||||
use_fp8_w8a8=use_fp8_w8a8,
|
use_fp8_w8a8=use_fp8_w8a8,
|
||||||
use_int8_w8a16=use_int8_w8a16,
|
use_int8_w8a16=use_int8_w8a16,
|
||||||
|
use_int4_w4a16=use_int4_w4a16,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
|
w1_zp=w1_zp,
|
||||||
|
w2_zp=w2_zp,
|
||||||
a1_scale=a1_scale,
|
a1_scale=a1_scale,
|
||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
block_shape=block_shape)
|
block_shape=block_shape)
|
||||||
|
|||||||
@ -26,7 +26,8 @@ QUANTIZATION_METHODS: List[str] = [
|
|||||||
"experts_int8",
|
"experts_int8",
|
||||||
"neuron_quant",
|
"neuron_quant",
|
||||||
"ipex",
|
"ipex",
|
||||||
"quark"
|
"quark",
|
||||||
|
"moe_wna16"
|
||||||
]
|
]
|
||||||
|
|
||||||
# The customized quantization methods which will be added to this dict.
|
# The customized quantization methods which will be added to this dict.
|
||||||
@ -94,6 +95,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
from .ipex_quant import IPEXConfig
|
from .ipex_quant import IPEXConfig
|
||||||
from .marlin import MarlinConfig
|
from .marlin import MarlinConfig
|
||||||
from .modelopt import ModelOptFp8Config
|
from .modelopt import ModelOptFp8Config
|
||||||
|
from .moe_wna16 import MoeWNA16Config
|
||||||
from .neuron_quant import NeuronQuantConfig
|
from .neuron_quant import NeuronQuantConfig
|
||||||
from .qqq import QQQConfig
|
from .qqq import QQQConfig
|
||||||
from .tpu_int8 import Int8TpuConfig
|
from .tpu_int8 import Int8TpuConfig
|
||||||
@ -121,7 +123,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
"experts_int8": ExpertsInt8Config,
|
"experts_int8": ExpertsInt8Config,
|
||||||
"neuron_quant": NeuronQuantConfig,
|
"neuron_quant": NeuronQuantConfig,
|
||||||
"ipex": IPEXConfig,
|
"ipex": IPEXConfig,
|
||||||
"quark": QuarkConfig
|
"quark": QuarkConfig,
|
||||||
|
"moe_wna16": MoeWNA16Config,
|
||||||
}
|
}
|
||||||
# Update the `method_to_config` with customized quantization methods.
|
# Update the `method_to_config` with customized quantization methods.
|
||||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||||
|
|||||||
424
vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
424
vllm/model_executor/layers/quantization/moe_wna16.py
Normal file
@ -0,0 +1,424 @@
|
|||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||||
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
||||||
|
AWQLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
||||||
|
AWQMarlinConfig, AWQMarlinLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
|
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
|
||||||
|
GPTQLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQMarlinConfig, GPTQMarlinLinearMethod)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
class MoeWNA16Config(QuantizationConfig):
|
||||||
|
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
|
||||||
|
|
||||||
|
def __init__(self, linear_quant_method: str, weight_bits: int,
|
||||||
|
group_size: int, has_zp: bool, lm_head_quantized: bool,
|
||||||
|
modules_to_not_convert: Optional[List[str]],
|
||||||
|
full_config: Dict[str, Any]) -> None:
|
||||||
|
self.weight_bits = weight_bits
|
||||||
|
self.group_size = group_size
|
||||||
|
self.has_zp = has_zp
|
||||||
|
self.bit8_pack_factor = 8 // self.weight_bits
|
||||||
|
self.lm_head_quantized = lm_head_quantized
|
||||||
|
self.linear_quant_method = linear_quant_method
|
||||||
|
self.full_config = full_config
|
||||||
|
self.use_marlin = False
|
||||||
|
if self.linear_quant_method == "gptq":
|
||||||
|
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(
|
||||||
|
full_config)
|
||||||
|
elif self.linear_quant_method == "awq":
|
||||||
|
capability_tuple = current_platform.get_device_capability()
|
||||||
|
device_capability = (-1 if capability_tuple is None else
|
||||||
|
capability_tuple.to_int())
|
||||||
|
awq_min_capability = AWQConfig.get_min_capability()
|
||||||
|
if device_capability < awq_min_capability:
|
||||||
|
raise ValueError(
|
||||||
|
"The quantization method moe_wna16 + awq is not supported "
|
||||||
|
"for the current GPU. "
|
||||||
|
f"Minimum capability: {awq_min_capability}. "
|
||||||
|
f"Current capability: {device_capability}.")
|
||||||
|
self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible(
|
||||||
|
full_config)
|
||||||
|
else:
|
||||||
|
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||||
|
|
||||||
|
if modules_to_not_convert is None:
|
||||||
|
self.modules_to_not_convert = []
|
||||||
|
else:
|
||||||
|
self.modules_to_not_convert = modules_to_not_convert
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "moe_wna16"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.bfloat16, torch.half]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 70
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return ["quantize_config.json"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
|
||||||
|
linear_quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||||
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||||
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
|
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||||
|
default=False)
|
||||||
|
if linear_quant_method == "gptq":
|
||||||
|
has_zp = not cls.get_from_keys(config, ["sym"])
|
||||||
|
modules_to_not_convert = []
|
||||||
|
elif linear_quant_method == "awq":
|
||||||
|
has_zp = cls.get_from_keys(config, ["zero_point"])
|
||||||
|
modules_to_not_convert = cls.get_from_keys(
|
||||||
|
config, ["modules_to_not_convert"])
|
||||||
|
else:
|
||||||
|
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||||
|
|
||||||
|
return cls(linear_quant_method, weight_bits, group_size, has_zp,
|
||||||
|
lm_head_quantized, modules_to_not_convert, config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def override_quantization_method(cls, hf_quant_cfg,
|
||||||
|
user_quant) -> Optional[str]:
|
||||||
|
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
||||||
|
if can_convert and user_quant == "moe_wna16":
|
||||||
|
return cls.get_name()
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
|
||||||
|
# Extract data from quant config.
|
||||||
|
quant_method = quant_config.get("quant_method", "").lower()
|
||||||
|
num_bits = quant_config.get("bits")
|
||||||
|
desc_act = quant_config.get("desc_act")
|
||||||
|
|
||||||
|
capability_tuple = current_platform.get_device_capability()
|
||||||
|
device_capability = (-1 if capability_tuple is None else
|
||||||
|
capability_tuple.to_int())
|
||||||
|
awq_min_capability = AWQConfig.get_min_capability()
|
||||||
|
|
||||||
|
gptq_compatible = quant_method == "gptq" and \
|
||||||
|
not desc_act and num_bits in [4, 8]
|
||||||
|
awq_compatible = quant_method == "awq" and num_bits == 4 and \
|
||||||
|
device_capability >= awq_min_capability
|
||||||
|
|
||||||
|
return gptq_compatible or awq_compatible
|
||||||
|
|
||||||
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
|
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
elif isinstance(layer, FusedMoE):
|
||||||
|
return MoeWNA16Method(self)
|
||||||
|
else:
|
||||||
|
if self.linear_quant_method == "gptq":
|
||||||
|
if self.use_marlin:
|
||||||
|
return GPTQMarlinLinearMethod(
|
||||||
|
GPTQMarlinConfig.from_config(self.full_config))
|
||||||
|
else:
|
||||||
|
return GPTQLinearMethod(
|
||||||
|
GPTQConfig.from_config(self.full_config))
|
||||||
|
elif self.linear_quant_method == "awq":
|
||||||
|
if self.use_marlin:
|
||||||
|
return AWQMarlinLinearMethod(
|
||||||
|
AWQMarlinConfig.from_config(self.full_config))
|
||||||
|
else:
|
||||||
|
return AWQLinearMethod(
|
||||||
|
AWQConfig.from_config(self.full_config))
|
||||||
|
else:
|
||||||
|
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||||
|
|
||||||
|
|
||||||
|
def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
|
||||||
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||||
|
|
||||||
|
|
||||||
|
class MoeWNA16Method(FusedMoEMethodBase):
|
||||||
|
"""Linear method for MOE WNA16 (W8A16/W4A16) quantization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: MoeWNA16Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
layer.quant_config = self.quant_config
|
||||||
|
bit8_pack_factor = self.quant_config.bit8_pack_factor
|
||||||
|
group_size = self.quant_config.group_size
|
||||||
|
group_size_div_factor = 1
|
||||||
|
|
||||||
|
# make intermediate_size and hidden_size diviable by group_size
|
||||||
|
# we reduce the group size to ensure that
|
||||||
|
# and we would repeat the loaded_weight later
|
||||||
|
while intermediate_size_per_partition % group_size or \
|
||||||
|
hidden_size % group_size:
|
||||||
|
group_size = group_size // 2
|
||||||
|
group_size_div_factor *= 2
|
||||||
|
assert group_size >= 32
|
||||||
|
layer.group_size = group_size
|
||||||
|
layer.group_size_div_factor = group_size_div_factor
|
||||||
|
|
||||||
|
strategy = FusedMoeWeightScaleSupported.GROUP.value
|
||||||
|
extra_weight_attrs.update({
|
||||||
|
"quant_method": strategy,
|
||||||
|
"is_transposed": False
|
||||||
|
})
|
||||||
|
|
||||||
|
assert 'weight_loader' in extra_weight_attrs
|
||||||
|
weight_loader = extra_weight_attrs['weight_loader']
|
||||||
|
wrapped_weight_loader = MoeWNA16Method.get_weight_loader(
|
||||||
|
layer, weight_loader)
|
||||||
|
extra_weight_attrs['weight_loader'] = wrapped_weight_loader
|
||||||
|
|
||||||
|
# Fused gate_up_proj (column parallel)
|
||||||
|
w13_qweight = torch.nn.Parameter(torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_size // bit8_pack_factor,
|
||||||
|
dtype=torch.uint8),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_qweight", w13_qweight)
|
||||||
|
set_weight_attrs(w13_qweight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# down_proj (row parallel)
|
||||||
|
w2_qweight = torch.nn.Parameter(torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition // bit8_pack_factor,
|
||||||
|
dtype=torch.uint8),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_qweight", w2_qweight)
|
||||||
|
set_weight_attrs(w2_qweight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w13_scales = torch.nn.Parameter(torch.zeros(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_size // group_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_scales", w13_scales)
|
||||||
|
set_weight_attrs(w13_scales, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_scales = torch.nn.Parameter(torch.zeros(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition // group_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_scales", w2_scales)
|
||||||
|
set_weight_attrs(w2_scales, extra_weight_attrs)
|
||||||
|
|
||||||
|
if self.quant_config.has_zp:
|
||||||
|
w13_qzeros = torch.nn.Parameter(torch.zeros(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition // bit8_pack_factor,
|
||||||
|
hidden_size // group_size,
|
||||||
|
dtype=torch.uint8),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_qzeros", w13_qzeros)
|
||||||
|
set_weight_attrs(w13_qzeros, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_qzeros = torch.nn.Parameter(torch.zeros(
|
||||||
|
num_experts,
|
||||||
|
hidden_size // bit8_pack_factor,
|
||||||
|
intermediate_size_per_partition // group_size,
|
||||||
|
dtype=torch.uint8),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_qzeros", w2_qzeros)
|
||||||
|
set_weight_attrs(w2_qzeros, extra_weight_attrs)
|
||||||
|
|
||||||
|
if self.quant_config.linear_quant_method == "gptq":
|
||||||
|
# some param are unused, but we need to init them in order to
|
||||||
|
# load weights
|
||||||
|
invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
|
||||||
|
if not self.quant_config.has_zp:
|
||||||
|
invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
|
||||||
|
for key in invalid_param_keys:
|
||||||
|
param = torch.nn.Parameter(torch.empty((0, ),
|
||||||
|
dtype=torch.int32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter(key, param)
|
||||||
|
set_weight_attrs(param, extra_weight_attrs)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
|
weight_bits = self.quant_config.weight_bits
|
||||||
|
has_zp = self.quant_config.has_zp
|
||||||
|
|
||||||
|
return fused_experts(x,
|
||||||
|
layer.w13_qweight,
|
||||||
|
layer.w2_qweight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
use_int4_w4a16=weight_bits == 4,
|
||||||
|
use_int8_w8a16=weight_bits == 8,
|
||||||
|
w1_scale=layer.w13_scales,
|
||||||
|
w2_scale=layer.w2_scales,
|
||||||
|
w1_zp=layer.w13_qzeros if has_zp else None,
|
||||||
|
w2_zp=layer.w2_qzeros if has_zp else None,
|
||||||
|
block_shape=[0, layer.group_size])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_weight_loader(layer, weight_loader):
|
||||||
|
|
||||||
|
def convert_awq_tensor(tensor, tensor_type):
|
||||||
|
# convert awq qweight/qzeros to a standard format (assume int4)
|
||||||
|
# qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
|
||||||
|
# qzeros: (k // group_size, n // pack_factor_bit32) ->
|
||||||
|
# (n // pack_factor_bit8, k // group_size)
|
||||||
|
# pack_factor_bit32 = 32 // weight_bits
|
||||||
|
# pack_factor_bit8 = 8 // weight_bits
|
||||||
|
|
||||||
|
# 0. suppose origin shape (a, b), dtype int32
|
||||||
|
# 1. convert to uint8, shape (a, b) -> (a, 4 * b)
|
||||||
|
size0 = tensor.size(0)
|
||||||
|
tensor = tensor.view(torch.uint8)
|
||||||
|
|
||||||
|
# 2. unpack to uint4 (only when weight_bits == 4)
|
||||||
|
# shape (a, 4 * b) -> (a, 4 * b, 2)
|
||||||
|
shifter = torch.tensor([0, 4],
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=tensor.device)
|
||||||
|
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||||
|
|
||||||
|
# 3. change order, see
|
||||||
|
# https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
|
||||||
|
# shape -> (a, 4 * b * pack_factor_bit8)
|
||||||
|
reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||||
|
tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
|
||||||
|
tensor = tensor.view(size0, -1)
|
||||||
|
|
||||||
|
# 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
|
||||||
|
tensor = tensor.T.contiguous()
|
||||||
|
|
||||||
|
# 5. repack (only when weight_bits == 4)
|
||||||
|
# qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
|
||||||
|
# qzeros shape -> (4 * b, a)
|
||||||
|
|
||||||
|
if tensor_type == "qweight":
|
||||||
|
tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
|
||||||
|
elif tensor_type == "qzeros":
|
||||||
|
tensor = tensor[1::2, :] * 16 + tensor[::2, :]
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def convert_gptq_int4_qzeros(tensor):
|
||||||
|
tensor = tensor.view(torch.uint8)
|
||||||
|
shifter = torch.tensor([0, 4],
|
||||||
|
dtype=torch.uint8,
|
||||||
|
device=tensor.device)
|
||||||
|
tensor = (tensor[:, :, None] >> shifter) & 0xF
|
||||||
|
tensor = tensor + 1
|
||||||
|
tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def moe_wna16_weight_loader(param: torch.nn.Parameter,
|
||||||
|
loaded_weight: torch.Tensor,
|
||||||
|
weight_name: str, shard_id: str,
|
||||||
|
expert_id: int):
|
||||||
|
if "g_idx" in weight_name:
|
||||||
|
return
|
||||||
|
if not layer.quant_config.has_zp and "qzeros" in weight_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
device = get_tp_group().device
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
loaded_weight = loaded_weight.to(device)
|
||||||
|
shard_size = layer.intermediate_size_per_partition
|
||||||
|
|
||||||
|
# convert gptq and awq weight to a standard format
|
||||||
|
if layer.quant_config.linear_quant_method == "awq":
|
||||||
|
assert layer.quant_config.weight_bits == 4
|
||||||
|
if "weight" in weight_name:
|
||||||
|
loaded_weight = convert_awq_tensor(loaded_weight,
|
||||||
|
"qweight")
|
||||||
|
elif "zeros" in weight_name:
|
||||||
|
loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
|
||||||
|
else:
|
||||||
|
loaded_weight = loaded_weight.T
|
||||||
|
elif layer.quant_config.linear_quant_method == "gptq":
|
||||||
|
assert layer.quant_config.weight_bits in [4, 8]
|
||||||
|
if "weight" in weight_name:
|
||||||
|
loaded_weight = loaded_weight.T.contiguous().view(
|
||||||
|
torch.uint8)
|
||||||
|
elif "zeros" in weight_name:
|
||||||
|
# add 1 to gptq qzeros to align with awq
|
||||||
|
loaded_weight = loaded_weight.view(torch.uint8)
|
||||||
|
if layer.quant_config.weight_bits == 4:
|
||||||
|
loaded_weight = convert_gptq_int4_qzeros(
|
||||||
|
loaded_weight).T
|
||||||
|
else:
|
||||||
|
loaded_weight = loaded_weight.T + 1
|
||||||
|
else:
|
||||||
|
loaded_weight = loaded_weight.T
|
||||||
|
|
||||||
|
# repeat the qzeros/scales to fit new group size
|
||||||
|
if layer.group_size_div_factor > 1 and \
|
||||||
|
"qzeros" in weight_name or "scales" in weight_name:
|
||||||
|
loaded_weight = loaded_weight.repeat_interleave(
|
||||||
|
layer.group_size_div_factor, 1)
|
||||||
|
|
||||||
|
if "w13_qzeros" in weight_name:
|
||||||
|
tensor = loaded_weight.view(layer.tp_size, -1,
|
||||||
|
loaded_weight.size(1))[tp_rank]
|
||||||
|
if shard_id == "w1":
|
||||||
|
param.data[expert_id, :shard_size // 2] = tensor
|
||||||
|
else:
|
||||||
|
param.data[expert_id, shard_size // 2:] = tensor
|
||||||
|
elif "w2_qzeros" in weight_name:
|
||||||
|
param.data[expert_id] = loaded_weight.view(
|
||||||
|
loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank]
|
||||||
|
else:
|
||||||
|
weight_loader(param, loaded_weight, weight_name, shard_id,
|
||||||
|
expert_id)
|
||||||
|
|
||||||
|
return moe_wna16_weight_loader
|
||||||
Loading…
x
Reference in New Issue
Block a user