mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:54:33 +08:00
[Bugfix][Nixl] Fix kernel physical<>logical block_size issue (#28677)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
433c0f8675
commit
96b23b8e3b
@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization():
|
||||
req_index = 0
|
||||
block_table.append_row(kvcache_manager_blocks, req_index)
|
||||
# Get expected kernel blocks from the implementation for verification.
|
||||
expected_kernel_blocks = block_table._map_to_kernel_blocks(
|
||||
np.array(kvcache_manager_blocks)
|
||||
expected_kernel_blocks = block_table.map_to_kernel_blocks(
|
||||
np.array(kvcache_manager_blocks),
|
||||
block_table.blocks_per_kv_block,
|
||||
block_table._kernel_block_arange,
|
||||
)
|
||||
# Verify block table state
|
||||
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)
|
||||
|
||||
@ -49,6 +49,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
# To be used when logical block size does not match the kernel block size
|
||||
local_physical_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
@ -139,6 +142,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
assert load_remote_cache ^ save_to_host
|
||||
_req = ReqMeta(
|
||||
local_block_ids=local_block_ids,
|
||||
local_physical_block_ids=local_block_ids,
|
||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
@ -935,6 +939,7 @@ class NixlConnectorWorker:
|
||||
attn_backend=backend,
|
||||
)
|
||||
self._use_pallas = self.kv_topo._use_pallas
|
||||
self._physical_blocks_per_logical_kv_block = 1
|
||||
|
||||
def _nixl_handshake(
|
||||
self,
|
||||
@ -1133,6 +1138,22 @@ class NixlConnectorWorker:
|
||||
if base_addr in seen_base_addresses:
|
||||
continue
|
||||
|
||||
# TODO (NickLucche): Get kernel_block_size in a cleaner way
|
||||
# NHD default "view" for non-MLA cache
|
||||
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]
|
||||
|
||||
if self.block_size != kernel_block_size:
|
||||
logger.info_once(
|
||||
"User-specified logical block size (%s) does not match"
|
||||
" physical kernel block size (%s). Using the latter. ",
|
||||
self.block_size,
|
||||
kernel_block_size,
|
||||
)
|
||||
self._physical_blocks_per_logical_kv_block = (
|
||||
self.block_size // kernel_block_size
|
||||
)
|
||||
self.block_size = kernel_block_size
|
||||
|
||||
seen_base_addresses.append(base_addr)
|
||||
curr_tensor_size_bytes = cache.numel() * cache.element_size()
|
||||
|
||||
@ -1479,7 +1500,7 @@ class NixlConnectorWorker:
|
||||
assert self.use_host_buffer
|
||||
assert self.copy_blocks is not None
|
||||
|
||||
local_block_ids = meta.local_block_ids
|
||||
local_block_ids = meta.local_physical_block_ids
|
||||
self.copy_blocks(
|
||||
self.host_xfer_buffers,
|
||||
self.device_kv_caches,
|
||||
@ -1492,7 +1513,7 @@ class NixlConnectorWorker:
|
||||
"synced recved kv of request[%s] to device kv buffer,"
|
||||
"local_block_ids: %s. ",
|
||||
req_id,
|
||||
",".join(map(str, meta.local_block_ids)),
|
||||
",".join(map(str, local_block_ids)),
|
||||
)
|
||||
|
||||
def save_kv_to_host(self, metadata: NixlConnectorMetadata):
|
||||
@ -1501,19 +1522,22 @@ class NixlConnectorWorker:
|
||||
assert self.copy_blocks is not None
|
||||
|
||||
for req_id, meta in metadata.reqs_to_save.items():
|
||||
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
|
||||
meta.local_block_ids
|
||||
)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
"save_load_kv for request[%s] to host xfer buffer."
|
||||
"local_block_ids: %s. ",
|
||||
req_id,
|
||||
",".join(map(str, meta.local_block_ids)),
|
||||
",".join(map(str, meta.local_physical_block_ids)),
|
||||
)
|
||||
# blocking
|
||||
self.copy_blocks(
|
||||
self.device_kv_caches,
|
||||
self.host_xfer_buffers,
|
||||
meta.local_block_ids,
|
||||
meta.local_block_ids,
|
||||
meta.local_physical_block_ids,
|
||||
meta.local_physical_block_ids,
|
||||
"d2h",
|
||||
)
|
||||
|
||||
@ -1582,7 +1606,7 @@ class NixlConnectorWorker:
|
||||
if self.use_host_buffer:
|
||||
self.sync_recved_kv_to_device(req_id, meta)
|
||||
if self.enable_permute_local_kv:
|
||||
block_ids_to_permute += meta.local_block_ids
|
||||
block_ids_to_permute += meta.local_physical_block_ids
|
||||
if len(block_ids_to_permute) > 0:
|
||||
self.permute_device_kv(block_ids_to_permute)
|
||||
|
||||
@ -1669,7 +1693,7 @@ class NixlConnectorWorker:
|
||||
req_id,
|
||||
xfer_state,
|
||||
)
|
||||
# mark all blocks for this request as invalid
|
||||
# mark all (logical)blocks for this request as invalid
|
||||
if meta := self._recving_metadata.pop(req_id, None):
|
||||
self._invalid_block_ids.update(meta.local_block_ids)
|
||||
self._recving_metadata.pop(req_id, None)
|
||||
@ -1686,13 +1710,19 @@ class NixlConnectorWorker:
|
||||
We check for these trnxs to complete in each step().
|
||||
"""
|
||||
for req_id, meta in metadata.reqs_to_recv.items():
|
||||
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
|
||||
meta.local_block_ids
|
||||
)
|
||||
meta.remote_block_ids = self._logical_to_kernel_block_ids(
|
||||
meta.remote_block_ids
|
||||
)
|
||||
remote_engine_id = meta.remote_engine_id
|
||||
logger.debug(
|
||||
"start_load_kv for request %s from remote engine %s. "
|
||||
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
|
||||
req_id,
|
||||
remote_engine_id,
|
||||
len(meta.local_block_ids),
|
||||
len(meta.local_physical_block_ids),
|
||||
len(meta.remote_block_ids),
|
||||
)
|
||||
# always store metadata for failure recovery
|
||||
@ -1740,7 +1770,7 @@ class NixlConnectorWorker:
|
||||
self._read_blocks(
|
||||
request_id=req_id,
|
||||
dst_engine_id=meta.remote_engine_id,
|
||||
local_block_ids=meta.local_block_ids,
|
||||
local_block_ids=meta.local_physical_block_ids,
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
)
|
||||
|
||||
@ -1867,7 +1897,7 @@ class NixlConnectorWorker:
|
||||
"Marking blocks as invalid.",
|
||||
request_id,
|
||||
)
|
||||
# mark all blocks for this request as invalid
|
||||
# mark all (logical) blocks for this request as invalid
|
||||
if meta := self._recving_metadata.get(request_id):
|
||||
self._invalid_block_ids.update(meta.local_block_ids)
|
||||
self.xfer_stats.record_failed_transfer()
|
||||
@ -1906,6 +1936,23 @@ class NixlConnectorWorker:
|
||||
descs_ids = region_ids * num_blocks + block_ids
|
||||
return descs_ids.flatten()
|
||||
|
||||
def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Convert logical block ids to kernel physical block ids.
|
||||
This is required when the logical block size (the one set by the user)
|
||||
does not match the one required by the attn backend.
|
||||
"""
|
||||
if self._physical_blocks_per_logical_kv_block == 1:
|
||||
# Noop when physical and logical block sizes are the same
|
||||
return block_ids
|
||||
block_ids_np = np.array(block_ids)
|
||||
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
|
||||
1, -1
|
||||
)
|
||||
return BlockTable.map_to_kernel_blocks(
|
||||
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
|
||||
).tolist()
|
||||
|
||||
def get_backend_aware_kv_block_len(self, layer_idx: int):
|
||||
"""
|
||||
Get the block length for one K/V element (K and V have the same size).
|
||||
|
||||
@ -98,7 +98,9 @@ class BlockTable:
|
||||
return
|
||||
|
||||
if self.use_hybrid_blocks:
|
||||
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
|
||||
block_ids = self.map_to_kernel_blocks(
|
||||
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
|
||||
)
|
||||
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
@ -188,7 +190,12 @@ class BlockTable:
|
||||
self.block_table.gpu.fill_(0)
|
||||
self.block_table.cpu.fill_(0)
|
||||
|
||||
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
|
||||
@staticmethod
|
||||
def map_to_kernel_blocks(
|
||||
kv_manager_block_ids: np.ndarray,
|
||||
blocks_per_kv_block: int,
|
||||
kernel_block_arange: np.ndarray,
|
||||
) -> np.ndarray:
|
||||
"""Convert kv_manager_block_id IDs to kernel block IDs.
|
||||
|
||||
Example:
|
||||
@ -203,12 +210,12 @@ class BlockTable:
|
||||
# kv_manager_block_id 1 → kernel block id [2, 3]
|
||||
# kv_manager_block_id 2 → kernel block id [4, 5]
|
||||
"""
|
||||
if not self.use_hybrid_blocks:
|
||||
if blocks_per_kv_block == 1:
|
||||
return kv_manager_block_ids
|
||||
|
||||
kernel_block_ids = (
|
||||
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
|
||||
+ self._kernel_block_arange
|
||||
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
|
||||
+ kernel_block_arange
|
||||
)
|
||||
|
||||
return kernel_block_ids.reshape(-1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user