mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 01:25:01 +08:00
[Fix]Load kv-cache dtype from hf_quant_config.json automatically (fix for reverted PR) (#30785)
Signed-off-by: <> Co-authored-by: root <root@gpu-937.slurm-workers-slurm.slurm.svc.cluster.local>
This commit is contained in:
parent
9db1db5949
commit
7b966ae2ba
@ -93,6 +93,7 @@ from vllm.transformers_utils.utils import is_cloud_storage
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.utils.torch_utils import resolve_kv_cache_dtype_string
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -106,6 +107,7 @@ else:
|
||||
LoadFormats = Any
|
||||
UsageContext = Any
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# object is used to allow for special typing forms
|
||||
@ -1361,12 +1363,17 @@ class EngineArgs:
|
||||
f"dcp_size={self.decode_context_parallel_size}."
|
||||
)
|
||||
|
||||
# Resolve "auto" kv_cache_dtype to actual value from model config
|
||||
resolved_cache_dtype = resolve_kv_cache_dtype_string(
|
||||
self.kv_cache_dtype, model_config
|
||||
)
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
kv_cache_memory_bytes=self.kv_cache_memory_bytes,
|
||||
swap_space=self.swap_space,
|
||||
cache_dtype=self.kv_cache_dtype,
|
||||
cache_dtype=resolved_cache_dtype,
|
||||
is_attention_free=model_config.is_attention_free,
|
||||
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||||
sliding_window=sliding_window,
|
||||
|
||||
@ -24,6 +24,10 @@ else:
|
||||
ModelConfig = object
|
||||
IntermediateTensors = object
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"float32": torch.float32,
|
||||
@ -49,6 +53,13 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
|
||||
}
|
||||
|
||||
|
||||
MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = {
|
||||
# TODO: Add more modelopt kv cache dtype
|
||||
# mappings here when it supported by some attention backend
|
||||
# (for example supports nvfp4).
|
||||
"fp8": "fp8_e4m3",
|
||||
}
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@ -194,6 +205,70 @@ def get_kv_cache_torch_dtype(
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def get_kv_cache_quant_algo_string(quant_cfg: dict[str, Any]) -> str | None:
|
||||
"""Get the KV cache quantization algorithm string from the quantization config.
|
||||
|
||||
Maps various FP8 format names to vLLM's standard cache dtype strings.
|
||||
Returns None if no kv_cache_quant_algo is specified.
|
||||
Returns "auto" if the value is not recognized/supported.
|
||||
"""
|
||||
# Mapping from model config values to vLLM cache_dtype strings
|
||||
|
||||
quant_method = quant_cfg.get("quant_method", "")
|
||||
if quant_method.startswith("modelopt"):
|
||||
quantization_inner = quant_cfg.get("quantization", quant_cfg)
|
||||
# Check if quant config is specified and use kv cache quant algo
|
||||
kv_algo = quantization_inner.get("kv_cache_quant_algo") or quant_cfg.get(
|
||||
"kv_cache_quant_algo"
|
||||
)
|
||||
if isinstance(kv_algo, str):
|
||||
kv_algo_lower = kv_algo.lower()
|
||||
|
||||
# Try to map to vLLM's standard format
|
||||
if kv_algo_lower in MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP:
|
||||
return MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP[kv_algo_lower]
|
||||
else:
|
||||
# Unknown/unsupported format - return "auto" as safe fallback
|
||||
logger.warning(
|
||||
"WARNING: Unknown kv_cache_quant_algo '%s' in model "
|
||||
"config. Supported values: %s. Falling back to 'auto'.",
|
||||
kv_algo,
|
||||
list(MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP.keys()),
|
||||
)
|
||||
return "auto"
|
||||
return None
|
||||
|
||||
|
||||
def get_kv_cache_quant_algo_dtype(quant_cfg: dict[str, Any]) -> torch.dtype | None:
|
||||
"""Get the KV cache quantization algorithm dtype from the quantization config."""
|
||||
kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg)
|
||||
if kv_algo_str is not None and kv_algo_str != "auto":
|
||||
# Only convert if we have a valid dtype string (not "auto" fallback)
|
||||
return STR_DTYPE_TO_TORCH_DTYPE[kv_algo_str]
|
||||
return None
|
||||
|
||||
|
||||
def resolve_kv_cache_dtype_string(
|
||||
kv_cache_dtype: str, model_config: ModelConfig
|
||||
) -> str:
|
||||
"""Resolve 'auto' kv_cache_dtype to the actual string value from model config.
|
||||
Returns the resolved cache_dtype string.
|
||||
"""
|
||||
if kv_cache_dtype != "auto":
|
||||
return kv_cache_dtype
|
||||
|
||||
hf_cfg = getattr(model_config, "hf_config", None)
|
||||
if hf_cfg is not None:
|
||||
quant_cfg = getattr(hf_cfg, "quantization_config", None)
|
||||
if quant_cfg is not None:
|
||||
kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg)
|
||||
if kv_algo_str is not None:
|
||||
return kv_algo_str
|
||||
|
||||
# Default to auto (will be handled by downstream code)
|
||||
return "auto"
|
||||
|
||||
|
||||
def kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype: str, model_config: ModelConfig
|
||||
) -> torch.dtype:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user