diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a70c98b63713..5ff95876ef34 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -677,12 +677,13 @@ class NixlConnectorWorker: mapping between local and remote TP workers. """ - tp_size: int tp_rank: int remote_tp_size: dict[EngineId, int] is_mla: bool total_num_kv_heads: int attn_backend: type[AttentionBackend] + engine_id: EngineId + remote_block_size: dict[EngineId, int] def __post_init__(self): # Figure out whether the first dimension of the cache is K/V @@ -710,8 +711,13 @@ class NixlConnectorWorker: self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first ) - block_size: int - remote_block_size: dict[EngineId, int] + @property + def tp_size(self) -> int: + return self.remote_tp_size[self.engine_id] + + @property + def block_size(self) -> int: + return self.remote_block_size[self.engine_id] def tp_ratio( self, @@ -957,13 +963,12 @@ class NixlConnectorWorker: self.xfer_stats = NixlKVConnectorStats() self.kv_topo = self.TpKVTopology( - tp_size=self.world_size, tp_rank=self.tp_rank, + engine_id=self.engine_id, remote_tp_size=self._tp_size, # shared state + remote_block_size=self._block_size, # shared state is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), - block_size=self.block_size, - remote_block_size=self._block_size, attn_backend=backend, ) self._use_pallas = self.kv_topo._use_pallas @@ -1185,6 +1190,7 @@ class NixlConnectorWorker: self.block_size // kernel_block_size ) self.block_size = kernel_block_size + self._block_size[self.engine_id] = kernel_block_size seen_base_addresses.append(base_addr) curr_tensor_size_bytes = cache.numel() * cache.element_size()