mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 07:17:14 +08:00
[Hardware] correct method signatures for HPU,ROCm,XPU (#18551)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
parent
9c1baa5bc6
commit
60cad94b86
@ -42,7 +42,6 @@ def tpu_platform_plugin() -> Optional[str]:
|
||||
logger.debug("Confirmed TPU platform is available.")
|
||||
except Exception as e:
|
||||
logger.debug("TPU platform is not available because: %s", str(e))
|
||||
pass
|
||||
|
||||
return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None
|
||||
|
||||
@ -112,7 +111,6 @@ def rocm_platform_plugin() -> Optional[str]:
|
||||
amdsmi.amdsmi_shut_down()
|
||||
except Exception as e:
|
||||
logger.debug("ROCm platform is not available because: %s", str(e))
|
||||
pass
|
||||
|
||||
return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None
|
||||
|
||||
@ -130,7 +128,6 @@ def hpu_platform_plugin() -> Optional[str]:
|
||||
"habana_frameworks is not found.")
|
||||
except Exception as e:
|
||||
logger.debug("HPU platform is not available because: %s", str(e))
|
||||
pass
|
||||
|
||||
return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None
|
||||
|
||||
@ -148,7 +145,6 @@ def xpu_platform_plugin() -> Optional[str]:
|
||||
logger.debug("Confirmed XPU platform is available.")
|
||||
except Exception as e:
|
||||
logger.debug("XPU platform is not available because: %s", str(e))
|
||||
pass
|
||||
|
||||
return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None
|
||||
|
||||
@ -170,7 +166,6 @@ def cpu_platform_plugin() -> Optional[str]:
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("CPU platform is not available because: %s", str(e))
|
||||
pass
|
||||
|
||||
return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None
|
||||
|
||||
@ -222,8 +217,11 @@ def resolve_current_platform_cls_qualname() -> str:
|
||||
platform_cls_qualname = func()
|
||||
if platform_cls_qualname is not None:
|
||||
activated_plugins.append(name)
|
||||
logger.info("Platform plugin %s loaded.", name)
|
||||
logger.warning(
|
||||
"Platform plugin %s function's return value is None", name)
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception("Failed to load platform plugin %s", name)
|
||||
|
||||
activated_builtin_plugins = list(
|
||||
set(activated_plugins) & set(builtin_platform_plugins.keys()))
|
||||
|
||||
@ -39,8 +39,8 @@ class HpuPlatform(Platform):
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -217,9 +217,9 @@ class RocmPlatform(Platform):
|
||||
major, minor = torch.cuda.get_device_capability(device_id)
|
||||
return DeviceCapability(major=major, minor=minor)
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
@with_amdsmi_context
|
||||
def is_fully_connected(physical_device_ids: list[int]) -> bool:
|
||||
def is_fully_connected(cls, physical_device_ids: list[int]) -> bool:
|
||||
"""
|
||||
Query if the set of gpus are fully connected by xgmi (1 hop)
|
||||
"""
|
||||
|
||||
@ -37,15 +37,17 @@ class XPUPlatform(Platform):
|
||||
logger.info("Using IPEX attention backend.")
|
||||
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
device_id: int = 0) -> Optional[DeviceCapability]:
|
||||
cls,
|
||||
device_id: int = 0,
|
||||
) -> Optional[DeviceCapability]:
|
||||
# capacity format differs from cuda's and will cause unexpected
|
||||
# failure, so use None directly
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.xpu.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
@ -57,8 +59,8 @@ class XPUPlatform(Platform):
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
||||
@classmethod
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user