[XPU] Set consistent default KV cache layout (#24745)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-09-15 12:09:34 +02:00 committed by GitHub
parent bc0f6059a2
commit 2e41f5abca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 16 deletions

View File

@ -56,9 +56,9 @@ except ImportError:
logger.warning("NIXL is not available") logger.warning("NIXL is not available")
NixlWrapper = None NixlWrapper = None
# Supported xPUs and types of kv transfer buffer. # Supported platforms and types of kv transfer buffer.
# {xPU: tuple of supported kv buffer types} # {device: tuple of supported kv buffer types}
_NIXL_SUPPORTED_XPUS = { _NIXL_SUPPORTED_DEVICE = {
"cuda": ("cuda", ), "cuda": ("cuda", ),
"tpu": ("cpu", ), "tpu": ("cpu", ),
"xpu": ("cpu", ), "xpu": ("cpu", ),
@ -458,9 +458,9 @@ class NixlConnectorWorker:
self.device_type = current_platform.device_type self.device_type = current_platform.device_type
self.kv_buffer_device: str = \ self.kv_buffer_device: str = \
vllm_config.kv_transfer_config.kv_buffer_device vllm_config.kv_transfer_config.kv_buffer_device
if self.device_type not in _NIXL_SUPPORTED_XPUS: if self.device_type not in _NIXL_SUPPORTED_DEVICE:
raise RuntimeError(f"{self.device_type} is not supported.") raise RuntimeError(f"{self.device_type} is not supported.")
elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[ elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[
self.device_type]: self.device_type]:
raise RuntimeError( raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer " f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
@ -468,7 +468,7 @@ class NixlConnectorWorker:
self.device_kv_caches: dict[str, torch.Tensor] = {} self.device_kv_caches: dict[str, torch.Tensor] = {}
# cpu kv buffer for xfer # cpu kv buffer for xfer
# used when xPU memory can not be registered under nixl # used when device memory can not be registered under nixl
self.host_xfer_buffers: dict[str, torch.Tensor] = {} self.host_xfer_buffers: dict[str, torch.Tensor] = {}
self.use_host_buffer = self.kv_buffer_device == "cpu" self.use_host_buffer = self.kv_buffer_device == "cpu"
if self.kv_buffer_device == "cuda": if self.kv_buffer_device == "cuda":
@ -927,6 +927,9 @@ class NixlConnectorWorker:
if tp_ratio > 1: if tp_ratio > 1:
# Heterogeneous TP expects same kv_cache_layout. # Heterogeneous TP expects same kv_cache_layout.
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
if self.device_type == "xpu":
raise ValueError(
"Heterogeneous TP is not supported on XPU")
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
"Remote P worker KV layer cache must be of shape [2, N, " "Remote P worker KV layer cache must be of shape [2, N, "

View File

@ -9,6 +9,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from vllm.v1.attention.backends.utils import set_kv_cache_layout
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
@ -164,12 +165,9 @@ class XPUPlatform(Platform):
vllm_config.scheduler_config.max_model_len, vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS) DEFAULT_MAX_NUM_BATCHED_TOKENS)
if (envs.VLLM_KV_CACHE_LAYOUT is None set_kv_cache_layout("NHD")
or envs.VLLM_KV_CACHE_LAYOUT != "NHD"): logger.info("Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD" "only NHD layout is supported by XPU attention kernels.")
logger.info(
"Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
"only NHD layout is supported by XPU attention kernels.")
@classmethod @classmethod
def is_pin_memory_available(cls): def is_pin_memory_available(cls):

View File

@ -5,8 +5,8 @@ import enum
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, fields, make_dataclass from dataclasses import dataclass, fields, make_dataclass
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol, from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional,
TypeVar) Protocol, TypeVar, Union, get_args)
import numpy as np import numpy as np
import torch import torch
@ -30,7 +30,12 @@ from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
logger = init_logger(__name__) logger = init_logger(__name__)
_KV_CACHE_LAYOUT_OVERRIDE = None KVCacheLayoutType = Literal["NHD", "HND"]
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
def is_valid_kv_cache_layout(value: str) -> bool:
return value in get_args(KVCacheLayoutType)
@dataclass @dataclass
@ -296,12 +301,13 @@ def get_kv_cache_layout():
if cache_layout is None: if cache_layout is None:
cache_layout = get_kv_connector_cache_layout() cache_layout = get_kv_connector_cache_layout()
else: else:
assert is_valid_kv_cache_layout(cache_layout)
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout) "detected. Setting KV cache layout to %s.", cache_layout)
return cache_layout return cache_layout
def set_kv_cache_layout(cache_layout: str): def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
global _KV_CACHE_LAYOUT_OVERRIDE global _KV_CACHE_LAYOUT_OVERRIDE
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout _KV_CACHE_LAYOUT_OVERRIDE = cache_layout