mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 20:15:01 +08:00
[Platform] Move async output check to platform (#10768)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
parent
e691b26f6f
commit
aea2fc38c3
@ -513,11 +513,10 @@ class ModelConfig:
|
||||
|
||||
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||
# If the feature combo become valid
|
||||
if device_config.device_type not in ("cuda", "tpu", "xpu", "hpu"):
|
||||
if not current_platform.is_async_output_supported(self.enforce_eager):
|
||||
logger.warning(
|
||||
"Async output processing is only supported for CUDA, TPU, XPU "
|
||||
"and HPU."
|
||||
"Disabling it for other platforms.")
|
||||
"Async output processing is not supported on the "
|
||||
"current platform type %s.", current_platform.device_type)
|
||||
self.use_async_output_proc = False
|
||||
return
|
||||
|
||||
@ -527,16 +526,6 @@ class ModelConfig:
|
||||
self.use_async_output_proc = False
|
||||
return
|
||||
|
||||
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
|
||||
# If the feature combo become valid
|
||||
if device_config.device_type == "cuda" and self.enforce_eager:
|
||||
logger.warning(
|
||||
"To see benefits of async output processing, enable CUDA "
|
||||
"graph. Since, enforce-eager is enabled, async output "
|
||||
"processor cannot be used")
|
||||
self.use_async_output_proc = not self.enforce_eager
|
||||
return
|
||||
|
||||
# Async postprocessor is not necessary with embedding mode
|
||||
# since there is no token generation
|
||||
if self.task == "embedding":
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@ -37,6 +37,10 @@ class CpuPlatform(Platform):
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@ -4,7 +4,7 @@ pynvml. However, it should not initialize cuda context.
|
||||
|
||||
import os
|
||||
from functools import lru_cache, wraps
|
||||
from typing import TYPE_CHECKING, Callable, List, TypeVar
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
@ -88,6 +88,16 @@ class CudaPlatformBase(Platform):
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
if enforce_eager:
|
||||
logger.warning(
|
||||
"To see benefits of async output processing, enable CUDA "
|
||||
"graph. Since, enforce-eager is enabled, async output "
|
||||
"processor cannot be used")
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_full_nvlink(cls, device_ids: List[int]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -20,6 +20,10 @@ class HpuPlatform(Platform):
|
||||
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
|
||||
return _Backend.HPU_ATTN
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
return torch.no_grad()
|
||||
|
||||
@ -6,11 +6,15 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
@ -147,6 +151,13 @@ class Platform:
|
||||
"""Get the total memory of a device in bytes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
"""
|
||||
Check if the current platform supports async output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
"""A device-specific wrapper of `torch.inference_mode`.
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
@ -18,6 +18,10 @@ class NeuronPlatform(Platform):
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "neuron"
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -37,6 +37,10 @@ class OpenVinoPlatform(Platform):
|
||||
def get_device_name(self, device_id: int = 0) -> str:
|
||||
return "openvino"
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def inference_mode(self):
|
||||
return torch.inference_mode(mode=True)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -72,6 +72,16 @@ class RocmPlatform(Platform):
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
if enforce_eager:
|
||||
logger.warning(
|
||||
"To see benefits of async output processing, enable CUDA "
|
||||
"graph. Since, enforce-eager is enabled, async output "
|
||||
"processor cannot be used")
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -35,6 +35,10 @@ class TpuPlatform(Platform):
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -41,6 +41,10 @@ class XPUPlatform(Platform):
|
||||
device_props = torch.xpu.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
return torch.no_grad()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user