mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 03:25:02 +08:00
[NIXL] use Host buffer to support TP_ratio > 1 for XPU (#27140)
Signed-off-by: Chendi Xue <chendi.xue@intel.com> Signed-off-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
This commit is contained in:
parent
9771e0b432
commit
7c4767f1eb
@ -755,6 +755,7 @@ class NixlConnectorWorker:
|
|||||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER
|
self._use_flashinfer = attn_backend == _Backend.FLASHINFER
|
||||||
self._use_pallas = attn_backend == _Backend.PALLAS
|
self._use_pallas = attn_backend == _Backend.PALLAS
|
||||||
self.kv_cache_layout = get_kv_cache_layout()
|
self.kv_cache_layout = get_kv_cache_layout()
|
||||||
|
self.host_buffer_kv_cache_layout = self.kv_cache_layout
|
||||||
logger.debug("Detected attention backend %s", self.backend_name)
|
logger.debug("Detected attention backend %s", self.backend_name)
|
||||||
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
|
||||||
|
|
||||||
@ -875,6 +876,20 @@ class NixlConnectorWorker:
|
|||||||
for layer_name, kv_cache in kv_caches.items():
|
for layer_name, kv_cache in kv_caches.items():
|
||||||
kv_shape = kv_cache.shape
|
kv_shape = kv_cache.shape
|
||||||
kv_dtype = kv_cache.dtype
|
kv_dtype = kv_cache.dtype
|
||||||
|
if (
|
||||||
|
self.kv_cache_layout == "NHD"
|
||||||
|
and self.vllm_config.kv_transfer_config is not None
|
||||||
|
and self.vllm_config.kv_transfer_config.enable_permute_local_kv
|
||||||
|
):
|
||||||
|
logger.info_once(
|
||||||
|
"'enable_permute_local_kv' flag is enabled while "
|
||||||
|
"device KV Layout is NHD. Init host buffer with"
|
||||||
|
" HND to better support Decode/Prefill TP_ratio > 1."
|
||||||
|
)
|
||||||
|
# Since NHD will not support Decode/Prefill TP_ratio > 1,
|
||||||
|
# we can leverage host_buffer for permute
|
||||||
|
self.host_buffer_kv_cache_layout = "HND"
|
||||||
|
kv_shape = tuple(kv_shape[i] for i in [0, 1, 3, 2, 4])
|
||||||
xfer_buffers[layer_name] = torch.empty(
|
xfer_buffers[layer_name] = torch.empty(
|
||||||
kv_shape, dtype=kv_dtype, device="cpu"
|
kv_shape, dtype=kv_dtype, device="cpu"
|
||||||
)
|
)
|
||||||
@ -1110,7 +1125,9 @@ class NixlConnectorWorker:
|
|||||||
num_blocks=self.num_blocks,
|
num_blocks=self.num_blocks,
|
||||||
block_lens=self.block_len_per_layer,
|
block_lens=self.block_len_per_layer,
|
||||||
attn_backend_name=self.backend_name,
|
attn_backend_name=self.backend_name,
|
||||||
kv_cache_layout=self.kv_cache_layout,
|
kv_cache_layout=self.kv_cache_layout
|
||||||
|
if not self.use_host_buffer
|
||||||
|
else self.host_buffer_kv_cache_layout,
|
||||||
)
|
)
|
||||||
ready_event, stop_event = threading.Event(), threading.Event()
|
ready_event, stop_event = threading.Event(), threading.Event()
|
||||||
self._nixl_handshake_listener_t = threading.Thread(
|
self._nixl_handshake_listener_t = threading.Thread(
|
||||||
@ -1273,7 +1290,12 @@ class NixlConnectorWorker:
|
|||||||
assert not self._use_pallas or tp_ratio == 1, (
|
assert not self._use_pallas or tp_ratio == 1, (
|
||||||
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
|
||||||
)
|
)
|
||||||
if not self.use_mla and nixl_agent_meta.kv_cache_layout != self.kv_cache_layout:
|
kv_cache_layout = (
|
||||||
|
self.kv_cache_layout
|
||||||
|
if not self.use_host_buffer
|
||||||
|
else self.host_buffer_kv_cache_layout
|
||||||
|
)
|
||||||
|
if not self.use_mla and nixl_agent_meta.kv_cache_layout != kv_cache_layout:
|
||||||
if (
|
if (
|
||||||
self.kv_transfer_config.enable_permute_local_kv
|
self.kv_transfer_config.enable_permute_local_kv
|
||||||
and nixl_agent_meta.kv_cache_layout == "HND"
|
and nixl_agent_meta.kv_cache_layout == "HND"
|
||||||
@ -1299,9 +1321,6 @@ class NixlConnectorWorker:
|
|||||||
)
|
)
|
||||||
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
|
remote_block_size = remote_block_len // (self.slot_size_per_layer[0])
|
||||||
else:
|
else:
|
||||||
if tp_ratio > 1 and self.device_type == "xpu":
|
|
||||||
# XPU uses NHD, hence it does not support splitting on H
|
|
||||||
raise ValueError("Heterogeneous TP is not supported on XPU")
|
|
||||||
# When MLA is not used, this is a list of the same block length
|
# When MLA is not used, this is a list of the same block length
|
||||||
for block_len in nixl_agent_meta.block_lens:
|
for block_len in nixl_agent_meta.block_lens:
|
||||||
assert block_len == remote_block_len, (
|
assert block_len == remote_block_len, (
|
||||||
|
|||||||
@ -144,6 +144,8 @@ class XPUPlatform(Platform):
|
|||||||
# check and update parallel config
|
# check and update parallel config
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
|
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
|
||||||
|
if vllm_config.kv_transfer_config is not None:
|
||||||
|
vllm_config.kv_transfer_config.enable_permute_local_kv = True
|
||||||
|
|
||||||
if parallel_config.distributed_executor_backend is None:
|
if parallel_config.distributed_executor_backend is None:
|
||||||
if parallel_config.world_size > 1:
|
if parallel_config.world_size > 1:
|
||||||
@ -245,6 +247,10 @@ class XPUPlatform(Platform):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Copy blocks from src_cache to dst_cache on XPU."""
|
"""Copy blocks from src_cache to dst_cache on XPU."""
|
||||||
_src_cache = src_cache[:, src_block_indices]
|
_src_cache = src_cache[:, src_block_indices]
|
||||||
|
if _src_cache.shape[2:] != dst_cache.shape[2:]:
|
||||||
|
# To support TP_ratio, HOST KV might be initiated with HND
|
||||||
|
# while XPU device KV is with NHD
|
||||||
|
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
|
||||||
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -257,4 +263,8 @@ class XPUPlatform(Platform):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Copy blocks from XPU to host (CPU)."""
|
"""Copy blocks from XPU to host (CPU)."""
|
||||||
_src_cache = src_cache[:, src_block_indices]
|
_src_cache = src_cache[:, src_block_indices]
|
||||||
|
if _src_cache.shape[2:] != dst_cache.shape[2:]:
|
||||||
|
# XPU device KV is with NHD while HOST KV
|
||||||
|
# might be initiated with HND for TP_ratio support
|
||||||
|
_src_cache = _src_cache.permute(0, 1, 3, 2, 4)
|
||||||
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user