diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ca19e468914c7..03720bd2516d4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -93,6 +93,7 @@ from vllm.transformers_utils.utils import is_cloud_storage from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.mem_constants import GiB_bytes from vllm.utils.network_utils import get_ip +from vllm.utils.torch_utils import resolve_kv_cache_dtype_string from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: @@ -106,6 +107,7 @@ else: LoadFormats = Any UsageContext = Any + logger = init_logger(__name__) # object is used to allow for special typing forms @@ -1361,12 +1363,17 @@ class EngineArgs: f"dcp_size={self.decode_context_parallel_size}." ) + # Resolve "auto" kv_cache_dtype to actual value from model config + resolved_cache_dtype = resolve_kv_cache_dtype_string( + self.kv_cache_dtype, model_config + ) + cache_config = CacheConfig( block_size=self.block_size, gpu_memory_utilization=self.gpu_memory_utilization, kv_cache_memory_bytes=self.kv_cache_memory_bytes, swap_space=self.swap_space, - cache_dtype=self.kv_cache_dtype, + cache_dtype=resolved_cache_dtype, is_attention_free=model_config.is_attention_free, num_gpu_blocks_override=self.num_gpu_blocks_override, sliding_window=sliding_window, diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index c97efce312b56..b82e0171b7f7f 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -24,6 +24,10 @@ else: ModelConfig = object IntermediateTensors = object +import logging + +logger = logging.getLogger(__name__) + STR_DTYPE_TO_TORCH_DTYPE = { "float32": torch.float32, @@ -49,6 +53,13 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = { } +MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP = { + # TODO: Add more modelopt kv cache dtype + # mappings here when it supported by some attention backend + # (for example supports nvfp4). + "fp8": "fp8_e4m3", +} + T = TypeVar("T") @@ -194,6 +205,70 @@ def get_kv_cache_torch_dtype( return torch_dtype +def get_kv_cache_quant_algo_string(quant_cfg: dict[str, Any]) -> str | None: + """Get the KV cache quantization algorithm string from the quantization config. + + Maps various FP8 format names to vLLM's standard cache dtype strings. + Returns None if no kv_cache_quant_algo is specified. + Returns "auto" if the value is not recognized/supported. + """ + # Mapping from model config values to vLLM cache_dtype strings + + quant_method = quant_cfg.get("quant_method", "") + if quant_method.startswith("modelopt"): + quantization_inner = quant_cfg.get("quantization", quant_cfg) + # Check if quant config is specified and use kv cache quant algo + kv_algo = quantization_inner.get("kv_cache_quant_algo") or quant_cfg.get( + "kv_cache_quant_algo" + ) + if isinstance(kv_algo, str): + kv_algo_lower = kv_algo.lower() + + # Try to map to vLLM's standard format + if kv_algo_lower in MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP: + return MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP[kv_algo_lower] + else: + # Unknown/unsupported format - return "auto" as safe fallback + logger.warning( + "WARNING: Unknown kv_cache_quant_algo '%s' in model " + "config. Supported values: %s. Falling back to 'auto'.", + kv_algo, + list(MODELOPT_TO_VLLM_KV_CACHE_DTYPE_MAP.keys()), + ) + return "auto" + return None + + +def get_kv_cache_quant_algo_dtype(quant_cfg: dict[str, Any]) -> torch.dtype | None: + """Get the KV cache quantization algorithm dtype from the quantization config.""" + kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg) + if kv_algo_str is not None and kv_algo_str != "auto": + # Only convert if we have a valid dtype string (not "auto" fallback) + return STR_DTYPE_TO_TORCH_DTYPE[kv_algo_str] + return None + + +def resolve_kv_cache_dtype_string( + kv_cache_dtype: str, model_config: ModelConfig +) -> str: + """Resolve 'auto' kv_cache_dtype to the actual string value from model config. + Returns the resolved cache_dtype string. + """ + if kv_cache_dtype != "auto": + return kv_cache_dtype + + hf_cfg = getattr(model_config, "hf_config", None) + if hf_cfg is not None: + quant_cfg = getattr(hf_cfg, "quantization_config", None) + if quant_cfg is not None: + kv_algo_str = get_kv_cache_quant_algo_string(quant_cfg) + if kv_algo_str is not None: + return kv_algo_str + + # Default to auto (will be handled by downstream code) + return "auto" + + def kv_cache_dtype_str_to_dtype( kv_cache_dtype: str, model_config: ModelConfig ) -> torch.dtype: