mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-20 23:49:14 +08:00
[Bugfix][plugin] fla crash on plugin (#27322)
This commit is contained in:
parent
01baefe674
commit
ccd3e55e51
@ -17,6 +17,7 @@ from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -137,8 +138,8 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
||||
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
||||
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
||||
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
||||
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
||||
device_torch_lib = getattr(torch, device)
|
||||
device = "cuda" if current_platform.is_cuda_alike() else get_available_device()
|
||||
device_torch_lib = getattr(torch, device, None)
|
||||
device_platform = _check_platform()
|
||||
|
||||
is_amd = device_platform == "amd"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user