From 2e41f5abca796c615f4db8d9a496d037fa385653 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Mon, 15 Sep 2025 12:09:34 +0200 Subject: [PATCH] [XPU] Set consistent default KV cache layout (#24745) Signed-off-by: NickLucche --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 15 +++++++++------ vllm/platforms/xpu.py | 10 ++++------ vllm/v1/attention/backends/utils.py | 14 ++++++++++---- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c306eeb5aa7ab..1ff1407aeb99b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -56,9 +56,9 @@ except ImportError: logger.warning("NIXL is not available") NixlWrapper = None -# Supported xPUs and types of kv transfer buffer. -# {xPU: tuple of supported kv buffer types} -_NIXL_SUPPORTED_XPUS = { +# Supported platforms and types of kv transfer buffer. +# {device: tuple of supported kv buffer types} +_NIXL_SUPPORTED_DEVICE = { "cuda": ("cuda", ), "tpu": ("cpu", ), "xpu": ("cpu", ), @@ -458,9 +458,9 @@ class NixlConnectorWorker: self.device_type = current_platform.device_type self.kv_buffer_device: str = \ vllm_config.kv_transfer_config.kv_buffer_device - if self.device_type not in _NIXL_SUPPORTED_XPUS: + if self.device_type not in _NIXL_SUPPORTED_DEVICE: raise RuntimeError(f"{self.device_type} is not supported.") - elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[ + elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[ self.device_type]: raise RuntimeError( f"{self.device_type} with {self.kv_buffer_device} kv_buffer " @@ -468,7 +468,7 @@ class NixlConnectorWorker: self.device_kv_caches: dict[str, torch.Tensor] = {} # cpu kv buffer for xfer - # used when xPU memory can not be registered under nixl + # used when device memory can not be registered under nixl self.host_xfer_buffers: dict[str, torch.Tensor] = {} self.use_host_buffer = self.kv_buffer_device == "cpu" if self.kv_buffer_device == "cuda": @@ -927,6 +927,9 @@ class NixlConnectorWorker: if tp_ratio > 1: # Heterogeneous TP expects same kv_cache_layout. assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout + if self.device_type == "xpu": + raise ValueError( + "Heterogeneous TP is not supported on XPU") assert nixl_agent_meta.block_len == self.block_len * tp_ratio, ( "Remote P worker KV layer cache must be of shape [2, N, " diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 32208e7fff018..792115b33ea8d 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -9,6 +9,7 @@ import torch import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS +from vllm.v1.attention.backends.utils import set_kv_cache_layout from .interface import DeviceCapability, Platform, PlatformEnum, _Backend @@ -164,12 +165,9 @@ class XPUPlatform(Platform): vllm_config.scheduler_config.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) - if (envs.VLLM_KV_CACHE_LAYOUT is None - or envs.VLLM_KV_CACHE_LAYOUT != "NHD"): - os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD" - logger.info( - "Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; " - "only NHD layout is supported by XPU attention kernels.") + set_kv_cache_layout("NHD") + logger.info("Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; " + "only NHD layout is supported by XPU attention kernels.") @classmethod def is_pin_memory_available(cls): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 009943fa743d8..ead70c910a8fa 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -5,8 +5,8 @@ import enum import functools from abc import abstractmethod from dataclasses import dataclass, fields, make_dataclass -from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol, - TypeVar) +from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional, + Protocol, TypeVar, Union, get_args) import numpy as np import torch @@ -30,7 +30,12 @@ from vllm.logger import init_logger from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) -_KV_CACHE_LAYOUT_OVERRIDE = None +KVCacheLayoutType = Literal["NHD", "HND"] +_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None + + +def is_valid_kv_cache_layout(value: str) -> bool: + return value in get_args(KVCacheLayoutType) @dataclass @@ -296,12 +301,13 @@ def get_kv_cache_layout(): if cache_layout is None: cache_layout = get_kv_connector_cache_layout() else: + assert is_valid_kv_cache_layout(cache_layout) logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ "detected. Setting KV cache layout to %s.", cache_layout) return cache_layout -def set_kv_cache_layout(cache_layout: str): +def set_kv_cache_layout(cache_layout: KVCacheLayoutType): global _KV_CACHE_LAYOUT_OVERRIDE _KV_CACHE_LAYOUT_OVERRIDE = cache_layout