diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 8da3ae538d671..8abbe8ba0103e 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -122,10 +122,12 @@ class SchedulerConfig: the default scheduler. Can be a class directly or the path to a class of form "mod.custom_class".""" - disable_hybrid_kv_cache_manager: bool = False + disable_hybrid_kv_cache_manager: bool | None = None """If set to True, KV cache manager will allocate the same size of KV cache for all attention layers even if there are multiple type of attention layers like full attention and sliding window attention. + If set to None, the default value will be determined based on the environment + and starting configuration. """ async_scheduling: bool = False diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b5f8f916de438..ace5adc109d86 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -887,17 +887,48 @@ class VllmConfig: if not self.instance_id: self.instance_id = random_uuid()[:5] - if not self.scheduler_config.disable_hybrid_kv_cache_manager: - # logger should only print warning message for hybrid models. As we - # can't know whether the model is hybrid or not now, so we don't log - # warning message here and will log it later. - if not current_platform.support_hybrid_kv_cache(): - # Hybrid KV cache manager is not supported on non-GPU platforms. - self.scheduler_config.disable_hybrid_kv_cache_manager = True + # Hybrid KV cache manager (HMA) runtime rules: + # - Explicit enable (--no-disable-kv-cache-manager): error if runtime + # disables it + # - No preference: auto-disable for unsupported features (e.g. kv connector) + # - Explicit disable (--disable-kv-cache-manager): always respect it + need_disable_hybrid_kv_cache_manager = False + # logger should only print warning message for hybrid models. As we + # can't know whether the model is hybrid or not now, so we don't log + # warning message here and will log it later. + if not current_platform.support_hybrid_kv_cache(): + # Hybrid KV cache manager is not supported on non-GPU platforms. + need_disable_hybrid_kv_cache_manager = True + if self.kv_events_config is not None: + # Hybrid KV cache manager is not compatible with KV events. + need_disable_hybrid_kv_cache_manager = True + if ( + self.model_config is not None + and self.model_config.attention_chunk_size is not None + ): + if ( + self.speculative_config is not None + and self.speculative_config.use_eagle() + ): + # Hybrid KV cache manager is not yet supported with chunked + # local attention + eagle. + need_disable_hybrid_kv_cache_manager = True + elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: + logger.warning( + "There is a latency regression when using chunked local" + " attention with the hybrid KV cache manager. Disabling" + " it, by default. To enable it, set the environment " + "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." + ) + # Hybrid KV cache manager is not yet supported with chunked + # local attention. + need_disable_hybrid_kv_cache_manager = True + + if self.scheduler_config.disable_hybrid_kv_cache_manager is None: + # Default to disable HMA, but only if the user didn't express a preference. if self.kv_transfer_config is not None: - # NOTE(Kuntai): turn HMA off for connector for now. - # TODO(Kuntai): have a more elegent solution to check and - # turn off HMA for connector that does not support HMA. + # NOTE(Kuntai): turn HMA off for connector unless specifically enabled. + need_disable_hybrid_kv_cache_manager = True logger.warning( "Turning off hybrid kv cache manager because " "`--kv-transfer-config` is set. This will reduce the " @@ -905,33 +936,26 @@ class VllmConfig: "or Mamba attention. If you are a developer of kv connector" ", please consider supporting hybrid kv cache manager for " "your connector by making sure your connector is a subclass" - " of `SupportsHMA` defined in kv_connector/v1/base.py." + " of `SupportsHMA` defined in kv_connector/v1/base.py and" + " use --no-disable-hybrid-kv-cache-manager to start vLLM." ) - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.kv_events_config is not None: - # Hybrid KV cache manager is not compatible with KV events. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if ( - self.model_config is not None - and self.model_config.attention_chunk_size is not None - ): - if ( - self.speculative_config is not None - and self.speculative_config.use_eagle() - ): - # Hybrid KV cache manager is not yet supported with chunked - # local attention + eagle. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: - logger.warning( - "There is a latency regression when using chunked local" - " attention with the hybrid KV cache manager. Disabling" - " it, by default. To enable it, set the environment " - "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." - ) - # Hybrid KV cache manager is not yet supported with chunked - # local attention. - self.scheduler_config.disable_hybrid_kv_cache_manager = True + self.scheduler_config.disable_hybrid_kv_cache_manager = ( + need_disable_hybrid_kv_cache_manager + ) + elif ( + self.scheduler_config.disable_hybrid_kv_cache_manager is False + and need_disable_hybrid_kv_cache_manager + ): + raise ValueError( + "Hybrid KV cache manager was explicitly enabled but is not " + "supported in this configuration. Consider omitting the " + "--no-disable-hybrid-kv-cache-manager flag to let vLLM decide" + " automatically." + ) + + if self.scheduler_config.disable_hybrid_kv_cache_manager is None: + # Default to enable HMA if not explicitly disabled by user or logic above. + self.scheduler_config.disable_hybrid_kv_cache_manager = False if self.compilation_config.debug_dump_path: self.compilation_config.debug_dump_path = ( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 2867532756450..3862aa9222446 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -491,7 +491,7 @@ class EngineArgs: enable_chunked_prefill: bool | None = None disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input - disable_hybrid_kv_cache_manager: bool = ( + disable_hybrid_kv_cache_manager: bool | None = ( SchedulerConfig.disable_hybrid_kv_cache_manager )