diff --git a/tests/kernels/quantization/test_machete_mm.py b/tests/kernels/quantization/test_machete_mm.py index a4fb9874c490..a7cb2a4e7f21 100644 --- a/tests/kernels/quantization/test_machete_mm.py +++ b/tests/kernels/quantization/test_machete_mm.py @@ -14,6 +14,8 @@ import torch from tests.kernels.utils import opcheck from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + query_machete_supported_group_sizes) from vllm.model_executor.layers.quantization.utils.quant_utils import ( pack_rows, quantize_weights) from vllm.platforms import current_platform @@ -46,8 +48,6 @@ MNK_SHAPES = [ (1024, 8192, 4096), ] -GROUP_SIZES_TO_TEST: list[Optional[int]] = [128, -1] - @dataclass class TypeConfig: @@ -270,7 +270,7 @@ def test_machete_all_schedules(shape, types: TypeConfig): if types.group_scale_type is None: group_sizes = [None] else: - group_sizes = GROUP_SIZES_TO_TEST + group_sizes = query_machete_supported_group_sizes(types.act_type) for group_size in group_sizes: if not group_size_valid(shape, group_size): @@ -299,7 +299,7 @@ def test_machete_heuristic(shape, types: TypeConfig): if types.group_scale_type is None: group_sizes = [None] else: - group_sizes = GROUP_SIZES_TO_TEST + group_sizes = query_machete_supported_group_sizes(types.act_type) for group_size in group_sizes: if not group_size_valid(shape, group_size): diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index a75f3ac8d503..12eb9d104bf2 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -8,7 +8,7 @@ import torch from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.machete_utils import ( - MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + check_machete_supports_shape, query_machete_supported_group_sizes, query_machete_supported_quant_types) from vllm.model_executor.layers.quantization.utils.quant_utils import ( pack_quantized_values_into_int32, unpack_quantized_values_into_int32) @@ -40,10 +40,10 @@ class MacheteLinearKernel(MPLinearKernel): "Machete, supported types are: "\ f"{query_machete_supported_quant_types(c.zero_points)}" - if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + if c.group_size not in query_machete_supported_group_sizes(c.act_type): return False, f"Group size ({c.group_size}) not supported by "\ "Machete, supported group sizes are: "\ - f"{MACHETE_SUPPORTED_GROUP_SIZES}" + f"{query_machete_supported_group_sizes(c.act_type)}" return check_machete_supports_shape(c.partition_weight_shape[0], c.partition_weight_shape[1]) diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py index 580c36a0e2fa..fbb850d22776 100644 --- a/vllm/model_executor/layers/quantization/utils/machete_utils.py +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -7,7 +7,6 @@ import torch from vllm.scalar_type import ScalarType, scalar_types -MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128] MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] @@ -22,6 +21,24 @@ def query_machete_supported_act_types(zero_points: bool) -> list[ScalarType]: return [torch.float16, torch.bfloat16] +def query_machete_supported_group_sizes(act_type: torch.dtype) -> list[int]: + """ + Queries the supported group sizes for Machete based on the activation type. + + Args: + act_type: The activation data type (torch.float16, torch.bfloat16). + + Returns: + A list of supported group sizes. The group size must + be divisible by `TileShapeK = 128 * 8 // num_bits(act_type)`. + -1 indicates per-channel quantization. + """ + if act_type in [torch.float16, torch.bfloat16]: + return [-1, 64, 128] + else: + return [-1, 128] + + def check_machete_supports_shape(in_features: int, out_featrues: int) \ -> tuple[bool, Optional[str]]: if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: