[Bugfix][NIXL] Fix block_size_ratio when logical !=physical blocks (#28925)

Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Nicolò Lucchesi 2025-11-18 15:07:50 +01:00 committed by GitHub
parent b9489f51e1
commit 184b12fdc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -677,12 +677,13 @@ class NixlConnectorWorker:
mapping between local and remote TP workers. mapping between local and remote TP workers.
""" """
tp_size: int
tp_rank: int tp_rank: int
remote_tp_size: dict[EngineId, int] remote_tp_size: dict[EngineId, int]
is_mla: bool is_mla: bool
total_num_kv_heads: int total_num_kv_heads: int
attn_backend: type[AttentionBackend] attn_backend: type[AttentionBackend]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
def __post_init__(self): def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V # Figure out whether the first dimension of the cache is K/V
@ -710,8 +711,13 @@ class NixlConnectorWorker:
self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first
) )
block_size: int @property
remote_block_size: dict[EngineId, int] def tp_size(self) -> int:
return self.remote_tp_size[self.engine_id]
@property
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
def tp_ratio( def tp_ratio(
self, self,
@ -957,13 +963,12 @@ class NixlConnectorWorker:
self.xfer_stats = NixlKVConnectorStats() self.xfer_stats = NixlKVConnectorStats()
self.kv_topo = self.TpKVTopology( self.kv_topo = self.TpKVTopology(
tp_size=self.world_size,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla, is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(), total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
block_size=self.block_size,
remote_block_size=self._block_size,
attn_backend=backend, attn_backend=backend,
) )
self._use_pallas = self.kv_topo._use_pallas self._use_pallas = self.kv_topo._use_pallas
@ -1185,6 +1190,7 @@ class NixlConnectorWorker:
self.block_size // kernel_block_size self.block_size // kernel_block_size
) )
self.block_size = kernel_block_size self.block_size = kernel_block_size
self._block_size[self.engine_id] = kernel_block_size
seen_base_addresses.append(base_addr) seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size() curr_tensor_size_bytes = cache.numel() * cache.element_size()