From 9c1baa5bc6caedabeac1a6da57ec79b41e13056d Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Fri, 23 May 2025 12:38:50 +0800 Subject: [PATCH] [Misc] Replace `cuda` hard code with `current_platform` (#16983) Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm/distributed/parallel_state.py | 5 +++-- vllm/forward_context.py | 5 ++++- vllm/spec_decode/metrics.py | 8 ++++---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 51c519d8f8623..f67c018891889 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1221,8 +1221,9 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): ray.shutdown() gc.collect() from vllm.platforms import current_platform - if not current_platform.is_cpu(): - torch.cuda.empty_cache() + empty_cache = current_platform.empty_cache + if empty_cache is not None: + empty_cache() try: torch._C._host_emptyCache() except AttributeError: diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 5d2d95f18d2fa..3c8083e3dd0dd 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any, # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch - torch.cuda.synchronize() + from vllm.platforms import current_platform + synchronize = current_platform.synchronize + if synchronize is not None: + synchronize() now = time.perf_counter() # time measurement is in milliseconds batchsize_forward_time[batchsize].append( diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 0bb8d602ec8f1..4430da26c0493 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -126,12 +126,12 @@ class AsyncMetricsCollector: """Copy rejection/typical-acceptance sampling metrics (number of accepted tokens, etc) to CPU asynchronously. - Returns a CUDA event recording when the copy is complete. + Returns a device event recording when the copy is complete. """ assert self._copy_stream is not None - self._copy_stream.wait_stream(torch.cuda.current_stream()) + self._copy_stream.wait_stream(current_platform.current_stream()) - with torch.cuda.stream(self._copy_stream): + with current_platform.stream(self._copy_stream): self._aggregate_num_accepted_tokens.copy_( self.spec_decode_sampler.num_accepted_tokens, non_blocking=True) @@ -142,7 +142,7 @@ class AsyncMetricsCollector: self._aggregate_num_draft_tokens = ( self.spec_decode_sampler.num_draft_tokens) - aggregate_metrics_ready = torch.cuda.Event() + aggregate_metrics_ready = current_platform.Event() aggregate_metrics_ready.record(self._copy_stream) return aggregate_metrics_ready