mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-18 23:54:38 +08:00
[Misc] Refactor platform to get device specific stream and event (#14411)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
parent
d9ac9e3dc5
commit
7272bfae77
@ -404,6 +404,15 @@ class Platform:
|
||||
) -> None:
|
||||
"""Raises if this request is unsupported on this platform"""
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
device = getattr(torch, self.device_name, None)
|
||||
if device is not None and hasattr(device, key):
|
||||
return getattr(device, key)
|
||||
else:
|
||||
logger.warning("Current platform %s doesn't has '%s' attribute.",
|
||||
self.device_name, key)
|
||||
return None
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
@ -89,14 +90,14 @@ class AsyncMetricsCollector:
|
||||
self._rank = rank
|
||||
if isinstance(device_type, torch.device):
|
||||
device_type = device_type.type
|
||||
if device_type == 'cuda':
|
||||
self._copy_stream = torch.cuda.Stream()
|
||||
stream = current_platform.Stream
|
||||
if stream is not None:
|
||||
self._copy_stream = stream()
|
||||
|
||||
def maybe_collect_rejsample_metrics(
|
||||
self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
|
||||
# currently using cuda.Event, skip for any non_cuda_alike platform
|
||||
from vllm.platforms import current_platform
|
||||
if not current_platform.is_cuda_alike():
|
||||
# Skip for any platform that doesn't have device Event
|
||||
if current_platform.Event is None:
|
||||
return None
|
||||
|
||||
# If a copy was initiated in the previous call, collect and return.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user