[Fix]Load kv-cache dtype from hf_quant_config.json automatically (fix for reverted PR) (#30785)

Signed-off-by: <>
Co-authored-by: root <root@gpu-937.slurm-workers-slurm.slurm.svc.cluster.local>
This commit is contained in:
danielafrimi 2025-12-17 11:56:38 +02:00 committed by GitHub
parent 9db1db5949
commit 7b966ae2ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 1 deletions

View File

@ -93,6 +93,7 @@ from vllm.transformers_utils.utils import is_cloud_storage
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING:
@ -106,6 +107,7 @@ else:
LoadFormats = Any
UsageContext = Any
logger = init_logger(__name__)
# object is used to allow for special typing forms
@ -1361,12 +1363,17 @@ class EngineArgs:
f"dcp_size={self.decode_context_parallel_size}."
)
# Resolve "auto" kv_cache_dtype to actual value from model config
resolved_cache_dtype = resolve_kv_cache_dtype_string(
self.kv_cache_dtype, model_config
)
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
swap_space=self.swap_space,
cache_dtype=self.kv_cache_dtype,
cache_dtype=resolved_cache_dtype,
is_attention_free=model_config.is_attention_free,
num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=sliding_window,

View File

@ -24,6 +24,10 @@ else:
ModelConfig = object
IntermediateTensors = object
import logging
logger = logging.getLogger(__name__)
STR_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
@ -49,6 +53,13 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
}
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
# TODO: Add more modelopt kv cache dtype
# mappings here when it supported by some attention backend
# (for example supports nvfp4).
"fp8": "fp8_e4m3",
}
T = TypeVar("T")
@ -194,6 +205,70 @@ def get_kv_cache_torch_dtype(
return torch_dtype
def get_kv_cache_quant_algo_string(quant_cfg: dict[str, Any]) -> str | None:
"""Get the KV cache quantization algorithm string from the quantization config.
Maps various FP8 format names to vLLM's standard cache dtype strings.
Returns None if no kv_cache_quant_algo is specified.
Returns "auto" if the value is not recognized/supported.
"""
# Mapping from model config values to vLLM cache_dtype strings
quant_method = quant_cfg.get("quant_method", "")
if quant_method.startswith("modelopt"):
quantization_inner = quant_cfg.get("quantization", quant_cfg)
# Check if quant config is specified and use kv cache quant algo
kv_algo = quantization_inner.get("kv_cache_quant_algo") or quant_cfg.get(
"kv_cache_quant_algo"
)
if isinstance(kv_algo, str):
kv_algo_lower = kv_algo.lower()
# Try to map to vLLM's standard format
if kv_algo_lower in MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP:
return MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP[kv_algo_lower]
else:
# Unknown/unsupported format - return "auto" as safe fallback
logger.warning(
"WARNING: Unknown kv_cache_quant_algo '%s' in model "
"config. Supported values: %s. Falling back to 'auto'.",
kv_algo,
list(MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP.keys()),
)
return "auto"
return None
def get_kv_cache_quant_algo_dtype(quant_cfg: dict[str, Any]) -> torch.dtype | None:
"""Get the KV cache quantization algorithm dtype from the quantization config."""
kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg)
if kv_algo_str is not None and kv_algo_str != "auto":
# Only convert if we have a valid dtype string (not "auto" fallback)
return STR_DTYPE_TO_TORCH_DTYPE[kv_algo_str]
return None
def resolve_kv_cache_dtype_string(
kv_cache_dtype: str, model_config: ModelConfig
) -> str:
"""Resolve 'auto' kv_cache_dtype to the actual string value from model config.
Returns the resolved cache_dtype string.
"""
if kv_cache_dtype != "auto":
return kv_cache_dtype
hf_cfg = getattr(model_config, "hf_config", None)
if hf_cfg is not None:
quant_cfg = getattr(hf_cfg, "quantization_config", None)
if quant_cfg is not None:
kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg)
if kv_algo_str is not None:
return kv_algo_str
# Default to auto (will be handled by downstream code)
return "auto"
def kv_cache_dtype_str_to_dtype(
kv_cache_dtype: str, model_config: ModelConfig
) -> torch.dtype: