diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 0398f0943a70..8324a563edd6 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -606,8 +606,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # which is consistent with the practice of setting # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor + if hasattr(layer_self_attn.attn, "_k_scale"): + layer_self_attn.attn._k_scale = scaling_factor + layer_self_attn.attn._v_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index f9e0443b9a50..a91ed4158a73 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -545,8 +545,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # which is consistent with the practice of setting # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor + if hasattr(layer_self_attn.attn, "_k_scale"): + layer_self_attn.attn._k_scale = scaling_factor + layer_self_attn.attn._v_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2902e6999c2f..8623da99574b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -452,8 +452,9 @@ class LlamaModel(nn.Module): # which is consistent with the practice of setting # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor + if hasattr(layer_self_attn.attn, "_k_scale"): + layer_self_attn.attn._k_scale = scaling_factor + layer_self_attn.attn._v_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index caae0b65d7d1..a7cf65a0e36e 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -565,8 +565,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): # which is consistent with the practice of setting # scaling_factor = tensor_amax / FPtype_max scaling_factor *= 2 - if hasattr(layer_self_attn, "kv_scale"): - layer_self_attn.attn._kv_scale = scaling_factor + if hasattr(layer_self_attn.attn, "_k_scale"): + layer_self_attn.attn._k_scale = scaling_factor + layer_self_attn.attn._v_scale = scaling_factor else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2b545d1b28bd..637fba23611f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1136,7 +1136,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): self.prompt_adapter_manager.create_prompt_adapter_manager( self.model)) - if self.kv_cache_dtype == "fp8" and current_platform.is_rocm(): + if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm() + or current_platform.is_cuda()): # Currently only ROCm accepts kv-cache scaling factors # via quantization_param_path and this will be deprecated # in the future.