mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:35:28 +08:00
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
813 lines
23 KiB
Python
813 lines
23 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""Tests for the marlin kernel.
|
|
|
|
Run `pytest tests/kernels/quantization/test_marlin_gemm.py`.
|
|
"""
|
|
|
|
import itertools
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
|
from tests.quantization.utils import is_quant_method_supported
|
|
from vllm import _custom_ops as ops
|
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
|
GPTQ_MARLIN_24_MAX_PARALLEL,
|
|
GPTQ_MARLIN_24_MIN_THREAD_N,
|
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
|
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
|
per_token_quant_int8,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|
marlin_make_empty_g_idx,
|
|
marlin_make_workspace_new,
|
|
marlin_permute_bias,
|
|
marlin_permute_scales,
|
|
query_marlin_supported_quant_types,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
|
rand_marlin_weight_mxfp4_like,
|
|
rand_marlin_weight_nvfp4_like,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|
marlin_quant_fp8_torch,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
|
MarlinWorkspace,
|
|
awq_marlin_quantize,
|
|
get_weight_perm,
|
|
marlin_quantize,
|
|
marlin_weights,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
|
marlin_24_quantize,
|
|
)
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
awq_pack,
|
|
gptq_pack,
|
|
gptq_quantize_weights,
|
|
quantize_weights,
|
|
sort_weights,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
from vllm.scalar_type import scalar_types
|
|
|
|
if current_platform.is_rocm():
|
|
pytest.skip(
|
|
"These tests require gptq_marlin_repack,"
|
|
"marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
|
|
"or gptq_marlin_gemm which are not supported on ROCm.",
|
|
allow_module_level=True,
|
|
)
|
|
|
|
ACT_ORDER_OPTS = [False, True]
|
|
K_FULL_OPTS = [False, True]
|
|
USE_ATOMIC_ADD_OPTS = [False, True]
|
|
USE_FP32_REDUCE_OPTS = [True]
|
|
|
|
MARLIN_K_CHUNKS = [128]
|
|
MARLIN_N_CHUNKS = [64, 256]
|
|
|
|
MARLIN_24_K_CHUNKS = [128]
|
|
MARLIN_24_N_CHUNKS = [512]
|
|
|
|
HQQ_SUPPORTED_GROUP_SIZES = [64]
|
|
|
|
MARLIN_REPACK_NK_FACTORS = [
|
|
(4, 8),
|
|
(7, 5),
|
|
(13, 11),
|
|
]
|
|
|
|
MNK_FACTORS = [
|
|
(1, 1, 1),
|
|
(1, 4, 8),
|
|
(26, 37, 13),
|
|
(257, 13, 11),
|
|
]
|
|
|
|
DTYPES = [torch.float16, torch.bfloat16]
|
|
|
|
DENSE_MARLIN_QUANT_TEST_CONFIGS = [
|
|
# AWQ-INT4
|
|
{"b_type": scalar_types.uint4, "group_blocks": [-1, 2, 4, 8]},
|
|
# GPTQ-INT4
|
|
{
|
|
"b_type": scalar_types.uint4b8,
|
|
"support_act_order": True,
|
|
"group_blocks": [-1, 2, 4, 8],
|
|
},
|
|
# GPTQ-INT8
|
|
{
|
|
"b_type": scalar_types.uint8b128,
|
|
"support_act_order": True,
|
|
"group_blocks": [-1, 2, 4, 8],
|
|
},
|
|
# FP8
|
|
{"b_type": scalar_types.float8_e4m3fn, "group_blocks": [-1, 8]},
|
|
# NVFP4
|
|
{"b_type": scalar_types.float4_e2m1f, "group_blocks": [1]},
|
|
# MXFP4
|
|
{
|
|
"a_type": [scalar_types.bfloat16],
|
|
"b_type": scalar_types.float4_e2m1f,
|
|
"group_blocks": [2],
|
|
},
|
|
# AWQ-INT4 with INT8 activation
|
|
{
|
|
"a_type": [scalar_types.int8],
|
|
"b_type": scalar_types.uint4,
|
|
"group_blocks": [-1, 2, 4, 8],
|
|
},
|
|
# GPTQ-INT4 with INT8 activation
|
|
{
|
|
"a_type": [scalar_types.int8],
|
|
"b_type": scalar_types.uint4b8,
|
|
"group_blocks": [-1, 2, 4, 8],
|
|
},
|
|
# GPTQ-INT4 with FP8 activation
|
|
{
|
|
"a_type": [scalar_types.float8_e4m3fn],
|
|
"b_type": scalar_types.uint4b8,
|
|
"group_blocks": [-1, 2, 4, 8],
|
|
},
|
|
# AWQ-INT4 with FP8 activation
|
|
{
|
|
"a_type": [scalar_types.float8_e4m3fn],
|
|
"b_type": scalar_types.uint4,
|
|
"group_blocks": [-1, 2, 4, 8],
|
|
},
|
|
# MXFP4 with FP8 activation
|
|
{
|
|
"a_type": [scalar_types.float8_e4m3fn],
|
|
"b_type": scalar_types.float4_e2m1f,
|
|
"c_type": [scalar_types.bfloat16],
|
|
"group_blocks": [2],
|
|
},
|
|
]
|
|
|
|
|
|
def compute_max_diff(output, output_ref):
|
|
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
|
torch.abs(output_ref)
|
|
)
|
|
|
|
|
|
def rand_data(shape, dtype=torch.float16):
|
|
return torch.randn(shape, dtype=dtype, device="cuda")
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.",
|
|
)
|
|
def test_marlin_int4_fp8_preprocess_without_zp():
|
|
qweight_unpacked = torch.randint(
|
|
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
|
|
)
|
|
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
|
|
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
|
|
|
|
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed)
|
|
|
|
torch_res = torch.where(
|
|
qweight_unpacked >= 8, qweight_unpacked - 8, 15 - qweight_unpacked
|
|
)
|
|
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
|
|
torch_res = torch_res.to(torch.int8).view(torch.int32)
|
|
|
|
assert (cuda_res == torch_res).all()
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.",
|
|
)
|
|
def test_marlin_int4_fp8_preprocess_awq():
|
|
group_size = 128
|
|
|
|
qweight_unpacked = torch.randint(
|
|
0, 16, size=(2048, 2048), dtype=torch.int32, device="cuda"
|
|
)
|
|
qzeros_unpacked = torch.randint(
|
|
0, 16, size=(2048 // group_size, 2048), dtype=torch.int32, device="cuda"
|
|
)
|
|
|
|
qweight_packed = qweight_unpacked[:, ::2] * 16 + qweight_unpacked[:, 1::2]
|
|
qweight_packed = qweight_packed.to(torch.int8).view(torch.int32)
|
|
qzeros_packed = qzeros_unpacked[:, ::2] * 16 + qzeros_unpacked[:, 1::2]
|
|
qzeros_packed = qzeros_packed.to(torch.int8).view(torch.int32)
|
|
|
|
cuda_res = ops.marlin_int4_fp8_preprocess(qweight_packed, qzeros_packed)
|
|
|
|
repeated_zp = qzeros_unpacked.repeat_interleave(group_size, 0)
|
|
torch_res = qweight_unpacked - repeated_zp
|
|
torch_res[torch_res < 0] = 15 - qweight_unpacked[torch_res < 0]
|
|
torch_res = torch_res[:, ::2] * 16 + torch_res[:, 1::2]
|
|
torch_res = torch_res.to(torch.int8).view(torch.int32)
|
|
|
|
assert (cuda_res == torch_res).all()
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.",
|
|
)
|
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
|
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(False, False))
|
|
@pytest.mark.parametrize("act_order", ACT_ORDER_OPTS)
|
|
@pytest.mark.parametrize("is_a_8bit", [True, False])
|
|
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
|
|
def test_gptq_marlin_repack(
|
|
k_chunk, n_chunk, quant_type, act_order, is_a_8bit, nk_factors
|
|
):
|
|
n_factor, k_factor = nk_factors
|
|
|
|
size_k = k_chunk * k_factor
|
|
size_n = n_chunk * n_factor
|
|
group_size = 128
|
|
|
|
# Filter act_order
|
|
if act_order:
|
|
if group_size == -1:
|
|
return
|
|
if group_size == size_k:
|
|
return
|
|
if is_a_8bit:
|
|
return
|
|
|
|
# Normalize group_size
|
|
if group_size == -1:
|
|
group_size = size_k
|
|
assert group_size <= size_k
|
|
|
|
# Create input
|
|
b_weight = rand_data((size_k, size_n))
|
|
|
|
# Quantize (and apply act_order if provided)
|
|
w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights(
|
|
b_weight, quant_type, group_size, act_order
|
|
)
|
|
|
|
# Pack to GPTQ format
|
|
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
|
|
|
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
|
# increasing
|
|
sort_indices = torch.empty(0, dtype=torch.int, device=b_weight.device)
|
|
if act_order:
|
|
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
|
|
|
# Pack to Marlin format
|
|
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
|
|
marlin_q_w_1 = marlin_weights(
|
|
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
|
|
)
|
|
|
|
opcheck(
|
|
torch.ops._C.gptq_marlin_repack,
|
|
(q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit),
|
|
)
|
|
|
|
# Run Marlin repack GPU kernel
|
|
marlin_q_w_2 = ops.gptq_marlin_repack(
|
|
q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits, is_a_8bit
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.",
|
|
)
|
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
|
@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types(True))
|
|
@pytest.mark.parametrize("is_a_8bit", [True, False])
|
|
@pytest.mark.parametrize("nk_factors", MARLIN_REPACK_NK_FACTORS)
|
|
def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, is_a_8bit, nk_factors):
|
|
n_factor, k_factor = nk_factors
|
|
|
|
size_k = k_chunk * k_factor
|
|
size_n = n_chunk * n_factor
|
|
|
|
group_size = 128
|
|
|
|
# Create input
|
|
b_weight = rand_data((size_k, size_n))
|
|
|
|
# Quantize
|
|
w_ref, q_w, s, zp = quantize_weights(
|
|
b_weight, quant_type, group_size, zero_points=True
|
|
)
|
|
|
|
# Pack to AWQ format
|
|
q_w_awq = awq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
|
|
|
# Pack to Marlin format
|
|
weight_perm = get_weight_perm(quant_type.size_bits, is_a_8bit)
|
|
marlin_q_w_1 = marlin_weights(
|
|
q_w, size_k, size_n, quant_type.size_bits, weight_perm, is_a_8bit
|
|
)
|
|
|
|
opcheck(
|
|
torch.ops._C.awq_marlin_repack,
|
|
(q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit),
|
|
)
|
|
|
|
# Run Marlin repack GPU kernel
|
|
marlin_q_w_2 = ops.awq_marlin_repack(
|
|
q_w_awq, size_k, size_n, quant_type.size_bits, is_a_8bit
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
|
|
|
|
|
def marlin_generate_valid_test_cases():
|
|
all_combinations = itertools.product(
|
|
DENSE_MARLIN_QUANT_TEST_CONFIGS,
|
|
MNK_FACTORS,
|
|
MARLIN_N_CHUNKS,
|
|
MARLIN_K_CHUNKS,
|
|
ACT_ORDER_OPTS,
|
|
K_FULL_OPTS,
|
|
USE_ATOMIC_ADD_OPTS,
|
|
USE_FP32_REDUCE_OPTS,
|
|
)
|
|
|
|
def is_invalid(
|
|
a_type,
|
|
b_type,
|
|
c_type,
|
|
group_blocks,
|
|
size_m,
|
|
size_n,
|
|
size_k,
|
|
act_order,
|
|
is_k_full,
|
|
use_atomic_add,
|
|
use_fp32_reduce,
|
|
):
|
|
if use_atomic_add:
|
|
if use_fp32_reduce:
|
|
return False
|
|
if (
|
|
c_type == scalar_types.bfloat16
|
|
and torch.cuda.get_device_capability()[0] < 9
|
|
):
|
|
return False
|
|
|
|
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
|
if group_size > 0 and size_k % group_size != 0:
|
|
return False
|
|
|
|
if act_order and group_size in [-1, size_k]:
|
|
return False
|
|
if group_size == size_k:
|
|
return False
|
|
if not act_order and is_k_full:
|
|
return False
|
|
|
|
return a_type.size_bits < 16 or a_type is c_type
|
|
|
|
cases = []
|
|
for case in all_combinations:
|
|
quant_test_config, mnk_factors, n_chunk, k_chunk, act_order, *_ = case
|
|
size_m = mnk_factors[0]
|
|
size_n = mnk_factors[1] * n_chunk
|
|
size_k = mnk_factors[2] * k_chunk
|
|
|
|
if act_order and not quant_test_config.get("support_act_order", False):
|
|
continue
|
|
|
|
f16_types = [scalar_types.float16, scalar_types.bfloat16]
|
|
inner_combinations = itertools.product(
|
|
quant_test_config.get("a_type", f16_types),
|
|
[quant_test_config["b_type"]],
|
|
quant_test_config.get("c_type", f16_types),
|
|
quant_test_config["group_blocks"],
|
|
)
|
|
|
|
for sub_case in inner_combinations:
|
|
if (
|
|
sub_case[0] == scalar_types.float8_e4m3fn
|
|
and current_platform.get_device_capability() not in [89, 120]
|
|
):
|
|
continue
|
|
args = sub_case + (size_m, size_n, size_k) + case[4:]
|
|
if is_invalid(*args):
|
|
cases.append(args)
|
|
return cases
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.",
|
|
)
|
|
@pytest.mark.parametrize(
|
|
(
|
|
"a_type, b_type, c_type, group_blocks,"
|
|
"size_m, size_n, size_k, act_order, is_k_full,"
|
|
"use_atomic_add, use_fp32_reduce"
|
|
),
|
|
marlin_generate_valid_test_cases(),
|
|
)
|
|
def test_gptq_marlin_gemm(
|
|
a_type,
|
|
b_type,
|
|
c_type,
|
|
group_blocks,
|
|
size_m,
|
|
size_n,
|
|
size_k,
|
|
act_order,
|
|
is_k_full,
|
|
use_atomic_add,
|
|
use_fp32_reduce,
|
|
):
|
|
has_zp = b_type in [scalar_types.uint4, scalar_types.uint8]
|
|
|
|
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
|
|
|
|
if c_type == scalar_types.float16:
|
|
dtype = torch.float16
|
|
elif c_type == scalar_types.bfloat16:
|
|
dtype = torch.bfloat16
|
|
else:
|
|
raise RuntimeError("unsupported c_type")
|
|
|
|
if a_type == scalar_types.int8:
|
|
a_dtype = torch.int8
|
|
elif a_type == scalar_types.float8_e4m3fn:
|
|
a_dtype = torch.float8_e4m3fn
|
|
else:
|
|
a_dtype = dtype
|
|
|
|
a_input = rand_data((size_m, size_k), dtype=dtype)
|
|
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
|
|
|
if b_type == scalar_types.float4_e2m1f:
|
|
if group_size == 16:
|
|
w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_nvfp4_like(
|
|
b_weight.T, group_size, input_dtype=a_dtype
|
|
)
|
|
else:
|
|
w_ref, marlin_q_w, marlin_s = rand_marlin_weight_mxfp4_like(
|
|
b_weight.T, group_size, input_dtype=a_dtype
|
|
)
|
|
marlin_s2 = None
|
|
|
|
g_idx = None
|
|
sort_indices = None
|
|
marlin_zp = None
|
|
elif b_type == scalar_types.float8_e4m3fn:
|
|
w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(
|
|
b_weight.T, group_size, input_dtype=a_dtype
|
|
)
|
|
g_idx = None
|
|
sort_indices = None
|
|
marlin_zp = None
|
|
marlin_s2 = None
|
|
elif has_zp:
|
|
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
|
b_weight, b_type, group_size, input_dtype=a_dtype
|
|
)
|
|
g_idx = None
|
|
sort_indices = None
|
|
marlin_s2 = None
|
|
else:
|
|
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
|
b_weight, b_type, group_size, act_order, input_dtype=a_dtype
|
|
)
|
|
|
|
marlin_zp = None
|
|
marlin_s2 = None
|
|
|
|
workspace = marlin_make_workspace_new(w_ref.device)
|
|
|
|
if a_type == scalar_types.int8:
|
|
a_input, a_scales = per_token_quant_int8(a_input)
|
|
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
|
|
a_input_ref = a_input_ref.to(dtype)
|
|
|
|
if group_size != -1:
|
|
a_scales = a_scales / 4096 * marlin_s.max()
|
|
a_scales = a_scales.float()
|
|
marlin_s = marlin_s / marlin_s.max() * 4096
|
|
marlin_s = marlin_s.round().to(torch.int16).view(dtype)
|
|
elif a_type == scalar_types.float8_e4m3fn:
|
|
a_input, a_scales = ops.scaled_fp8_quant(a_input, use_per_token_if_dynamic=True)
|
|
a_input_ref = a_input.to(a_scales.dtype) * a_scales.view(-1, 1)
|
|
a_input_ref = a_input_ref.to(dtype)
|
|
else:
|
|
assert a_type.size_bits == 16
|
|
a_input_ref = a_input
|
|
a_scales = None
|
|
|
|
output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
|
|
|
|
output = ops.gptq_marlin_gemm(
|
|
a_input,
|
|
output,
|
|
marlin_q_w,
|
|
None,
|
|
marlin_s,
|
|
a_scales,
|
|
marlin_s2,
|
|
marlin_zp,
|
|
g_idx,
|
|
sort_indices,
|
|
workspace,
|
|
b_type,
|
|
a_input.shape[0],
|
|
b_weight.shape[1],
|
|
a_input.shape[1],
|
|
is_k_full=is_k_full,
|
|
use_atomic_add=use_atomic_add,
|
|
use_fp32_reduce=use_fp32_reduce,
|
|
is_zp_float=False,
|
|
)
|
|
output_ref = torch.matmul(a_input_ref, w_ref)
|
|
|
|
max_diff = compute_max_diff(output, output_ref)
|
|
assert max_diff < 0.04
|
|
|
|
|
|
# TODO: find better way to test this?
|
|
@torch.compile(fullgraph=True)
|
|
def marlin_24_gemm_tester(
|
|
a_input,
|
|
marlin_24_q_w_comp,
|
|
marlin_24_meta,
|
|
marlin_24_s,
|
|
scratch,
|
|
quant_type,
|
|
size_m,
|
|
size_n,
|
|
size_k,
|
|
):
|
|
return ops.gptq_marlin_24_gemm(
|
|
a_input,
|
|
marlin_24_q_w_comp,
|
|
marlin_24_meta,
|
|
marlin_24_s,
|
|
scratch,
|
|
quant_type,
|
|
size_m,
|
|
size_n,
|
|
size_k,
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.",
|
|
)
|
|
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
|
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
|
@pytest.mark.parametrize("quant_type", GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
|
@pytest.mark.parametrize("group_size", GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES)
|
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
|
def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, mnk_factors):
|
|
m_factor, n_factor, k_factor = mnk_factors
|
|
|
|
size_m = m_factor
|
|
size_k = k_chunk * k_factor
|
|
size_n = n_chunk * n_factor
|
|
|
|
a_input = rand_data((size_m, size_k))
|
|
b_weight = rand_data((size_k, size_n))
|
|
|
|
(w_24_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = marlin_24_quantize(
|
|
b_weight, quant_type, group_size
|
|
)
|
|
|
|
workspace_24 = MarlinWorkspace(
|
|
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
|
)
|
|
|
|
output_ref = torch.matmul(a_input, w_24_ref)
|
|
|
|
opcheck(
|
|
torch.ops._C.gptq_marlin_24_gemm,
|
|
(
|
|
a_input,
|
|
marlin_24_q_w_comp,
|
|
marlin_24_meta,
|
|
marlin_24_s,
|
|
workspace_24.scratch,
|
|
quant_type.id,
|
|
a_input.shape[0],
|
|
b_weight.shape[1],
|
|
a_input.shape[1],
|
|
),
|
|
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
|
)
|
|
|
|
output = marlin_24_gemm_tester(
|
|
a_input,
|
|
marlin_24_q_w_comp,
|
|
marlin_24_meta,
|
|
marlin_24_s,
|
|
workspace_24.scratch,
|
|
quant_type,
|
|
a_input.shape[0],
|
|
b_weight.shape[1],
|
|
a_input.shape[1],
|
|
)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
max_diff = compute_max_diff(output, output_ref)
|
|
|
|
assert max_diff < 0.04
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not is_quant_method_supported("gptq_marlin"),
|
|
reason="Marlin is not supported on this GPU type.",
|
|
)
|
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
|
@pytest.mark.parametrize("group_size", HQQ_SUPPORTED_GROUP_SIZES)
|
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
|
@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS)
|
|
def test_hqq_marlin_gemm(
|
|
k_chunk,
|
|
n_chunk,
|
|
group_size,
|
|
mnk_factors,
|
|
use_fp32_reduce,
|
|
):
|
|
m_factor, n_factor, k_factor = mnk_factors
|
|
|
|
size_m = m_factor
|
|
size_k = k_chunk * k_factor
|
|
size_n = n_chunk * n_factor
|
|
|
|
quant_type = scalar_types.uint4
|
|
|
|
a_input = rand_data((size_m, size_k))
|
|
dev = a_input.device
|
|
|
|
b_weight = torch.randint(0, 10, (size_n, size_k), dtype=torch.uint8, device=dev)
|
|
scale = rand_data((size_n, size_k // group_size))
|
|
zero = rand_data((size_n, size_k // group_size))
|
|
|
|
gptq_w_q = gptq_pack(b_weight.transpose(1, 0), 4, size_k, size_n)
|
|
|
|
sort_indices = torch.empty(0, dtype=torch.int, device=dev)
|
|
marlin_w_q = ops.gptq_marlin_repack(gptq_w_q, sort_indices, size_k, size_n, 4).to(
|
|
dev
|
|
)
|
|
marlin_s = marlin_permute_scales(
|
|
scale.transpose(1, 0), size_k, size_n, group_size
|
|
).to(dev)
|
|
marlin_zp = marlin_permute_scales(
|
|
zero.transpose(1, 0), size_k, size_n, group_size
|
|
).to(dev)
|
|
|
|
g_idx = marlin_make_empty_g_idx(dev)
|
|
g_idx_sort_indices = marlin_make_empty_g_idx(dev)
|
|
|
|
workspace = marlin_make_workspace_new(b_weight.device)
|
|
|
|
output = ops.gptq_marlin_gemm(
|
|
a_input,
|
|
None,
|
|
marlin_w_q,
|
|
None,
|
|
marlin_s,
|
|
None,
|
|
None,
|
|
marlin_zp,
|
|
g_idx,
|
|
g_idx_sort_indices,
|
|
workspace,
|
|
quant_type,
|
|
a_input.shape[0],
|
|
b_weight.shape[0],
|
|
a_input.shape[1],
|
|
is_k_full=True,
|
|
use_fp32_reduce=use_fp32_reduce,
|
|
is_zp_float=True,
|
|
)
|
|
|
|
b_flat = b_weight.reshape(-1, group_size)
|
|
zp_flat = zero.reshape(-1, 1)
|
|
s_flat = scale.reshape(-1, 1)
|
|
dequant = (b_flat - zp_flat) * s_flat
|
|
|
|
output_ref = torch.matmul(a_input, dequant.reshape(b_weight.shape).transpose(1, 0))
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
max_diff = compute_max_diff(output, output_ref)
|
|
|
|
assert max_diff < 0.04
|
|
|
|
|
|
def test_marlin_gemm_subset_input():
|
|
quant_type = scalar_types.uint4b8
|
|
group_size = 128
|
|
|
|
size_m, size_k, size_n = 32, 1024, 2048
|
|
big_m = size_m * 2
|
|
big_k = size_k * 2
|
|
|
|
a_input = rand_data((big_m, big_k))[8 : size_m + 8, 8 : size_k + 8]
|
|
b_weight = rand_data((size_k, size_n))
|
|
|
|
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
|
b_weight, quant_type, group_size, False
|
|
)
|
|
|
|
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
|
workspace = marlin_make_workspace_new(a_input.device)
|
|
|
|
output = ops.gptq_marlin_gemm(
|
|
a_input,
|
|
None,
|
|
marlin_q_w,
|
|
None,
|
|
marlin_s,
|
|
None,
|
|
None,
|
|
marlin_zp,
|
|
g_idx,
|
|
sort_indices,
|
|
workspace,
|
|
quant_type,
|
|
a_input.shape[0],
|
|
b_weight.shape[1],
|
|
a_input.shape[1],
|
|
is_k_full=True,
|
|
use_atomic_add=False,
|
|
use_fp32_reduce=True,
|
|
is_zp_float=False,
|
|
)
|
|
output_ref = torch.matmul(a_input, w_ref)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
max_diff = compute_max_diff(output, output_ref)
|
|
|
|
assert max_diff < 0.04
|
|
|
|
|
|
@pytest.mark.parametrize("size_m", [1, 256])
|
|
def test_marlin_gemm_with_bias(size_m):
|
|
quant_type = scalar_types.uint4b8
|
|
group_size = 128
|
|
|
|
size_k, size_n = 1024, 2048
|
|
a_input = rand_data((size_m, size_k))
|
|
b_weight = rand_data((size_k, size_n))
|
|
b_bias = rand_data((size_n,)) * 10
|
|
|
|
marlin_bias = marlin_permute_bias(b_bias)
|
|
|
|
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
|
b_weight, quant_type, group_size, False
|
|
)
|
|
|
|
marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
|
|
workspace = marlin_make_workspace_new(a_input.device)
|
|
|
|
output = ops.gptq_marlin_gemm(
|
|
a_input,
|
|
None,
|
|
marlin_q_w,
|
|
marlin_bias,
|
|
marlin_s,
|
|
None,
|
|
None,
|
|
marlin_zp,
|
|
g_idx,
|
|
sort_indices,
|
|
workspace,
|
|
quant_type,
|
|
a_input.shape[0],
|
|
b_weight.shape[1],
|
|
a_input.shape[1],
|
|
is_k_full=True,
|
|
use_atomic_add=False,
|
|
use_fp32_reduce=True,
|
|
is_zp_float=False,
|
|
)
|
|
output_ref = torch.matmul(a_input, w_ref) + b_bias.view(1, -1)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
max_diff = compute_max_diff(output, output_ref)
|
|
|
|
assert max_diff < 0.04
|