diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5d49bbaf270bb..4f6cf9a350706 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -316,6 +316,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Cached outputs. self._draft_token_ids: Optional[Union[list[list[int]], torch.Tensor]] = None + self.transfer_event = torch.cuda.Event() + self.sampled_token_ids_pinned_cpu = torch.empty( + (self.max_model_len, 1), + dtype=torch.int64, + device="cpu", + pin_memory=True) def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer: return CpuGpuBuffer(*args, @@ -1691,7 +1697,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_gen_len = sampled_token_ids.shape[-1] if max_gen_len == 1: # No spec decode tokens. - valid_sampled_token_ids = sampled_token_ids.tolist() + valid_sampled_token_ids = self._to_list(sampled_token_ids) else: # Includes spec decode tokens. valid_sampled_token_ids = self.rejection_sampler.parse_output( @@ -2219,7 +2225,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - CUDAGraphMode.PIECEWISE: Piecewise cudagraph. - CUDAGraphMode.FULL: Full cudagraph, attention metadata is needed. - force_attention: If True, always create attention metadata. Used to + force_attention: If True, always create attention metadata. Used to warm up attention backend when mode is NONE. uniform_decode: If True, the batch is a uniform decode batch. skip_eplb: If True, skip EPLB state update. @@ -3233,3 +3239,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): mamba_type=mamba_module.mamba_type) return kv_cache_spec + + def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]: + # This is a short term mitigation for issue mentioned in + # https://github.com/vllm-project/vllm/issues/22754. + # `tolist` would trigger a cuda wise stream sync, which + # would block other copy ops from other cuda streams. + # A cuda event sync would avoid such a situation. Since + # this is in the critical path of every single model + # forward loop, this has caused perf issue for a disagg + # setup. + pinned = self.sampled_token_ids_pinned_cpu[:sampled_token_ids.shape[0]] + pinned.copy_(sampled_token_ids, non_blocking=True) + self.transfer_event.record() + self.transfer_event.synchronize() + return pinned.tolist()