mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[V1][Kernel] Flashinfer HND KV cache layout (#19280)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
93aee29fdb
commit
4c8f64faa7
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
@ -50,8 +49,7 @@ if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
|
||||
"NHD").upper()
|
||||
FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD"
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
"""
|
||||
KV cache helper for store.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -94,15 +93,17 @@ class model_aware_kv_ops_helper:
|
||||
|
||||
|
||||
def get_kv_connector_cache_layout():
|
||||
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
|
||||
# used for faster transfer.
|
||||
vllm_config = get_current_vllm_config()
|
||||
kv_config = vllm_config.kv_transfer_config
|
||||
if vllm_config.model_config is None:
|
||||
logger.warning("Unable to detect current VLLM config. " \
|
||||
if vllm_config.model_config is None or kv_config is None:
|
||||
logger.warning_once("Unable to detect current VLLM config. " \
|
||||
"Defaulting to NHD kv cache layout.")
|
||||
else:
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
if not use_mla and kv_config.kv_connector == "NixlConnector":
|
||||
logger.info("NixlConnector detected. Setting KV cache " \
|
||||
logger.info_once("NixlConnector detected. Setting KV cache " \
|
||||
"layout to HND for better xfer performance.")
|
||||
return "HND"
|
||||
return "NHD"
|
||||
|
||||
11
vllm/envs.py
11
vllm/envs.py
@ -128,6 +128,7 @@ if TYPE_CHECKING:
|
||||
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
|
||||
VLLM_SLEEP_WHEN_IDLE: bool = False
|
||||
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
|
||||
VLLM_KV_CACHE_LAYOUT: Optional[str] = None
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -879,6 +880,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# processes via zmq.
|
||||
"VLLM_MQ_MAX_CHUNK_BYTES_MB":
|
||||
lambda: int(os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16")),
|
||||
|
||||
# KV Cache layout used throughout vllm.
|
||||
# Some common values are:
|
||||
# - NHD
|
||||
# - HND
|
||||
# Where N=num_blocks, H=num_heads and D=head_size. The default value will
|
||||
# leave the layout choice to the backend. Mind that backends may only
|
||||
# implement and support a subset of all possible layouts.
|
||||
"VLLM_KV_CACHE_LAYOUT":
|
||||
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None)
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
||||
@ -16,13 +16,12 @@ from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
|
||||
get_flash_attn_version)
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
@ -73,16 +72,15 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
# NOTE When running disaggregated PD with NIXL, HND layout is used for
|
||||
# faster transfer. `stride_order` indicates the permutation that gets
|
||||
# `stride_order` indicates the permutation that gets
|
||||
# us from `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
else:
|
||||
raise ValueError("Unknown cache layout format %s.", cache_layout)
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
|
||||
|
||||
|
||||
@ -19,7 +19,8 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
CommonAttentionMetadata,
|
||||
get_kv_cache_layout)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
@ -66,6 +67,19 @@ class FlashInferBackend(AttentionBackend):
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
# `stride_order` indicates the permutation that gets us from
|
||||
# `get_kv_cache_shape` to the actual memory layout we want.
|
||||
cache_layout = get_kv_cache_layout()
|
||||
if cache_layout == "NHD":
|
||||
stride_order = (0, 1, 2, 3, 4)
|
||||
elif cache_layout == "HND":
|
||||
stride_order = (0, 1, 3, 2, 4)
|
||||
else:
|
||||
raise ValueError(f"Unknown cache layout format {cache_layout}.")
|
||||
return stride_order
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerLayerParameters:
|
||||
@ -290,7 +304,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
def _get_prefill_wrapper(self):
|
||||
if self._prefill_wrapper is None:
|
||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(), "NHD")
|
||||
self._get_workspace_buffer(), get_kv_cache_layout())
|
||||
return self._prefill_wrapper
|
||||
|
||||
def _get_decode_wrapper(self):
|
||||
@ -303,14 +317,14 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
num_qo_heads // num_kv_heads > 4)
|
||||
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(),
|
||||
"NHD",
|
||||
get_kv_cache_layout(),
|
||||
use_tensor_cores=use_tensor_cores)
|
||||
return self._decode_wrapper
|
||||
|
||||
def _get_cascade_wrapper(self):
|
||||
if self._cascade_wrapper is None:
|
||||
self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
|
||||
2, self._get_workspace_buffer(), "NHD")
|
||||
2, self._get_workspace_buffer(), get_kv_cache_layout())
|
||||
return self._cascade_wrapper
|
||||
|
||||
def _plan(self, attn_metadata: FlashInferMetadata):
|
||||
@ -620,6 +634,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
|
||||
stride_order = FlashInferBackend.get_kv_cache_stride_order()
|
||||
# Regular attention (common case).
|
||||
# Decodes are at the front and prefills are at the back,
|
||||
# according to reorder_batch()
|
||||
@ -634,7 +649,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert prefill_wrapper._sm_scale == self.scale
|
||||
prefill_wrapper.run(
|
||||
prefill_query,
|
||||
kv_cache,
|
||||
kv_cache.permute(*stride_order),
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[num_decode_tokens:],
|
||||
@ -650,7 +665,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert decode_wrapper._sm_scale == self.scale
|
||||
decode_wrapper.run(
|
||||
decode_query,
|
||||
kv_cache,
|
||||
kv_cache.permute(*stride_order),
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output[:num_decode_tokens],
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import abc
|
||||
import functools
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
|
||||
@ -12,6 +13,13 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonAttentionMetadata:
|
||||
@ -119,3 +127,16 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
||||
raise ValueError(
|
||||
error_msg +
|
||||
f"must be the same type as the current layer ({expected}).")
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_kv_cache_layout():
|
||||
# Override with format specified by the user.
|
||||
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
|
||||
if cache_layout is None:
|
||||
cache_layout = get_kv_connector_cache_layout()
|
||||
else:
|
||||
logger.info_once("`FLASHINFER_KV_CACHE_LAYOUT` environment variable " \
|
||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||
|
||||
return cache_layout
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user