[Bugfix][Nixl] Fix kernel physical<>logical block_size issue (#28677)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-11-14 15:40:05 +01:00 committed by GitHub
parent 433c0f8675
commit 96b23b8e3b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 73 additions and 17 deletions

View File

@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization():
req_index = 0 req_index = 0
block_table.append_row(kvcache_manager_blocks, req_index) block_table.append_row(kvcache_manager_blocks, req_index)
# Get expected kernel blocks from the implementation for verification. # Get expected kernel blocks from the implementation for verification.
expected_kernel_blocks = block_table._map_to_kernel_blocks( expected_kernel_blocks = block_table.map_to_kernel_blocks(
np.array(kvcache_manager_blocks) np.array(kvcache_manager_blocks),
block_table.blocks_per_kv_block,
block_table._kernel_block_arange,
) )
# Verify block table state # Verify block table state
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks) assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)

View File

@ -49,6 +49,7 @@ from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
@dataclass @dataclass
class ReqMeta: class ReqMeta:
local_block_ids: list[int] local_block_ids: list[int]
# To be used when logical block size does not match the kernel block size
local_physical_block_ids: list[int]
remote_block_ids: list[int] remote_block_ids: list[int]
remote_host: str remote_host: str
remote_port: int remote_port: int
@ -139,6 +142,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
assert load_remote_cache ^ save_to_host assert load_remote_cache ^ save_to_host
_req = ReqMeta( _req = ReqMeta(
local_block_ids=local_block_ids, local_block_ids=local_block_ids,
local_physical_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params["remote_block_ids"], remote_block_ids=kv_transfer_params["remote_block_ids"],
remote_engine_id=kv_transfer_params["remote_engine_id"], remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"], remote_host=kv_transfer_params["remote_host"],
@ -935,6 +939,7 @@ class NixlConnectorWorker:
attn_backend=backend, attn_backend=backend,
) )
self._use_pallas = self.kv_topo._use_pallas self._use_pallas = self.kv_topo._use_pallas
self._physical_blocks_per_logical_kv_block = 1
def _nixl_handshake( def _nixl_handshake(
self, self,
@ -1133,6 +1138,22 @@ class NixlConnectorWorker:
if base_addr in seen_base_addresses: if base_addr in seen_base_addresses:
continue continue
# TODO (NickLucche): Get kernel_block_size in a cleaner way
# NHD default "view" for non-MLA cache
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
" physical kernel block size (%s). Using the latter. ",
self.block_size,
kernel_block_size,
)
self._physical_blocks_per_logical_kv_block = (
self.block_size // kernel_block_size
)
self.block_size = kernel_block_size
seen_base_addresses.append(base_addr) seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size() curr_tensor_size_bytes = cache.numel() * cache.element_size()
@ -1479,7 +1500,7 @@ class NixlConnectorWorker:
assert self.use_host_buffer assert self.use_host_buffer
assert self.copy_blocks is not None assert self.copy_blocks is not None
local_block_ids = meta.local_block_ids local_block_ids = meta.local_physical_block_ids
self.copy_blocks( self.copy_blocks(
self.host_xfer_buffers, self.host_xfer_buffers,
self.device_kv_caches, self.device_kv_caches,
@ -1492,7 +1513,7 @@ class NixlConnectorWorker:
"synced recved kv of request[%s] to device kv buffer," "synced recved kv of request[%s] to device kv buffer,"
"local_block_ids: %s. ", "local_block_ids: %s. ",
req_id, req_id,
",".join(map(str, meta.local_block_ids)), ",".join(map(str, local_block_ids)),
) )
def save_kv_to_host(self, metadata: NixlConnectorMetadata): def save_kv_to_host(self, metadata: NixlConnectorMetadata):
@ -1501,19 +1522,22 @@ class NixlConnectorWorker:
assert self.copy_blocks is not None assert self.copy_blocks is not None
for req_id, meta in metadata.reqs_to_save.items(): for req_id, meta in metadata.reqs_to_save.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug( logger.debug(
"save_load_kv for request[%s] to host xfer buffer." "save_load_kv for request[%s] to host xfer buffer."
"local_block_ids: %s. ", "local_block_ids: %s. ",
req_id, req_id,
",".join(map(str, meta.local_block_ids)), ",".join(map(str, meta.local_physical_block_ids)),
) )
# blocking # blocking
self.copy_blocks( self.copy_blocks(
self.device_kv_caches, self.device_kv_caches,
self.host_xfer_buffers, self.host_xfer_buffers,
meta.local_block_ids, meta.local_physical_block_ids,
meta.local_block_ids, meta.local_physical_block_ids,
"d2h", "d2h",
) )
@ -1582,7 +1606,7 @@ class NixlConnectorWorker:
if self.use_host_buffer: if self.use_host_buffer:
self.sync_recved_kv_to_device(req_id, meta) self.sync_recved_kv_to_device(req_id, meta)
if self.enable_permute_local_kv: if self.enable_permute_local_kv:
block_ids_to_permute += meta.local_block_ids block_ids_to_permute += meta.local_physical_block_ids
if len(block_ids_to_permute) > 0: if len(block_ids_to_permute) > 0:
self.permute_device_kv(block_ids_to_permute) self.permute_device_kv(block_ids_to_permute)
@ -1669,7 +1693,7 @@ class NixlConnectorWorker:
req_id, req_id,
xfer_state, xfer_state,
) )
# mark all blocks for this request as invalid # mark all (logical)blocks for this request as invalid
if meta := self._recving_metadata.pop(req_id, None): if meta := self._recving_metadata.pop(req_id, None):
self._invalid_block_ids.update(meta.local_block_ids) self._invalid_block_ids.update(meta.local_block_ids)
self._recving_metadata.pop(req_id, None) self._recving_metadata.pop(req_id, None)
@ -1686,13 +1710,19 @@ class NixlConnectorWorker:
We check for these trnxs to complete in each step(). We check for these trnxs to complete in each step().
""" """
for req_id, meta in metadata.reqs_to_recv.items(): for req_id, meta in metadata.reqs_to_recv.items():
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
meta.local_block_ids
)
meta.remote_block_ids = self._logical_to_kernel_block_ids(
meta.remote_block_ids
)
remote_engine_id = meta.remote_engine_id remote_engine_id = meta.remote_engine_id
logger.debug( logger.debug(
"start_load_kv for request %s from remote engine %s. " "start_load_kv for request %s from remote engine %s. "
"Num local_block_ids: %s. Num remote_block_ids: %s. ", "Num local_block_ids: %s. Num remote_block_ids: %s. ",
req_id, req_id,
remote_engine_id, remote_engine_id,
len(meta.local_block_ids), len(meta.local_physical_block_ids),
len(meta.remote_block_ids), len(meta.remote_block_ids),
) )
# always store metadata for failure recovery # always store metadata for failure recovery
@ -1740,7 +1770,7 @@ class NixlConnectorWorker:
self._read_blocks( self._read_blocks(
request_id=req_id, request_id=req_id,
dst_engine_id=meta.remote_engine_id, dst_engine_id=meta.remote_engine_id,
local_block_ids=meta.local_block_ids, local_block_ids=meta.local_physical_block_ids,
remote_block_ids=meta.remote_block_ids, remote_block_ids=meta.remote_block_ids,
) )
@ -1867,7 +1897,7 @@ class NixlConnectorWorker:
"Marking blocks as invalid.", "Marking blocks as invalid.",
request_id, request_id,
) )
# mark all blocks for this request as invalid # mark all (logical) blocks for this request as invalid
if meta := self._recving_metadata.get(request_id): if meta := self._recving_metadata.get(request_id):
self._invalid_block_ids.update(meta.local_block_ids) self._invalid_block_ids.update(meta.local_block_ids)
self.xfer_stats.record_failed_transfer() self.xfer_stats.record_failed_transfer()
@ -1906,6 +1936,23 @@ class NixlConnectorWorker:
descs_ids = region_ids * num_blocks + block_ids descs_ids = region_ids * num_blocks + block_ids
return descs_ids.flatten() return descs_ids.flatten()
def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
"""
Convert logical block ids to kernel physical block ids.
This is required when the logical block size (the one set by the user)
does not match the one required by the attn backend.
"""
if self._physical_blocks_per_logical_kv_block == 1:
# Noop when physical and logical block sizes are the same
return block_ids
block_ids_np = np.array(block_ids)
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
1, -1
)
return BlockTable.map_to_kernel_blocks(
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
).tolist()
def get_backend_aware_kv_block_len(self, layer_idx: int): def get_backend_aware_kv_block_len(self, layer_idx: int):
""" """
Get the block length for one K/V element (K and V have the same size). Get the block length for one K/V element (K and V have the same size).

View File

@ -98,7 +98,9 @@ class BlockTable:
return return
if self.use_hybrid_blocks: if self.use_hybrid_blocks:
block_ids = self._map_to_kernel_blocks(np.array(block_ids)) block_ids = self.map_to_kernel_blocks(
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
)
num_blocks = len(block_ids) num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx] start = self.num_blocks_per_row[row_idx]
@ -188,7 +190,12 @@ class BlockTable:
self.block_table.gpu.fill_(0) self.block_table.gpu.fill_(0)
self.block_table.cpu.fill_(0) self.block_table.cpu.fill_(0)
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray: @staticmethod
def map_to_kernel_blocks(
kv_manager_block_ids: np.ndarray,
blocks_per_kv_block: int,
kernel_block_arange: np.ndarray,
) -> np.ndarray:
"""Convert kv_manager_block_id IDs to kernel block IDs. """Convert kv_manager_block_id IDs to kernel block IDs.
Example: Example:
@ -203,12 +210,12 @@ class BlockTable:
# kv_manager_block_id 1 → kernel block id [2, 3] # kv_manager_block_id 1 → kernel block id [2, 3]
# kv_manager_block_id 2 → kernel block id [4, 5] # kv_manager_block_id 2 → kernel block id [4, 5]
""" """
if not self.use_hybrid_blocks: if blocks_per_kv_block == 1:
return kv_manager_block_ids return kv_manager_block_ids
kernel_block_ids = ( kernel_block_ids = (
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
+ self._kernel_block_arange + kernel_block_arange
) )
return kernel_block_ids.reshape(-1) return kernel_block_ids.reshape(-1)