diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c6437610e3bab..fecf08c7e647c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -1452,11 +1452,17 @@ class NixlConnectorWorker: len(done_recving), ) - # clean up metadata for completed requests + block_ids_to_permute = [] for req_id in done_recving: + # clean up metadata for completed requests meta = self._recving_metadata.pop(req_id, None) - if self.use_host_buffer and meta: + assert meta is not None, f"{req_id} not found in recving_metadata list" + if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) + if self.enable_permute_local_kv: + block_ids_to_permute += meta.local_block_ids + if len(block_ids_to_permute) > 0: + self.permute_device_kv(block_ids_to_permute) # Handle timeout to avoid stranding blocks on remote. now = time.perf_counter() @@ -1477,15 +1483,6 @@ class NixlConnectorWorker: del self._reqs_to_send[req_id] done_sending.add(req_id) - if self.enable_permute_local_kv and len(done_recving) > 0: - block_ids = [] - for req_id in done_recving: - meta = self._recving_metadata.pop(req_id) - assert meta, f"{req_id} not found in recving_metadata list" - block_ids += meta.local_block_ids - - self.permute_device_kv(block_ids) - return done_sending, done_recving def _get_new_notifs(self) -> set[str]: