diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a987dc53878dc..b7d80f5194c0f 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -import os from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass @@ -50,8 +49,7 @@ if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) -FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT", - "NHD").upper() +FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD" class FlashInferBackend(AttentionBackend): diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index b9bed06d791c5..493235d724f4e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,7 +3,6 @@ """ KV cache helper for store. """ - import torch import vllm.envs as envs @@ -94,15 +93,17 @@ class model_aware_kv_ops_helper: def get_kv_connector_cache_layout(): + # NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is + # used for faster transfer. vllm_config = get_current_vllm_config() kv_config = vllm_config.kv_transfer_config - if vllm_config.model_config is None: - logger.warning("Unable to detect current VLLM config. " \ + if vllm_config.model_config is None or kv_config is None: + logger.warning_once("Unable to detect current VLLM config. " \ "Defaulting to NHD kv cache layout.") else: use_mla = vllm_config.model_config.use_mla if not use_mla and kv_config.kv_connector == "NixlConnector": - logger.info("NixlConnector detected. Setting KV cache " \ + logger.info_once("NixlConnector detected. Setting KV cache " \ "layout to HND for better xfer performance.") return "HND" return "NHD" diff --git a/vllm/envs.py b/vllm/envs.py index 921052821ee3a..a4a1784f97f90 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -128,6 +128,7 @@ if TYPE_CHECKING: VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 VLLM_SLEEP_WHEN_IDLE: bool = False VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16 + VLLM_KV_CACHE_LAYOUT: Optional[str] = None def get_default_cache_root(): @@ -879,6 +880,16 @@ environment_variables: dict[str, Callable[[], Any]] = { # processes via zmq. "VLLM_MQ_MAX_CHUNK_BYTES_MB": lambda: int(os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16")), + + # KV Cache layout used throughout vllm. + # Some common values are: + # - NHD + # - HND + # Where N=num_blocks, H=num_heads and D=head_size. The default value will + # leave the layout choice to the backend. Mind that backends may only + # implement and support a subset of all possible layouts. + "VLLM_KV_CACHE_LAYOUT": + lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None) } # --8<-- [end:env-vars-definition] diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 8b7745ceddd4e..43a664476aaae 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -16,13 +16,12 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, get_flash_attn_version) from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.distributed.kv_transfer.kv_connector.utils import ( - get_kv_connector_cache_layout) from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + get_kv_cache_layout) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -73,16 +72,15 @@ class FlashAttentionBackend(AttentionBackend): @staticmethod def get_kv_cache_stride_order() -> tuple[int, ...]: - # NOTE When running disaggregated PD with NIXL, HND layout is used for - # faster transfer. `stride_order` indicates the permutation that gets + # `stride_order` indicates the permutation that gets # us from `get_kv_cache_shape` to the actual memory layout we want. - cache_layout = get_kv_connector_cache_layout() + cache_layout = get_kv_cache_layout() if cache_layout == "NHD": stride_order = (0, 1, 2, 3, 4) elif cache_layout == "HND": stride_order = (0, 1, 3, 2, 4) else: - raise ValueError("Unknown cache layout format %s.", cache_layout) + raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index b2f54f37a6e19..03a2ed7139c7c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -19,7 +19,8 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + get_kv_cache_layout) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -66,6 +67,19 @@ class FlashInferBackend(AttentionBackend): ) -> tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + # `stride_order` indicates the permutation that gets us from + # `get_kv_cache_shape` to the actual memory layout we want. + cache_layout = get_kv_cache_layout() + if cache_layout == "NHD": + stride_order = (0, 1, 2, 3, 4) + elif cache_layout == "HND": + stride_order = (0, 1, 3, 2, 4) + else: + raise ValueError(f"Unknown cache layout format {cache_layout}.") + return stride_order + @dataclass class PerLayerParameters: @@ -290,7 +304,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def _get_prefill_wrapper(self): if self._prefill_wrapper is None: self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( - self._get_workspace_buffer(), "NHD") + self._get_workspace_buffer(), get_kv_cache_layout()) return self._prefill_wrapper def _get_decode_wrapper(self): @@ -303,14 +317,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): num_qo_heads // num_kv_heads > 4) self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), - "NHD", + get_kv_cache_layout(), use_tensor_cores=use_tensor_cores) return self._decode_wrapper def _get_cascade_wrapper(self): if self._cascade_wrapper is None: self._cascade_wrapper = MultiLevelCascadeAttentionWrapper( - 2, self._get_workspace_buffer(), "NHD") + 2, self._get_workspace_buffer(), get_kv_cache_layout()) return self._cascade_wrapper def _plan(self, attn_metadata: FlashInferMetadata): @@ -620,6 +634,7 @@ class FlashInferImpl(AttentionImpl): num_decode_tokens = attn_metadata.num_decode_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens + stride_order = FlashInferBackend.get_kv_cache_stride_order() # Regular attention (common case). # Decodes are at the front and prefills are at the back, # according to reorder_batch() @@ -634,7 +649,7 @@ class FlashInferImpl(AttentionImpl): assert prefill_wrapper._sm_scale == self.scale prefill_wrapper.run( prefill_query, - kv_cache, + kv_cache.permute(*stride_order), k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output[num_decode_tokens:], @@ -650,7 +665,7 @@ class FlashInferImpl(AttentionImpl): assert decode_wrapper._sm_scale == self.scale decode_wrapper.run( decode_query, - kv_cache, + kv_cache.permute(*stride_order), k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output[:num_decode_tokens], diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 8f6ecd532ccff..82798afee32cb 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import abc +import functools from abc import abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar @@ -12,6 +13,13 @@ if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch +import vllm.envs as envs +from vllm.distributed.kv_transfer.kv_connector.utils import ( + get_kv_connector_cache_layout) +from vllm.logger import init_logger + +logger = init_logger(__name__) + @dataclass class CommonAttentionMetadata: @@ -119,3 +127,16 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name, raise ValueError( error_msg + f"must be the same type as the current layer ({expected}).") + + +@functools.lru_cache +def get_kv_cache_layout(): + # Override with format specified by the user. + cache_layout = envs.VLLM_KV_CACHE_LAYOUT + if cache_layout is None: + cache_layout = get_kv_connector_cache_layout() + else: + logger.info_once("`FLASHINFER_KV_CACHE_LAYOUT` environment variable " \ + "detected. Setting KV cache layout to %s.", cache_layout) + + return cache_layout