mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 09:12:36 +08:00
[XPU] Set consistent default KV cache layout (#24745)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
bc0f6059a2
commit
2e41f5abca
@ -56,9 +56,9 @@ except ImportError:
|
|||||||
logger.warning("NIXL is not available")
|
logger.warning("NIXL is not available")
|
||||||
NixlWrapper = None
|
NixlWrapper = None
|
||||||
|
|
||||||
# Supported xPUs and types of kv transfer buffer.
|
# Supported platforms and types of kv transfer buffer.
|
||||||
# {xPU: tuple of supported kv buffer types}
|
# {device: tuple of supported kv buffer types}
|
||||||
_NIXL_SUPPORTED_XPUS = {
|
_NIXL_SUPPORTED_DEVICE = {
|
||||||
"cuda": ("cuda", ),
|
"cuda": ("cuda", ),
|
||||||
"tpu": ("cpu", ),
|
"tpu": ("cpu", ),
|
||||||
"xpu": ("cpu", ),
|
"xpu": ("cpu", ),
|
||||||
@ -458,9 +458,9 @@ class NixlConnectorWorker:
|
|||||||
self.device_type = current_platform.device_type
|
self.device_type = current_platform.device_type
|
||||||
self.kv_buffer_device: str = \
|
self.kv_buffer_device: str = \
|
||||||
vllm_config.kv_transfer_config.kv_buffer_device
|
vllm_config.kv_transfer_config.kv_buffer_device
|
||||||
if self.device_type not in _NIXL_SUPPORTED_XPUS:
|
if self.device_type not in _NIXL_SUPPORTED_DEVICE:
|
||||||
raise RuntimeError(f"{self.device_type} is not supported.")
|
raise RuntimeError(f"{self.device_type} is not supported.")
|
||||||
elif self.kv_buffer_device not in _NIXL_SUPPORTED_XPUS[
|
elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[
|
||||||
self.device_type]:
|
self.device_type]:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
|
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
|
||||||
@ -468,7 +468,7 @@ class NixlConnectorWorker:
|
|||||||
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
self.device_kv_caches: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
# cpu kv buffer for xfer
|
# cpu kv buffer for xfer
|
||||||
# used when xPU memory can not be registered under nixl
|
# used when device memory can not be registered under nixl
|
||||||
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
|
self.host_xfer_buffers: dict[str, torch.Tensor] = {}
|
||||||
self.use_host_buffer = self.kv_buffer_device == "cpu"
|
self.use_host_buffer = self.kv_buffer_device == "cpu"
|
||||||
if self.kv_buffer_device == "cuda":
|
if self.kv_buffer_device == "cuda":
|
||||||
@ -927,6 +927,9 @@ class NixlConnectorWorker:
|
|||||||
if tp_ratio > 1:
|
if tp_ratio > 1:
|
||||||
# Heterogeneous TP expects same kv_cache_layout.
|
# Heterogeneous TP expects same kv_cache_layout.
|
||||||
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
|
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
|
||||||
|
if self.device_type == "xpu":
|
||||||
|
raise ValueError(
|
||||||
|
"Heterogeneous TP is not supported on XPU")
|
||||||
|
|
||||||
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
|
||||||
"Remote P worker KV layer cache must be of shape [2, N, "
|
"Remote P worker KV layer cache must be of shape [2, N, "
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||||
|
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
|
||||||
|
|
||||||
@ -164,12 +165,9 @@ class XPUPlatform(Platform):
|
|||||||
vllm_config.scheduler_config.max_model_len,
|
vllm_config.scheduler_config.max_model_len,
|
||||||
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
DEFAULT_MAX_NUM_BATCHED_TOKENS)
|
||||||
|
|
||||||
if (envs.VLLM_KV_CACHE_LAYOUT is None
|
set_kv_cache_layout("NHD")
|
||||||
or envs.VLLM_KV_CACHE_LAYOUT != "NHD"):
|
logger.info("Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
|
||||||
os.environ["VLLM_KV_CACHE_LAYOUT"] = "NHD"
|
"only NHD layout is supported by XPU attention kernels.")
|
||||||
logger.info(
|
|
||||||
"Setting VLLM_KV_CACHE_LAYOUT to 'NHD' for XPU; "
|
|
||||||
"only NHD layout is supported by XPU attention kernels.")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_pin_memory_available(cls):
|
def is_pin_memory_available(cls):
|
||||||
|
|||||||
@ -5,8 +5,8 @@ import enum
|
|||||||
import functools
|
import functools
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, fields, make_dataclass
|
from dataclasses import dataclass, fields, make_dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Optional, Protocol,
|
from typing import (TYPE_CHECKING, Any, ClassVar, Generic, Literal, Optional,
|
||||||
TypeVar)
|
Protocol, TypeVar, Union, get_args)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -30,7 +30,12 @@ from vllm.logger import init_logger
|
|||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
_KV_CACHE_LAYOUT_OVERRIDE = None
|
KVCacheLayoutType = Literal["NHD", "HND"]
|
||||||
|
_KV_CACHE_LAYOUT_OVERRIDE: Union[KVCacheLayoutType, None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_kv_cache_layout(value: str) -> bool:
|
||||||
|
return value in get_args(KVCacheLayoutType)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -296,12 +301,13 @@ def get_kv_cache_layout():
|
|||||||
if cache_layout is None:
|
if cache_layout is None:
|
||||||
cache_layout = get_kv_connector_cache_layout()
|
cache_layout = get_kv_connector_cache_layout()
|
||||||
else:
|
else:
|
||||||
|
assert is_valid_kv_cache_layout(cache_layout)
|
||||||
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
|
||||||
"detected. Setting KV cache layout to %s.", cache_layout)
|
"detected. Setting KV cache layout to %s.", cache_layout)
|
||||||
return cache_layout
|
return cache_layout
|
||||||
|
|
||||||
|
|
||||||
def set_kv_cache_layout(cache_layout: str):
|
def set_kv_cache_layout(cache_layout: KVCacheLayoutType):
|
||||||
global _KV_CACHE_LAYOUT_OVERRIDE
|
global _KV_CACHE_LAYOUT_OVERRIDE
|
||||||
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
|
_KV_CACHE_LAYOUT_OVERRIDE = cache_layout
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user