[Bugfix][Nixl] Fix kernel physical<>logical block_size issue (#28677)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-11-14 15:40:05 +01:00 committed by GitHub
parent 433c0f8675
commit 96b23b8e3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 73 additions and 17 deletions

View File

@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization():
req_index = 0
block_table.append_row(kvcache_manager_blocks, req_index)
# Get expected kernel blocks from the implementation for verification.
expected_kernel_blocks = block_table._map_to_kernel_blocks(
np.array(kvcache_manager_blocks)
expected_kernel_blocks = block_table.map_to_kernel_blocks(
np.array(kvcache_manager_blocks),
block_table.blocks_per_kv_block,
block_table._kernel_block_arange,
)
# Verify block table state
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)

View File

@ -49,6 +49,7 @@ from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
@dataclass
class ReqMeta:
local_block_ids: list[int]
# To be used when logical block size does not match the kernel block size
local_physical_block_ids: list[int]
remote_block_ids: list[int]
remote_host: str
remote_port: int
@ -139,6 +142,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
assert load_remote_cache ^ save_to_host
_req = ReqMeta(
local_block_ids=local_block_ids,
local_physical_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
@ -935,6 +939,7 @@ class NixlConnectorWorker:
attn_backend=backend,
)
self._use_pallas = self.kv_topo._use_pallas
self._physical_blocks_per_logical_kv_block = 1
def _nixl_handshake(
self,
@ -1133,6 +1138,22 @@ class NixlConnectorWorker:
if base_addr in seen_base_addresses:
continue
# TODO (NickLucche): Get kernel_block_size in a cleaner way
# NHD default "view" for non-MLA cache
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter. ",
self.block_size,
kernel_block_size,
)
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size()
@ -1479,7 +1500,7 @@ class NixlConnectorWorker:
assert self.use_host_buffer
assert self.copy_blocks is not None
local_block_ids = meta.local_block_ids
local_block_ids = meta.local_physical_block_ids
self.copy_blocks(
self.host_xfer_buffers,
self.device_kv_caches,
@ -1492,7 +1513,7 @@ class NixlConnectorWorker:
"synced recved kv of request[%s] to device kv buffer,"
"local_block_ids: %s. ",
req_id,
",".join(map(str, meta.local_block_ids)),
",".join(map(str, local_block_ids)),
)
def save_kv_to_host(self, metadata: NixlConnectorMetadata):
@ -1501,19 +1522,22 @@ class NixlConnectorWorker:
assert self.copy_blocks is not None
for req_id, meta in metadata.reqs_to_save.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"save_load_kv for request[%s] to host xfer buffer."
"local_block_ids: %s. ",
req_id,
",".join(map(str, meta.local_block_ids)),
",".join(map(str, meta.local_physical_block_ids)),
)
# blocking
self.copy_blocks(
self.device_kv_caches,
self.host_xfer_buffers,
meta.local_block_ids,
meta.local_block_ids,
meta.local_physical_block_ids,
meta.local_physical_block_ids,
"d2h",
)
@ -1582,7 +1606,7 @@ class NixlConnectorWorker:
if self.use_host_buffer:
self.sync_recved_kv_to_device(req_id, meta)
if self.enable_permute_local_kv:
block_ids_to_permute += meta.local_block_ids
block_ids_to_permute += meta.local_physical_block_ids
if len(block_ids_to_permute) > 0:
self.permute_device_kv(block_ids_to_permute)
@ -1669,7 +1693,7 @@ class NixlConnectorWorker:
req_id,
xfer_state,
)
# mark all blocks for this request as invalid
# mark all (logical)blocks for this request as invalid
if meta := self._recving_metadata.pop(req_id, None):
self._invalid_block_ids.update(meta.local_block_ids)
self._recving_metadata.pop(req_id, None)
@ -1686,13 +1710,19 @@ class NixlConnectorWorker:
We check for these trnxs to complete in each step().
"""
for req_id, meta in metadata.reqs_to_recv.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
meta.remote_block_ids = self._logical_to_kernel_block_ids(
meta.remote_block_ids
)
remote_engine_id = meta.remote_engine_id
logger.debug(
"start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
req_id,
remote_engine_id,
len(meta.local_block_ids),
len(meta.local_physical_block_ids),
len(meta.remote_block_ids),
)
# always store metadata for failure recovery
@ -1740,7 +1770,7 @@ class NixlConnectorWorker:
self._read_blocks(
request_id=req_id,
dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids,
local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote_block_ids,
)
@ -1867,7 +1897,7 @@ class NixlConnectorWorker:
"Marking blocks as invalid.",
request_id,
)
# mark all blocks for this request as invalid
# mark all (logical) blocks for this request as invalid
if meta := self._recving_metadata.get(request_id):
self._invalid_block_ids.update(meta.local_block_ids)
self.xfer_stats.record_failed_transfer()
@ -1906,6 +1936,23 @@ class NixlConnectorWorker:
descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten()
def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
"""
Convert logical block ids to kernel physical block ids.
This is required when the logical block size (the one set by the user)
does not match the one required by the attn backend.
"""
if self._physical_blocks_per_logical_kv_block == 1:
# Noop when physical and logical block sizes are the same
return block_ids
block_ids_np = np.array(block_ids)
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
1, -1
)
return BlockTable.map_to_kernel_blocks(
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
).tolist()
def get_backend_aware_kv_block_len(self, layer_idx: int):
"""
Get the block length for one K/V element (K and V have the same size).

View File

@ -98,7 +98,9 @@ class BlockTable:
return
if self.use_hybrid_blocks:
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
block_ids = self.map_to_kernel_blocks(
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
)
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
@ -188,7 +190,12 @@ class BlockTable:
self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0)
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
@staticmethod
def map_to_kernel_blocks(
kv_manager_block_ids: np.ndarray,
blocks_per_kv_block: int,
kernel_block_arange: np.ndarray,
) -> np.ndarray:
"""Convert kv_manager_block_id IDs to kernel block IDs.
Example:
@ -203,12 +210,12 @@ class BlockTable:
# kv_manager_block_id 1 → kernel block id [2, 3]
# kv_manager_block_id 2 → kernel block id [4, 5]
"""
if not self.use_hybrid_blocks:
if blocks_per_kv_block == 1:
return kv_manager_block_ids
kernel_block_ids = (
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
+ self._kernel_block_arange
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
+ kernel_block_arange
)
return kernel_block_ids.reshape(-1)