Support AWQMarlin with MLA

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
mgoin 2025-02-10 20:53:29 +00:00
parent 2ae889052c
commit d56ef8b685
5 changed files with 51 additions and 31 deletions

View File

@ -215,7 +215,7 @@ def rms_norm_dynamic_per_token_quant(
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
if envs.VLLM_USE_TRITON_AWQ:
if envs.VLLM_USE_TRITON_AWQ or qweight.dtype != torch.float16:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton)
return awq_dequantize_triton(qweight, scales, zeros)

View File

@ -18,6 +18,8 @@ from vllm.distributed import (get_tensor_model_parallel_world_size,
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
@ -227,8 +229,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
def quantization_scheme_supported(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
is_layer_fp8(layer)
return isinstance(layer.quant_method,
(UnquantizedLinearMethod,
AWQMarlinLinearMethod)) or is_layer_fp8(layer)
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
@ -289,6 +292,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
return scaled_dequantize(weight, scales,
weight_scale_group_shape)
elif isinstance(layer.quant_method, AWQMarlinLinearMethod):
return layer.quant_method.decompress_weights(layer).T
else:
return layer.weight
@ -296,12 +301,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
quantization_scheme_supported(self.q_proj) and\
quantization_scheme_supported(self.o_proj)):
raise NotImplementedError(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
"Only FP8, AWQ, and Unquantized are supported for MLA"
", please run with VLLM_MLA_DISABLE=1")
weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype
def get_layer_dtype(layer):
if hasattr(layer, "weight"):
return layer.weight.dtype
elif hasattr(layer, "qweight"):
return layer.qweight.dtype
else:
raise AttributeError(
f"Layer '{layer}' has neither weight nor qweight")
weight_dtype = get_layer_dtype(self.kv_b_proj)
assert get_layer_dtype(self.o_proj) == weight_dtype
assert get_layer_dtype(self.q_proj) == weight_dtype
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (

View File

@ -989,7 +989,7 @@ class ModelConfig:
return False
if self.quantization is not None and self.quantization not in [\
"fp8", "compressed-tensors"]:
"fp8", "compressed-tensors", "awq_marlin"]:
logger.warning(
"MLA is not supported with %s quantization. "
"Disabling MLA.", self.quantization)

View File

@ -234,6 +234,16 @@ class AWQMarlinLinearMethod(LinearMethodBase):
layer.output_size_per_partition = output_size_per_partition
layer.num_groups = num_groups
def decompress_weights(self, layer: torch.nn.Module) -> torch.Tensor:
"""
Decompress to recover the original unquantized weight.
NOTE: this is only to be used before process_weights_after_loading
"""
# We can use AWQ's dequant since the unprocessed weights
# are in AWQ format
return ops.awq_dequantize(layer.qweight, layer.scales, layer.qzeros, 0,
0, 0)
# TODO: Update this docs
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.

View File

@ -153,6 +153,19 @@ def _initialize_model(
return model_class(**kwargs)
def _process_attention_weights_after_loading(
model: nn.Module, model_config: ModelConfig) -> None:
# Currently only used by MLA.
# NOTE: This intentionally happens before other modules so
# we can easily decompress the weights for MLA.
for _, module in model.named_modules():
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
class BaseModelLoader(ABC):
"""Base class for model loaders."""
@ -394,6 +407,8 @@ class DefaultModelLoader(BaseModelLoader):
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
_process_attention_weights_after_loading(model, model_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
@ -404,13 +419,6 @@ class DefaultModelLoader(BaseModelLoader):
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
return model.eval()
@ -436,6 +444,8 @@ class DummyModelLoader(BaseModelLoader):
# random values to the weights.
initialize_dummy_weights(model)
_process_attention_weights_after_loading(model, model_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
@ -447,11 +457,6 @@ class DummyModelLoader(BaseModelLoader):
with device_loading_context(
module, torch.device(device_config.device)):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
return model.eval()
@ -642,16 +647,11 @@ class ShardedStateLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(vllm_config=vllm_config)
_process_attention_weights_after_loading(model, model_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(
model_config.dtype)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
@ -1401,16 +1401,12 @@ class RunaiModelStreamerLoader(BaseModelLoader):
self._get_weights_iterator(model_weights,
model_config.revision))
_process_attention_weights_after_loading(model, model_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
return model.eval()