[Misc] Refactor platform to get device specific stream and event (#14411)

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen 2025-04-21 21:25:49 +08:00 committed by GitHub
parent d9ac9e3dc5
commit 7272bfae77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 15 additions and 5 deletions

View File

@ -404,6 +404,15 @@ class Platform:
) -> None:
"""Raises if this request is unsupported on this platform"""
def __getattr__(self, key: str):
device = getattr(torch, self.device_name, None)
if device is not None and hasattr(device, key):
return getattr(device, key)
else:
logger.warning("Current platform %s doesn't has '%s' attribute.",
self.device_name, key)
return None
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

View File

@ -8,6 +8,7 @@ import torch
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
@ -89,14 +90,14 @@ class AsyncMetricsCollector:
self._rank = rank
if isinstance(device_type, torch.device):
device_type = device_type.type
if device_type == 'cuda':
self._copy_stream = torch.cuda.Stream()
stream = current_platform.Stream
if stream is not None:
self._copy_stream = stream()
def maybe_collect_rejsample_metrics(
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
# currently using cuda.Event, skip for any non_cuda_alike platform
from vllm.platforms import current_platform
if not current_platform.is_cuda_alike():
# Skip for any platform that doesn't have device Event
if current_platform.Event is None:
return None
# If a copy was initiated in the previous call, collect and return.