From 7272bfae77310c417e0b747a0590d9dd782c9f6b Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Mon, 21 Apr 2025 21:25:49 +0800 Subject: [PATCH] [Misc] Refactor platform to get device specific stream and event (#14411) Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm/platforms/interface.py | 9 +++++++++ vllm/spec_decode/metrics.py | 11 ++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 8c099b9531c5f..4707c3749b7e2 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -404,6 +404,15 @@ class Platform: ) -> None: """Raises if this request is unsupported on this platform""" + def __getattr__(self, key: str): + device = getattr(torch, self.device_name, None) + if device is not None and hasattr(device, key): + return getattr(device, key) + else: + logger.warning("Current platform %s doesn't has '%s' attribute.", + self.device_name, key) + return None + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index bc0e0a121cd55..0bb8d602ec8f1 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -8,6 +8,7 @@ import torch from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler) +from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available @@ -89,14 +90,14 @@ class AsyncMetricsCollector: self._rank = rank if isinstance(device_type, torch.device): device_type = device_type.type - if device_type == 'cuda': - self._copy_stream = torch.cuda.Stream() + stream = current_platform.Stream + if stream is not None: + self._copy_stream = stream() def maybe_collect_rejsample_metrics( self, k: int) -> Optional[SpecDecodeWorkerMetrics]: - # currently using cuda.Event, skip for any non_cuda_alike platform - from vllm.platforms import current_platform - if not current_platform.is_cuda_alike(): + # Skip for any platform that doesn't have device Event + if current_platform.Event is None: return None # If a copy was initiated in the previous call, collect and return.