From f0b2da72a84a7e481ca7ae1e84cedae5bc645611 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 14 Feb 2025 01:19:22 -0500 Subject: [PATCH] Expand MLA to support most types of quantization (#13181) --- vllm/attention/backends/mla/utils.py | 69 ++++++----------- vllm/config.py | 32 +------- vllm/model_executor/model_loader/loader.py | 90 ++++++++-------------- 3 files changed, 60 insertions(+), 131 deletions(-) diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index a41140ec83782..e9b4dff74f427 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -26,7 +26,7 @@ from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.utils.fp8_utils import ( apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - scaled_dequantize, scaled_quantize) + scaled_quantize) from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) @@ -220,16 +220,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): .view(-1, self.num_heads, self.kv_lora_rank) def process_weights_after_loading(self, act_dtype: torch.dtype): - - def is_layer_fp8(layer: LinearBase) -> bool: - return isinstance(layer.quant_method, Fp8LinearMethod) or\ - (isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ - and isinstance(layer.scheme, CompressedTensorsW8A8Fp8)) - - def quantization_scheme_supported(layer: LinearBase) -> bool: - return isinstance(layer.quant_method, UnquantizedLinearMethod) or \ - is_layer_fp8(layer) - # TODO(lucas) This is very gross, we need a more wide scale refactor of # all the FP8 code with a more standard way of # defining schemes/group-shapes, we should also potentially force @@ -239,7 +229,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ Tuple[Tuple[int, int], Tuple[int, int]]: if isinstance(layer.quant_method, Fp8LinearMethod): - if layer.quant_method.block_quant is not None: + if layer.quant_method.block_quant: weight_block_size = \ layer.quant_method.quant_config.weight_block_size # per-token-group (1, X), block-quantized (X, Y) @@ -267,41 +257,32 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" ) - def get_scales(layer: LinearBase) -> torch.Tensor: - if hasattr(layer, "weight_scale_inv"): - return layer.weight_scale_inv - return layer.weight_scale + def get_layer_weight(layer): + if hasattr(layer, "weight"): + return layer.weight + elif hasattr(layer, "qweight"): + return layer.qweight + else: + raise AttributeError( + f"Layer '{layer}' has neither weight nor qweight") def get_and_maybe_dequant_weights(layer: LinearBase): - if is_layer_fp8(layer): - if isinstance(layer.quant_method, \ - CompressedTensorsLinearMethod) and \ - isinstance(layer.scheme, CompressedTensorsW8A8Fp8): - # NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8` - # seems to store weights as (input, output) instead of - # (output, input) so we need to transpose - weight = layer.weight.T # standardize to (output, input) - else: - weight = layer.weight - _, weight_scale_group_shape = \ - get_scale_group_shapes_for_fp8(layer) - scales = get_scales(layer) + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight - return scaled_dequantize(weight, scales, - weight_scale_group_shape) - else: - return layer.weight - - if not (quantization_scheme_supported(self.kv_b_proj) and\ - quantization_scheme_supported(self.q_proj) and\ - quantization_scheme_supported(self.o_proj)): - raise NotImplementedError( - "Only FP8 and UnquantizedLinearMethod are supported for MLA" - ", please run with VLLM_MLA_DISABLE=1") - - weight_dtype = self.kv_b_proj.weight.dtype - assert self.o_proj.weight.dtype == weight_dtype - assert self.q_proj.weight.dtype == weight_dtype + weight_dtype = get_layer_weight(self.kv_b_proj).dtype + assert get_layer_weight(self.o_proj).dtype == weight_dtype + assert get_layer_weight(self.q_proj).dtype == weight_dtype kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( diff --git a/vllm/config.py b/vllm/config.py index 10004b8f62919..87ceb19056ef5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -991,37 +991,7 @@ class ModelConfig: @property def use_mla(self) -> bool: - if not self.is_deepseek_mla or envs.VLLM_MLA_DISABLE: - return False - - if self.quantization is not None and self.quantization not in [\ - "fp8", "compressed-tensors"]: - logger.warning( - "MLA is not supported with %s quantization. " - "Disabling MLA.", self.quantization) - return False - - # If using a "compressed-tensors" checkpoint, check that all groups - # have fp8 for both weights and activations. - if self.quantization == "compressed-tensors": - quant_config = self._parse_quant_hf_config() - for group_name, cfg in quant_config.get("config_groups", { - "": {} - }).items(): - act_cfg = cfg.get("input_activations", {}) - act_type = None if act_cfg is None else act_cfg.get("type", "") - w_cfg = cfg.get("weights", {}) - w_type = None if w_cfg is None else w_cfg.get("type", "") - if act_type != "fp8" or w_type != "fp8": - logger.warning( - "compressed-tensors MLA support requires fp8 " - "activations and weights in group '%s', but got " - "activations type '%s' and weights type '%s'.\n " - "Full config: %s", group_name, act_type, w_type, - quant_config) - return False - - return True + return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE @property def supported_runner_types(self) -> Set[RunnerType]: diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 2a2c2523b725d..230484a36dec2 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -153,6 +153,30 @@ def _initialize_model( return model_class(**kwargs) +def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig, + target_device: torch.device) -> None: + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + # Currently only used by MLA. + # NOTE: This intentionally happens after other modules so we can easily + # decompress the weights for MLA. + for _, module in model.named_modules(): + if isinstance(module, Attention) and \ + hasattr(module, "process_weights_after_loading"): + # TODO(lucas): see if there is a way to unify the signatures + # of process_weights_after_loading + module.process_weights_after_loading(model_config.dtype) + + class BaseModelLoader(ABC): """Base class for model loaders.""" @@ -376,7 +400,6 @@ class DefaultModelLoader(BaseModelLoader): def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config - target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): with target_device: @@ -394,23 +417,8 @@ class DefaultModelLoader(BaseModelLoader): "Following weights were not initialized from " f"checkpoint: {weights_not_loaded}") - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if isinstance(quant_method, QuantizeMethodBase): - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - # TODO(lucas): see if there is a way to unify the signatures - # of process_weights_after_loading - module.process_weights_after_loading(model_config.dtype) + _process_weights_after_loading(model, model_config, target_device) + return model.eval() @@ -429,29 +437,15 @@ class DummyModelLoader(BaseModelLoader): def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with target_device: model = _initialize_model(vllm_config=vllm_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context( - module, torch.device(device_config.device)): - quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - module.process_weights_after_loading(model_config.dtype) + _process_weights_after_loading(model, model_config, target_device) return model.eval() @@ -632,6 +626,7 @@ class ShardedStateLoader(BaseModelLoader): def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config + target_device = torch.device(device_config.device) from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank @@ -640,18 +635,10 @@ class ShardedStateLoader(BaseModelLoader): model_config.revision) with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with target_device: model = _initialize_model(vllm_config=vllm_config) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - module.process_weights_after_loading( - model_config.dtype) + _process_weights_after_loading(model, model_config, + target_device) rank = get_tensor_model_parallel_rank() pattern = os.path.join( local_model_path, @@ -1401,16 +1388,7 @@ class RunaiModelStreamerLoader(BaseModelLoader): self._get_weights_iterator(model_weights, model_config.revision)) - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - with device_loading_context(module, target_device): - quant_method.process_weights_after_loading(module) - if isinstance(module, Attention) and \ - hasattr(module, "process_weights_after_loading"): - # When attention modules need to process weights after - # currently only used by MLA - module.process_weights_after_loading(model_config.dtype) + _process_weights_after_loading(model, model_config, target_device) return model.eval()