mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 06:35:02 +08:00
[Bugfix] Change kv scaling factor by param json on nvidia gpu (#11688)
Signed-off-by: bjmsong <bjmsong@126.com> Co-authored-by: bjmsong <bjmsong@126.com>
This commit is contained in:
parent
b55ed6ef8a
commit
187e32997c
@ -606,8 +606,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
# which is consistent with the practice of setting
|
# which is consistent with the practice of setting
|
||||||
# scaling_factor = tensor_amax / FPtype_max
|
# scaling_factor = tensor_amax / FPtype_max
|
||||||
scaling_factor *= 2
|
scaling_factor *= 2
|
||||||
if hasattr(layer_self_attn, "kv_scale"):
|
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||||
layer_self_attn.attn._kv_scale = scaling_factor
|
layer_self_attn.attn._k_scale = scaling_factor
|
||||||
|
layer_self_attn.attn._v_scale = scaling_factor
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Self attention has no KV cache scaling "
|
raise RuntimeError("Self attention has no KV cache scaling "
|
||||||
"factor attribute!")
|
"factor attribute!")
|
||||||
|
|||||||
@ -545,8 +545,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
# which is consistent with the practice of setting
|
# which is consistent with the practice of setting
|
||||||
# scaling_factor = tensor_amax / FPtype_max
|
# scaling_factor = tensor_amax / FPtype_max
|
||||||
scaling_factor *= 2
|
scaling_factor *= 2
|
||||||
if hasattr(layer_self_attn, "kv_scale"):
|
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||||
layer_self_attn.attn._kv_scale = scaling_factor
|
layer_self_attn.attn._k_scale = scaling_factor
|
||||||
|
layer_self_attn.attn._v_scale = scaling_factor
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Self attention has no KV cache scaling "
|
raise RuntimeError("Self attention has no KV cache scaling "
|
||||||
"factor attribute!")
|
"factor attribute!")
|
||||||
|
|||||||
@ -452,8 +452,9 @@ class LlamaModel(nn.Module):
|
|||||||
# which is consistent with the practice of setting
|
# which is consistent with the practice of setting
|
||||||
# scaling_factor = tensor_amax / FPtype_max
|
# scaling_factor = tensor_amax / FPtype_max
|
||||||
scaling_factor *= 2
|
scaling_factor *= 2
|
||||||
if hasattr(layer_self_attn, "kv_scale"):
|
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||||
layer_self_attn.attn._kv_scale = scaling_factor
|
layer_self_attn.attn._k_scale = scaling_factor
|
||||||
|
layer_self_attn.attn._v_scale = scaling_factor
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Self attention has no KV cache scaling "
|
raise RuntimeError("Self attention has no KV cache scaling "
|
||||||
"factor attribute!")
|
"factor attribute!")
|
||||||
|
|||||||
@ -565,8 +565,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
# which is consistent with the practice of setting
|
# which is consistent with the practice of setting
|
||||||
# scaling_factor = tensor_amax / FPtype_max
|
# scaling_factor = tensor_amax / FPtype_max
|
||||||
scaling_factor *= 2
|
scaling_factor *= 2
|
||||||
if hasattr(layer_self_attn, "kv_scale"):
|
if hasattr(layer_self_attn.attn, "_k_scale"):
|
||||||
layer_self_attn.attn._kv_scale = scaling_factor
|
layer_self_attn.attn._k_scale = scaling_factor
|
||||||
|
layer_self_attn.attn._v_scale = scaling_factor
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Self attention has no KV cache scaling "
|
raise RuntimeError("Self attention has no KV cache scaling "
|
||||||
"factor attribute!")
|
"factor attribute!")
|
||||||
|
|||||||
@ -1136,7 +1136,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
self.prompt_adapter_manager.create_prompt_adapter_manager(
|
self.prompt_adapter_manager.create_prompt_adapter_manager(
|
||||||
self.model))
|
self.model))
|
||||||
|
|
||||||
if self.kv_cache_dtype == "fp8" and current_platform.is_rocm():
|
if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm()
|
||||||
|
or current_platform.is_cuda()):
|
||||||
# Currently only ROCm accepts kv-cache scaling factors
|
# Currently only ROCm accepts kv-cache scaling factors
|
||||||
# via quantization_param_path and this will be deprecated
|
# via quantization_param_path and this will be deprecated
|
||||||
# in the future.
|
# in the future.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user