mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 09:35:53 +08:00
[Platform] Move get_punica_wrapper() function to Platform (#11516)
Signed-off-by: Shanshan Shen <467638484@qq.com>
This commit is contained in:
parent
458e63a2c6
commit
a7d59688fb
@ -1,5 +1,6 @@
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import resolve_obj_by_qualname
|
||||||
|
|
||||||
from .punica_base import PunicaWrapperBase
|
from .punica_base import PunicaWrapperBase
|
||||||
|
|
||||||
@ -7,20 +8,11 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
|
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
|
||||||
if current_platform.is_cuda_alike():
|
punica_wrapper_qualname = current_platform.get_punica_wrapper()
|
||||||
# Lazy import to avoid ImportError
|
punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname)
|
||||||
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
|
punica_wrapper = punica_wrapper_cls(*args, **kwargs)
|
||||||
logger.info_once("Using PunicaWrapperGPU.")
|
assert punica_wrapper is not None, \
|
||||||
return PunicaWrapperGPU(*args, **kwargs)
|
"the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
|
||||||
elif current_platform.is_cpu():
|
logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] +
|
||||||
# Lazy import to avoid ImportError
|
".")
|
||||||
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
|
return punica_wrapper
|
||||||
logger.info_once("Using PunicaWrapperCPU.")
|
|
||||||
return PunicaWrapperCPU(*args, **kwargs)
|
|
||||||
elif current_platform.is_hpu():
|
|
||||||
# Lazy import to avoid ImportError
|
|
||||||
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
|
|
||||||
logger.info_once("Using PunicaWrapperHPU.")
|
|
||||||
return PunicaWrapperHPU(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|||||||
@ -109,3 +109,7 @@ class CpuPlatform(Platform):
|
|||||||
def is_pin_memory_available(cls) -> bool:
|
def is_pin_memory_available(cls) -> bool:
|
||||||
logger.warning("Pin memory is not supported on CPU.")
|
logger.warning("Pin memory is not supported on CPU.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_punica_wrapper(cls) -> str:
|
||||||
|
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"
|
||||||
|
|||||||
@ -218,6 +218,10 @@ class CudaPlatformBase(Platform):
|
|||||||
logger.info("Using Flash Attention backend.")
|
logger.info("Using Flash Attention backend.")
|
||||||
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
|
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_punica_wrapper(cls) -> str:
|
||||||
|
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
|
|||||||
@ -63,3 +63,7 @@ class HpuPlatform(Platform):
|
|||||||
def is_pin_memory_available(cls):
|
def is_pin_memory_available(cls):
|
||||||
logger.warning("Pin memory is not supported on HPU.")
|
logger.warning("Pin memory is not supported on HPU.")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_punica_wrapper(cls) -> str:
|
||||||
|
return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"
|
||||||
|
|||||||
@ -276,6 +276,13 @@ class Platform:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_punica_wrapper(cls) -> str:
|
||||||
|
"""
|
||||||
|
Return the punica wrapper for current platform.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class UnspecifiedPlatform(Platform):
|
class UnspecifiedPlatform(Platform):
|
||||||
_enum = PlatformEnum.UNSPECIFIED
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
|||||||
@ -153,3 +153,7 @@ class RocmPlatform(Platform):
|
|||||||
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
|
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
|
||||||
" is not set, enabling VLLM_USE_TRITON_AWQ.")
|
" is not set, enabling VLLM_USE_TRITON_AWQ.")
|
||||||
envs.VLLM_USE_TRITON_AWQ = True
|
envs.VLLM_USE_TRITON_AWQ = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_punica_wrapper(cls) -> str:
|
||||||
|
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user