From d56ef8b6850882a24cf0400b363ed5943f33f46d Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 10 Feb 2025 20:53:29 +0000 Subject: [PATCH] Support AWQMarlin with MLA Signed-off-by: mgoin --- vllm/_custom_ops.py | 2 +- vllm/attention/backends/mla/utils.py | 26 +++++++++--- vllm/config.py | 2 +- .../layers/quantization/awq_marlin.py | 10 +++++ vllm/model_executor/model_loader/loader.py | 42 +++++++++---------- 5 files changed, 51 insertions(+), 31 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a682350167675..73f41c3f317e3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -215,7 +215,7 @@ def rms_norm_dynamic_per_token_quant( def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: - if envs.VLLM_USE_TRITON_AWQ: + if envs.VLLM_USE_TRITON_AWQ or qweight.dtype != torch.float16: from vllm.model_executor.layers.quantization.awq_triton import ( awq_dequantize_triton) return awq_dequantize_triton(qweight, scales, zeros) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index a41140ec83782..5a9b4b03b5b2e 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -18,6 +18,8 @@ from vllm.distributed import (get_tensor_model_parallel_world_size, from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.awq_marlin import ( + AWQMarlinLinearMethod) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( @@ -227,8 +229,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): and isinstance(layer.scheme, CompressedTensorsW8A8Fp8)) def quantization_scheme_supported(layer: LinearBase) -> bool: - return isinstance(layer.quant_method, UnquantizedLinearMethod) or \ - is_layer_fp8(layer) + return isinstance(layer.quant_method, + (UnquantizedLinearMethod, + AWQMarlinLinearMethod)) or is_layer_fp8(layer) # TODO(lucas) This is very gross, we need a more wide scale refactor of # all the FP8 code with a more standard way of @@ -289,6 +292,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): return scaled_dequantize(weight, scales, weight_scale_group_shape) + elif isinstance(layer.quant_method, AWQMarlinLinearMethod): + return layer.quant_method.decompress_weights(layer).T else: return layer.weight @@ -296,12 +301,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): quantization_scheme_supported(self.q_proj) and\ quantization_scheme_supported(self.o_proj)): raise NotImplementedError( - "Only FP8 and UnquantizedLinearMethod are supported for MLA" + "Only FP8, AWQ, and Unquantized are supported for MLA" ", please run with VLLM_MLA_DISABLE=1") - weight_dtype = self.kv_b_proj.weight.dtype - assert self.o_proj.weight.dtype == weight_dtype - assert self.q_proj.weight.dtype == weight_dtype + def get_layer_dtype(layer): + if hasattr(layer, "weight"): + return layer.weight.dtype + elif hasattr(layer, "qweight"): + return layer.qweight.dtype + else: + raise AttributeError( + f"Layer '{layer}' has neither weight nor qweight") + + weight_dtype = get_layer_dtype(self.kv_b_proj) + assert get_layer_dtype(self.o_proj) == weight_dtype + assert get_layer_dtype(self.q_proj) == weight_dtype kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( diff --git a/vllm/config.py b/vllm/config.py index 426ba38080270..d2a104aad6eba 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -989,7 +989,7 @@ class ModelConfig: return False if self.quantization is not None and self.quantization not in [\ - "fp8", "compressed-tensors"]: + "fp8", "compressed-tensors", "awq_marlin"]: logger.warning( "MLA is not supported with %s quantization. " "Disabling MLA.", self.quantization) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 8849ba2928228..3de9397e57eff 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -234,6 +234,16 @@ class AWQMarlinLinearMethod(LinearMethodBase): layer.output_size_per_partition = output_size_per_partition layer.num_groups = num_groups + def decompress_weights(self, layer: torch.nn.Module) -> torch.Tensor: + """ + Decompress to recover the original unquantized weight. + NOTE: this is only to be used before process_weights_after_loading + """ + # We can use AWQ's dequant since the unprocessed weights + # are in AWQ format + return ops.awq_dequantize(layer.qweight, layer.scales, layer.qzeros, 0, + 0, 0) + # TODO: Update this docs # Checkpoints are serialized in AutoAWQ format, which is different from the # marlin format. This function is called after the weights are loaded. diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 2a2c2523b725d..69e6a56106293 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -153,6 +153,19 @@ def _initialize_model( return model_class(**kwargs) +def _process_attention_weights_after_loading( + model: nn.Module, model_config: ModelConfig) -> None: + # Currently only used by MLA. + # NOTE: This intentionally happens before other modules so + # we can easily decompress the weights for MLA. + for _, module in model.named_modules(): + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) + + class BaseModelLoader(ABC): """Base class for model loaders.""" @@ -394,6 +407,8 @@ class DefaultModelLoader(BaseModelLoader): "Following weights were not initialized from " f"checkpoint: {weights_not_loaded}") + _process_attention_weights_after_loading(model, model_config) + for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if isinstance(quant_method, QuantizeMethodBase): @@ -404,13 +419,6 @@ class DefaultModelLoader(BaseModelLoader): # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - # TODO(lucas): see if there is a way to unify the signatures - # of process_weights_after_loading - module.process_weights_after_loading(model_config.dtype) return model.eval() @@ -436,6 +444,8 @@ class DummyModelLoader(BaseModelLoader): # random values to the weights. initialize_dummy_weights(model) + _process_attention_weights_after_loading(model, model_config) + for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: @@ -447,11 +457,6 @@ class DummyModelLoader(BaseModelLoader): with device_loading_context( module, torch.device(device_config.device)): quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - module.process_weights_after_loading(model_config.dtype) return model.eval() @@ -642,16 +647,11 @@ class ShardedStateLoader(BaseModelLoader): with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(vllm_config=vllm_config) + _process_attention_weights_after_loading(model, model_config) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - module.process_weights_after_loading( - model_config.dtype) rank = get_tensor_model_parallel_rank() pattern = os.path.join( local_model_path, @@ -1401,16 +1401,12 @@ class RunaiModelStreamerLoader(BaseModelLoader): self._get_weights_iterator(model_weights, model_config.revision)) + _process_attention_weights_after_loading(model, model_config) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - module.process_weights_after_loading(model_config.dtype) return model.eval()