mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-06 19:29:09 +08:00
[Bugfix][Nixl] Fix kernel physical<>logical block_size issue (#28677)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
433c0f8675
commit
96b23b8e3b
@ -985,8 +985,10 @@ def test_hybrid_block_table_initialization():
|
|||||||
req_index = 0
|
req_index = 0
|
||||||
block_table.append_row(kvcache_manager_blocks, req_index)
|
block_table.append_row(kvcache_manager_blocks, req_index)
|
||||||
# Get expected kernel blocks from the implementation for verification.
|
# Get expected kernel blocks from the implementation for verification.
|
||||||
expected_kernel_blocks = block_table._map_to_kernel_blocks(
|
expected_kernel_blocks = block_table.map_to_kernel_blocks(
|
||||||
np.array(kvcache_manager_blocks)
|
np.array(kvcache_manager_blocks),
|
||||||
|
block_table.blocks_per_kv_block,
|
||||||
|
block_table._kernel_block_arange,
|
||||||
)
|
)
|
||||||
# Verify block table state
|
# Verify block table state
|
||||||
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)
|
assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks)
|
||||||
|
|||||||
@ -49,6 +49,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
|
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
|
||||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
@ -112,6 +113,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ReqMeta:
|
class ReqMeta:
|
||||||
local_block_ids: list[int]
|
local_block_ids: list[int]
|
||||||
|
# To be used when logical block size does not match the kernel block size
|
||||||
|
local_physical_block_ids: list[int]
|
||||||
remote_block_ids: list[int]
|
remote_block_ids: list[int]
|
||||||
remote_host: str
|
remote_host: str
|
||||||
remote_port: int
|
remote_port: int
|
||||||
@ -139,6 +142,7 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
|||||||
assert load_remote_cache ^ save_to_host
|
assert load_remote_cache ^ save_to_host
|
||||||
_req = ReqMeta(
|
_req = ReqMeta(
|
||||||
local_block_ids=local_block_ids,
|
local_block_ids=local_block_ids,
|
||||||
|
local_physical_block_ids=local_block_ids,
|
||||||
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
remote_block_ids=kv_transfer_params["remote_block_ids"],
|
||||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||||
remote_host=kv_transfer_params["remote_host"],
|
remote_host=kv_transfer_params["remote_host"],
|
||||||
@ -935,6 +939,7 @@ class NixlConnectorWorker:
|
|||||||
attn_backend=backend,
|
attn_backend=backend,
|
||||||
)
|
)
|
||||||
self._use_pallas = self.kv_topo._use_pallas
|
self._use_pallas = self.kv_topo._use_pallas
|
||||||
|
self._physical_blocks_per_logical_kv_block = 1
|
||||||
|
|
||||||
def _nixl_handshake(
|
def _nixl_handshake(
|
||||||
self,
|
self,
|
||||||
@ -1133,6 +1138,22 @@ class NixlConnectorWorker:
|
|||||||
if base_addr in seen_base_addresses:
|
if base_addr in seen_base_addresses:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# TODO (NickLucche): Get kernel_block_size in a cleaner way
|
||||||
|
# NHD default "view" for non-MLA cache
|
||||||
|
kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3]
|
||||||
|
|
||||||
|
if self.block_size != kernel_block_size:
|
||||||
|
logger.info_once(
|
||||||
|
"User-specified logical block size (%s) does not match"
|
||||||
|
" physical kernel block size (%s). Using the latter. ",
|
||||||
|
self.block_size,
|
||||||
|
kernel_block_size,
|
||||||
|
)
|
||||||
|
self._physical_blocks_per_logical_kv_block = (
|
||||||
|
self.block_size // kernel_block_size
|
||||||
|
)
|
||||||
|
self.block_size = kernel_block_size
|
||||||
|
|
||||||
seen_base_addresses.append(base_addr)
|
seen_base_addresses.append(base_addr)
|
||||||
curr_tensor_size_bytes = cache.numel() * cache.element_size()
|
curr_tensor_size_bytes = cache.numel() * cache.element_size()
|
||||||
|
|
||||||
@ -1479,7 +1500,7 @@ class NixlConnectorWorker:
|
|||||||
assert self.use_host_buffer
|
assert self.use_host_buffer
|
||||||
assert self.copy_blocks is not None
|
assert self.copy_blocks is not None
|
||||||
|
|
||||||
local_block_ids = meta.local_block_ids
|
local_block_ids = meta.local_physical_block_ids
|
||||||
self.copy_blocks(
|
self.copy_blocks(
|
||||||
self.host_xfer_buffers,
|
self.host_xfer_buffers,
|
||||||
self.device_kv_caches,
|
self.device_kv_caches,
|
||||||
@ -1492,7 +1513,7 @@ class NixlConnectorWorker:
|
|||||||
"synced recved kv of request[%s] to device kv buffer,"
|
"synced recved kv of request[%s] to device kv buffer,"
|
||||||
"local_block_ids: %s. ",
|
"local_block_ids: %s. ",
|
||||||
req_id,
|
req_id,
|
||||||
",".join(map(str, meta.local_block_ids)),
|
",".join(map(str, local_block_ids)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_kv_to_host(self, metadata: NixlConnectorMetadata):
|
def save_kv_to_host(self, metadata: NixlConnectorMetadata):
|
||||||
@ -1501,19 +1522,22 @@ class NixlConnectorWorker:
|
|||||||
assert self.copy_blocks is not None
|
assert self.copy_blocks is not None
|
||||||
|
|
||||||
for req_id, meta in metadata.reqs_to_save.items():
|
for req_id, meta in metadata.reqs_to_save.items():
|
||||||
|
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
|
||||||
|
meta.local_block_ids
|
||||||
|
)
|
||||||
if logger.isEnabledFor(logging.DEBUG):
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"save_load_kv for request[%s] to host xfer buffer."
|
"save_load_kv for request[%s] to host xfer buffer."
|
||||||
"local_block_ids: %s. ",
|
"local_block_ids: %s. ",
|
||||||
req_id,
|
req_id,
|
||||||
",".join(map(str, meta.local_block_ids)),
|
",".join(map(str, meta.local_physical_block_ids)),
|
||||||
)
|
)
|
||||||
# blocking
|
# blocking
|
||||||
self.copy_blocks(
|
self.copy_blocks(
|
||||||
self.device_kv_caches,
|
self.device_kv_caches,
|
||||||
self.host_xfer_buffers,
|
self.host_xfer_buffers,
|
||||||
meta.local_block_ids,
|
meta.local_physical_block_ids,
|
||||||
meta.local_block_ids,
|
meta.local_physical_block_ids,
|
||||||
"d2h",
|
"d2h",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1582,7 +1606,7 @@ class NixlConnectorWorker:
|
|||||||
if self.use_host_buffer:
|
if self.use_host_buffer:
|
||||||
self.sync_recved_kv_to_device(req_id, meta)
|
self.sync_recved_kv_to_device(req_id, meta)
|
||||||
if self.enable_permute_local_kv:
|
if self.enable_permute_local_kv:
|
||||||
block_ids_to_permute += meta.local_block_ids
|
block_ids_to_permute += meta.local_physical_block_ids
|
||||||
if len(block_ids_to_permute) > 0:
|
if len(block_ids_to_permute) > 0:
|
||||||
self.permute_device_kv(block_ids_to_permute)
|
self.permute_device_kv(block_ids_to_permute)
|
||||||
|
|
||||||
@ -1669,7 +1693,7 @@ class NixlConnectorWorker:
|
|||||||
req_id,
|
req_id,
|
||||||
xfer_state,
|
xfer_state,
|
||||||
)
|
)
|
||||||
# mark all blocks for this request as invalid
|
# mark all (logical)blocks for this request as invalid
|
||||||
if meta := self._recving_metadata.pop(req_id, None):
|
if meta := self._recving_metadata.pop(req_id, None):
|
||||||
self._invalid_block_ids.update(meta.local_block_ids)
|
self._invalid_block_ids.update(meta.local_block_ids)
|
||||||
self._recving_metadata.pop(req_id, None)
|
self._recving_metadata.pop(req_id, None)
|
||||||
@ -1686,13 +1710,19 @@ class NixlConnectorWorker:
|
|||||||
We check for these trnxs to complete in each step().
|
We check for these trnxs to complete in each step().
|
||||||
"""
|
"""
|
||||||
for req_id, meta in metadata.reqs_to_recv.items():
|
for req_id, meta in metadata.reqs_to_recv.items():
|
||||||
|
meta.local_physical_block_ids = self._logical_to_kernel_block_ids(
|
||||||
|
meta.local_block_ids
|
||||||
|
)
|
||||||
|
meta.remote_block_ids = self._logical_to_kernel_block_ids(
|
||||||
|
meta.remote_block_ids
|
||||||
|
)
|
||||||
remote_engine_id = meta.remote_engine_id
|
remote_engine_id = meta.remote_engine_id
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"start_load_kv for request %s from remote engine %s. "
|
"start_load_kv for request %s from remote engine %s. "
|
||||||
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
|
"Num local_block_ids: %s. Num remote_block_ids: %s. ",
|
||||||
req_id,
|
req_id,
|
||||||
remote_engine_id,
|
remote_engine_id,
|
||||||
len(meta.local_block_ids),
|
len(meta.local_physical_block_ids),
|
||||||
len(meta.remote_block_ids),
|
len(meta.remote_block_ids),
|
||||||
)
|
)
|
||||||
# always store metadata for failure recovery
|
# always store metadata for failure recovery
|
||||||
@ -1740,7 +1770,7 @@ class NixlConnectorWorker:
|
|||||||
self._read_blocks(
|
self._read_blocks(
|
||||||
request_id=req_id,
|
request_id=req_id,
|
||||||
dst_engine_id=meta.remote_engine_id,
|
dst_engine_id=meta.remote_engine_id,
|
||||||
local_block_ids=meta.local_block_ids,
|
local_block_ids=meta.local_physical_block_ids,
|
||||||
remote_block_ids=meta.remote_block_ids,
|
remote_block_ids=meta.remote_block_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1867,7 +1897,7 @@ class NixlConnectorWorker:
|
|||||||
"Marking blocks as invalid.",
|
"Marking blocks as invalid.",
|
||||||
request_id,
|
request_id,
|
||||||
)
|
)
|
||||||
# mark all blocks for this request as invalid
|
# mark all (logical) blocks for this request as invalid
|
||||||
if meta := self._recving_metadata.get(request_id):
|
if meta := self._recving_metadata.get(request_id):
|
||||||
self._invalid_block_ids.update(meta.local_block_ids)
|
self._invalid_block_ids.update(meta.local_block_ids)
|
||||||
self.xfer_stats.record_failed_transfer()
|
self.xfer_stats.record_failed_transfer()
|
||||||
@ -1906,6 +1936,23 @@ class NixlConnectorWorker:
|
|||||||
descs_ids = region_ids * num_blocks + block_ids
|
descs_ids = region_ids * num_blocks + block_ids
|
||||||
return descs_ids.flatten()
|
return descs_ids.flatten()
|
||||||
|
|
||||||
|
def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
|
||||||
|
"""
|
||||||
|
Convert logical block ids to kernel physical block ids.
|
||||||
|
This is required when the logical block size (the one set by the user)
|
||||||
|
does not match the one required by the attn backend.
|
||||||
|
"""
|
||||||
|
if self._physical_blocks_per_logical_kv_block == 1:
|
||||||
|
# Noop when physical and logical block sizes are the same
|
||||||
|
return block_ids
|
||||||
|
block_ids_np = np.array(block_ids)
|
||||||
|
block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape(
|
||||||
|
1, -1
|
||||||
|
)
|
||||||
|
return BlockTable.map_to_kernel_blocks(
|
||||||
|
block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange
|
||||||
|
).tolist()
|
||||||
|
|
||||||
def get_backend_aware_kv_block_len(self, layer_idx: int):
|
def get_backend_aware_kv_block_len(self, layer_idx: int):
|
||||||
"""
|
"""
|
||||||
Get the block length for one K/V element (K and V have the same size).
|
Get the block length for one K/V element (K and V have the same size).
|
||||||
|
|||||||
@ -98,7 +98,9 @@ class BlockTable:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if self.use_hybrid_blocks:
|
if self.use_hybrid_blocks:
|
||||||
block_ids = self._map_to_kernel_blocks(np.array(block_ids))
|
block_ids = self.map_to_kernel_blocks(
|
||||||
|
np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange
|
||||||
|
)
|
||||||
|
|
||||||
num_blocks = len(block_ids)
|
num_blocks = len(block_ids)
|
||||||
start = self.num_blocks_per_row[row_idx]
|
start = self.num_blocks_per_row[row_idx]
|
||||||
@ -188,7 +190,12 @@ class BlockTable:
|
|||||||
self.block_table.gpu.fill_(0)
|
self.block_table.gpu.fill_(0)
|
||||||
self.block_table.cpu.fill_(0)
|
self.block_table.cpu.fill_(0)
|
||||||
|
|
||||||
def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray:
|
@staticmethod
|
||||||
|
def map_to_kernel_blocks(
|
||||||
|
kv_manager_block_ids: np.ndarray,
|
||||||
|
blocks_per_kv_block: int,
|
||||||
|
kernel_block_arange: np.ndarray,
|
||||||
|
) -> np.ndarray:
|
||||||
"""Convert kv_manager_block_id IDs to kernel block IDs.
|
"""Convert kv_manager_block_id IDs to kernel block IDs.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -203,12 +210,12 @@ class BlockTable:
|
|||||||
# kv_manager_block_id 1 → kernel block id [2, 3]
|
# kv_manager_block_id 1 → kernel block id [2, 3]
|
||||||
# kv_manager_block_id 2 → kernel block id [4, 5]
|
# kv_manager_block_id 2 → kernel block id [4, 5]
|
||||||
"""
|
"""
|
||||||
if not self.use_hybrid_blocks:
|
if blocks_per_kv_block == 1:
|
||||||
return kv_manager_block_ids
|
return kv_manager_block_ids
|
||||||
|
|
||||||
kernel_block_ids = (
|
kernel_block_ids = (
|
||||||
kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block
|
kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block
|
||||||
+ self._kernel_block_arange
|
+ kernel_block_arange
|
||||||
)
|
)
|
||||||
|
|
||||||
return kernel_block_ids.reshape(-1)
|
return kernel_block_ids.reshape(-1)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user