From 9e0b558a0923004758a9dc91f5d4920ded48b42a Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 23 Jul 2024 00:11:50 -0400 Subject: [PATCH] [Misc] Support FP8 kv cache scales from compressed-tensors (#6528) --- tests/quantization/test_compressed_tensors.py | 7 ++ vllm/attention/layer.py | 23 +++--- .../compressed_tensors/compressed_tensors.py | 63 +++++++++++++-- .../quantization/compressed_tensors/utils.py | 17 ++++ .../model_executor/layers/quantization/fp8.py | 63 ++------------- .../layers/quantization/kv_cache.py | 78 +++++++++++++++++++ vllm/model_executor/models/llama.py | 10 +++ 7 files changed, 186 insertions(+), 75 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/kv_cache.py diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 888e20e51a842..c5a01b73f4a80 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -150,3 +150,10 @@ def test_compressed_tensors_fp8(vllm_runner): output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output + + +def test_compressed_tensors_kv_cache(vllm_runner): + model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" + with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: + output = llm.generate_greedy("Hello world!", max_tokens=20) + assert output diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 643a845899c37..5fa552f2f4eca 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -9,7 +9,7 @@ from vllm.attention.selector import get_attn_backend from vllm.config import CacheConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import Fp8KVCacheMethod +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod class Attention(nn.Module): @@ -59,19 +59,18 @@ class Attention(nn.Module): quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None if quant_method is not None: - assert isinstance(quant_method, Fp8KVCacheMethod) + assert isinstance(quant_method, BaseKVCacheMethod) # TODO (mgoin): kv cache dtype should be specified in the FP8 # checkpoint config and become the "auto" behavior - if "fp8" in self.kv_cache_dtype: - if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError("fp8_e5m2 kv-cache is not supported with " - "fp8 checkpoints.") - # When FP8 quantization is enabled, we make a parameter - # "kv_scale" so that it can be loaded from FP8 checkpoint. - # The k/v_scale will then be converted back to - # self._kv_scale in a native float32 value after weight loading - self.quant_method = quant_method - self.quant_method.create_weights(self) + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) # During model initialization, the default dtype is set as the model # weight and activation dtype. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 0accc94231b9c..c4d0c9cb981da 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensorsScheme, CompressedTensorsUnquantized, @@ -15,18 +15,23 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, QuantizationType, find_matched_target, is_activation_quantization_format, should_ignore_layer) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import current_platform class CompressedTensorsConfig(QuantizationConfig): - def __init__(self, target_scheme_map: Dict[str, Any], ignore: List[str], - quant_format: str): + def __init__(self, + target_scheme_map: Dict[str, Any], + ignore: List[str], + quant_format: str, + kv_cache_scheme: Optional[Dict[str, Any]] = None): self.ignore = ignore self.quant_format = quant_format # Map from [target -> scheme] self.target_scheme_map = target_scheme_map + self.kv_cache_scheme = kv_cache_scheme def get_linear_method(self) -> "CompressedTensorsLinearMethod": return CompressedTensorsLinearMethod(self) @@ -50,9 +55,12 @@ class CompressedTensorsConfig(QuantizationConfig): self, layer: torch.nn.Module, prefix: str, - ) -> Optional["CompressedTensorsLinearMethod"]: + ) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): return CompressedTensorsLinearMethod(self) + if isinstance(layer, Attention): + return CompressedTensorsKVCacheMethod(self) return None @classmethod @@ -85,7 +93,8 @@ class CompressedTensorsConfig(QuantizationConfig): return cls(target_scheme_map=target_scheme_map, ignore=ignore, - quant_format=quant_format) + quant_format=quant_format, + kv_cache_scheme=config.get("kv_cache_scheme")) @classmethod def get_config_filenames(cls) -> List[str]: @@ -309,3 +318,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase): if scheme is None: raise ValueError("A scheme must be defined for each layer") return scheme.apply_weights(layer, x, bias=bias) + + +class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from compressed-tensors + checkpoints. + """ + + def __init__(self, quant_config: CompressedTensorsConfig): + self.validate_kv_cache_scheme(quant_config.kv_cache_scheme) + super().__init__(quant_config) + + @staticmethod + def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]): + """ + Validator for the kv cache scheme. Useful for controlling the + kv cache quantization schemes, that are being supported in vLLM + :param kv_cache_scheme: the compressed-tensors kv cache scheme + """ + if kv_cache_scheme is None: + return + + type_ = kv_cache_scheme.get("type") + num_bits = kv_cache_scheme.get("num_bits") + + if type_ != "float" and num_bits != 8: + raise NotImplementedError( + "Currently supported kv cache quantization is " + "num_bits=8, type=float, however " + f"received num_bits={num_bits}, type={type_}") + + strategy = kv_cache_scheme.get("strategy") + if strategy != "tensor": + raise NotImplementedError( + "Only support per-tensor scaling factor " + "for compressed-tensors KV cache. " + f"Expected strategy: tensor, found strategy: {strategy}") + + is_symmetric = kv_cache_scheme.get("symmetric") + if not is_symmetric: + raise NotImplementedError( + "Only support symmetric scaling factor " + "for compressed-tensors KV cache. " + f"However found symmetric: {is_symmetric}") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index b3110ce653308..7e8e70806a0fc 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -209,6 +209,23 @@ def _find_first_match(value: str, return None +def get_compressed_tensors_cache_scale(name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in compressed-tensors. If this is the case, return its equivalent + param name expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + # If no matches, return None + return None + + def _is_equal_or_regex_match(value: str, target: str, check_contains: bool = False) -> bool: diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index d4498f452cc06..b2a1b0a9534e8 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( @@ -400,64 +401,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_group=topk_group) -class Fp8KVCacheMethod(QuantizeMethodBase): - """Supports loading kv-cache scaling factors from FP8 checkpoints. +class Fp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. """ def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module): - """Create "weight" (aka k_scale and v_scale) for an attention layer. - - Args: - layer: The layer that is using the QuantizeMethodBase factory. - """ - # Initialize the KV cache scales to -1.0, which is an invalid value. - # If the k/v_scale appears in the checkpoint, it will be - # overwritten when loading weights. - layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False) - layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False) - - def apply(self, layer: torch.nn.Module) -> torch.Tensor: - raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") - - def process_weights_after_loading(self, layer: Module) -> None: - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 - # regardless whether the kv-scale is available in the checkpoint. - if layer.kv_cache_dtype != "auto": - if layer.k_scale > 0.0 and layer.v_scale > 0.0: - # We prefer to use separate k_scale and v_scale if present - k_scale = layer.k_scale.to("cpu").tolist() - v_scale = layer.v_scale.to("cpu").tolist() - elif layer.k_scale < 0.0 and layer.v_scale < 0.0: - # If no scales were loaded (both scales are invalid negative - # values), use the default value of 1.0 - k_scale = Parameter(torch.tensor(1.0), requires_grad=False) - v_scale = Parameter(torch.tensor(1.0), requires_grad=False) - else: - # If we find a single kv_scale in the checkpoint, we remap - # kv_scale to k_scale during weight loading, and duplicate - # k_scale to v_scale here - assert layer.k_scale > 0.0 - scale_to_duplicate = max(layer.k_scale, layer.v_scale) - k_scale = scale_to_duplicate.to("cpu").tolist() - v_scale = scale_to_duplicate.to("cpu").tolist() - - if not isinstance(k_scale, float) or not isinstance( - v_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") - - # These are used in the final Attention.forward() - layer._k_scale = k_scale - layer._v_scale = v_scale - if (layer._k_scale == 1.0 and layer._v_scale == 1.0 - and "e5m2" not in layer.kv_cache_dtype): - print_warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure k/v_scale " - "scaling factors are available in the fp8 checkpoint.") - - del layer.k_scale - del layer.v_scale + super().__init__(quant_config) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py new file mode 100644 index 0000000000000..c1495711447fa --- /dev/null +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -0,0 +1,78 @@ +import torch + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.utils import print_warning_once + + +class BaseKVCacheMethod(QuantizeMethodBase): + """ + Quant method that adds `_k_scale` and `_v_scale` attributes to the + Attention layer to support loading those scaling factors from checkpoints. + The k/v_scale will be used to: + - quantize k/v_cache entries before saving them to the cache + - dequantize k/v_cache entries before fetching them from the cache + + :param quant_config: the appropriate QuantizationConfig + """ + + def __init__(self, quant_config: QuantizationConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module): + """ + Create "weight" (aka k_scale and v_scale) for an attention layer. + """ + # Initialize the KV cache scales to -1.0, which is an invalid value. + # If the k/v_scale appears in the checkpoint, it will be + # overwritten when loading weights. + layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + + def apply(self, layer: torch.nn.Module) -> torch.Tensor: + raise RuntimeError( + f"{self.__class__.__name__}.apply should not be called.") + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + if layer.kv_cache_dtype != "auto": + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = torch.nn.Parameter(torch.tensor(1.0), + requires_grad=False) + v_scale = torch.nn.Parameter(torch.tensor(1.0), + requires_grad=False) + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + + if not isinstance(k_scale, float) or not isinstance( + v_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + + # These are used in the final Attention.forward() + layer._k_scale = k_scale + layer._v_scale = v_scale + if (layer._k_scale == 1.0 and layer._v_scale == 1.0 + and "e5m2" not in layer.kv_cache_dtype): + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") + + del layer.k_scale + del layer.v_scale diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d052113e79892..2052c443a8885 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -39,6 +39,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -467,6 +469,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue