[Bugfix] Allow fallback to AWQ from AWQMarlin at per-layer granularity (#13119)

This commit is contained in:
Michael Goin 2025-02-12 12:19:53 -05:00 committed by GitHub
parent 36a08630e8
commit 09972e716c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 61 additions and 32 deletions

View File

@ -290,29 +290,30 @@ class ColumnParallelLinear(LinearBase):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None, output_sizes: Optional[list[int]] = None,
prefix: str = ""): prefix: str = ""):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix) quant_config, prefix)
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_method is not None
self.output_size_per_partition = divide(self.output_size, tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, tp_size)
for output_size in self.output_sizes
]
if output_sizes is None: if output_sizes is None:
output_sizes = [output_size] output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights( self.quant_method.create_weights(
layer=self, layer=self,
input_size_per_partition=self.input_size, input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes, output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size, input_size=self.input_size,
output_size=self.output_size, output_size=self.output_size,
@ -1044,22 +1045,24 @@ class RowParallelLinear(LinearBase):
reduce_results: bool = True, reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): prefix: str = ""):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix) quant_config, prefix)
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results self.reduce_results = reduce_results
# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights( self.quant_method.create_weights(
layer=self, layer=self,
input_size_per_partition=self.input_size_per_partition, input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size], output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size, input_size=self.input_size,
output_size=self.output_size, output_size=self.output_size,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,

View File

@ -13,15 +13,17 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, check_marlin_supports_layer, marlin_make_empty_g_idx,
marlin_permute_scales, moe_awq_to_marlin_zero_points, marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales,
verify_marlin_supported, verify_marlin_supports_shape) moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter) PackedvLLMParameter)
@ -40,18 +42,17 @@ class AWQMarlinConfig(QuantizationConfig):
8: scalar_types.uint8, 8: scalar_types.uint8,
} }
def __init__(self, def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
weight_bits: int,
group_size: int,
zero_point: bool,
lm_head_quantized: bool, lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]] = None) -> None: modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
self.pack_factor = 32 // weight_bits # packed into int32 self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size self.group_size = group_size
self.zero_point = zero_point self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.modules_to_not_convert = modules_to_not_convert or [] self.modules_to_not_convert = modules_to_not_convert or []
self.full_config = full_config
if self.weight_bits not in self.TYPE_MAP: if self.weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
@ -96,7 +97,7 @@ class AWQMarlinConfig(QuantizationConfig):
modules_to_not_convert = cls.get_from_keys_or( modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None) config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, lm_head_quantized, return cls(weight_bits, group_size, zero_point, lm_head_quantized,
modules_to_not_convert) modules_to_not_convert, config)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(cls, hf_quant_cfg,
@ -124,6 +125,13 @@ class AWQMarlinConfig(QuantizationConfig):
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert): if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
# Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMarlin. "
"Falling back to unoptimized AWQ kernels.")
return AWQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self) return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self) return AWQMoEMethod(self)

View File

@ -16,6 +16,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig) GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -87,8 +89,8 @@ class MoeWNA16Config(QuantizationConfig):
modules_to_not_convert = [] modules_to_not_convert = []
elif linear_quant_method == "awq": elif linear_quant_method == "awq":
has_zp = cls.get_from_keys(config, ["zero_point"]) has_zp = cls.get_from_keys(config, ["zero_point"])
modules_to_not_convert = cls.get_from_keys( modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"]) config, ["modules_to_not_convert"], None)
else: else:
raise ValueError("moe_wna16 only support gptq and awq.") raise ValueError("moe_wna16 only support gptq and awq.")
@ -135,7 +137,8 @@ class MoeWNA16Config(QuantizationConfig):
return GPTQConfig.from_config( return GPTQConfig.from_config(
self.full_config).get_quant_method(layer, prefix) self.full_config).get_quant_method(layer, prefix)
elif self.linear_quant_method == "awq": elif self.linear_quant_method == "awq":
if self.use_marlin: if self.use_marlin and check_marlin_supports_layer(
layer, self.group_size):
return AWQMarlinConfig.from_config( return AWQMarlinConfig.from_config(
self.full_config).get_quant_method(layer, prefix) self.full_config).get_quant_method(layer, prefix)
else: else:

View File

@ -6,6 +6,7 @@ import numpy
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
@ -135,6 +136,20 @@ def check_marlin_supports_shape(output_size_per_partition: int,
return True, None return True, None
def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
output_size_per_partition = getattr(layer, "output_size_per_partition",
None) or layer.output_size
input_size_per_partition = getattr(layer, "input_size_per_partition",
None) or layer.input_size
return check_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=layer.input_size,
group_size=group_size)[0]
def marlin_make_workspace(output_size_per_partition: int, def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor: device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition // max_workspace_size = (output_size_per_partition //