From 20098c10d9206fc8d18e2db20a2aaef73f395707 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 19 Sep 2025 23:27:51 +0000 Subject: [PATCH] Remove global CUDA graph pool Signed-off-by: Woosuk Kwon --- vllm/compilation/cuda_graph.py | 2 +- vllm/platforms/interface.py | 7 +------ vllm/v1/worker/gpu_ubatch_wrapper.py | 2 +- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index e233f959c0a4a..7463c655d5407 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -82,7 +82,7 @@ class CUDAGraphWrapper: # TODO: in the future, if we want to use multiple # streams, it might not be safe to share a global pool. # only investigate this when we use multiple streams - self.graph_pool = current_platform.get_global_graph_pool() + self.graph_pool = current_platform.graph_pool_handle() if cudagraph_options is None: cudagraph_options = CUDAGraphOptions() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 53fc762dce540..3ea2e219dc79e 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -140,8 +140,6 @@ class Platform: additional_env_vars: list[str] = [] - _global_graph_pool: Optional[Any] = None - @property def supported_dtypes(self) -> list[torch.dtype]: """Returns the supported dtypes for the current platform.""" @@ -539,10 +537,7 @@ class Platform: """ Return the global graph pool for this platform. """ - cls = self.__class__ - if cls._global_graph_pool is None: - cls._global_graph_pool = self.graph_pool_handle() - return cls._global_graph_pool + return self.graph_pool_handle() @classmethod def get_cu_count(cls, device_id: int = 0) -> int: diff --git a/vllm/v1/worker/gpu_ubatch_wrapper.py b/vllm/v1/worker/gpu_ubatch_wrapper.py index 5012ad0483c84..e29c3811eeccf 100644 --- a/vllm/v1/worker/gpu_ubatch_wrapper.py +++ b/vllm/v1/worker/gpu_ubatch_wrapper.py @@ -54,7 +54,7 @@ class UBatchWrapper: if runtime_mode is not CUDAGraphMode.NONE: self.cudagraph_wrapper = CUDAGraphWrapper( runnable, vllm_config, runtime_mode=runtime_mode) - self.graph_pool = current_platform.get_global_graph_pool() + self.graph_pool = current_platform.graph_pool_handle() def __getattr__(self, key: str): # allow accessing the attributes of the runnable.