mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:55:01 +08:00
Enable group size 64 for Machete (#20290)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
parent
e81fbefe8a
commit
3abfe22154
@ -14,6 +14,8 @@ import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||
query_machete_supported_group_sizes)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_rows, quantize_weights)
|
||||
from vllm.platforms import current_platform
|
||||
@ -46,8 +48,6 @@ MNK_SHAPES = [
|
||||
(1024, 8192, 4096),
|
||||
]
|
||||
|
||||
GROUP_SIZES_TO_TEST: list[Optional[int]] = [128, -1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeConfig:
|
||||
@ -270,7 +270,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
group_sizes = GROUP_SIZES_TO_TEST
|
||||
group_sizes = query_machete_supported_group_sizes(types.act_type)
|
||||
|
||||
for group_size in group_sizes:
|
||||
if not group_size_valid(shape, group_size):
|
||||
@ -299,7 +299,7 @@ def test_machete_heuristic(shape, types: TypeConfig):
|
||||
if types.group_scale_type is None:
|
||||
group_sizes = [None]
|
||||
else:
|
||||
group_sizes = GROUP_SIZES_TO_TEST
|
||||
group_sizes = query_machete_supported_group_sizes(types.act_type)
|
||||
|
||||
for group_size in group_sizes:
|
||||
if not group_size_valid(shape, group_size):
|
||||
|
||||
@ -8,7 +8,7 @@ import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
|
||||
check_machete_supports_shape, query_machete_supported_group_sizes,
|
||||
query_machete_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
|
||||
@ -40,10 +40,10 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
"Machete, supported types are: "\
|
||||
f"{query_machete_supported_quant_types(c.zero_points)}"
|
||||
|
||||
if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
|
||||
if c.group_size not in query_machete_supported_group_sizes(c.act_type):
|
||||
return False, f"Group size ({c.group_size}) not supported by "\
|
||||
"Machete, supported group sizes are: "\
|
||||
f"{MACHETE_SUPPORTED_GROUP_SIZES}"
|
||||
f"{query_machete_supported_group_sizes(c.act_type)}"
|
||||
|
||||
return check_machete_supports_shape(c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1])
|
||||
|
||||
@ -7,7 +7,6 @@ import torch
|
||||
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
|
||||
|
||||
|
||||
@ -22,6 +21,24 @@ def query_machete_supported_act_types(zero_points: bool) -> list[ScalarType]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def query_machete_supported_group_sizes(act_type: torch.dtype) -> list[int]:
|
||||
"""
|
||||
Queries the supported group sizes for Machete based on the activation type.
|
||||
|
||||
Args:
|
||||
act_type: The activation data type (torch.float16, torch.bfloat16).
|
||||
|
||||
Returns:
|
||||
A list of supported group sizes. The group size must
|
||||
be divisible by `TileShapeK = 128 * 8 // num_bits(act_type)`.
|
||||
-1 indicates per-channel quantization.
|
||||
"""
|
||||
if act_type in [torch.float16, torch.bfloat16]:
|
||||
return [-1, 64, 128]
|
||||
else:
|
||||
return [-1, 128]
|
||||
|
||||
|
||||
def check_machete_supports_shape(in_features: int, out_featrues: int) \
|
||||
-> tuple[bool, Optional[str]]:
|
||||
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user