From 60fb4f3bcfce9c84e09ba61e4b59bb1abe19953d Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 23 Dec 2024 14:30:45 -0500 Subject: [PATCH] [Bugfix] Add kv cache scales to gemma2.py (#11269) --- vllm/model_executor/models/gemma2.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 4664aa53ea092..f4530e4771960 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -31,11 +31,14 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors @@ -326,6 +329,15 @@ class Gemma2Model(nn.Module): params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue for (param_name, shard_name, shard_id) in stacked_params_mapping: if shard_name not in name: continue @@ -343,6 +355,10 @@ class Gemma2Model(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue if is_pp_missing_parameter(name, self): continue param = params_dict[name]