diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index 9f1606e672de..a29322465199 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -1,5 +1,6 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import resolve_obj_by_qualname from .punica_base import PunicaWrapperBase @@ -7,20 +8,11 @@ logger = init_logger(__name__) def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: - if current_platform.is_cuda_alike(): - # Lazy import to avoid ImportError - from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU - logger.info_once("Using PunicaWrapperGPU.") - return PunicaWrapperGPU(*args, **kwargs) - elif current_platform.is_cpu(): - # Lazy import to avoid ImportError - from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU - logger.info_once("Using PunicaWrapperCPU.") - return PunicaWrapperCPU(*args, **kwargs) - elif current_platform.is_hpu(): - # Lazy import to avoid ImportError - from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU - logger.info_once("Using PunicaWrapperHPU.") - return PunicaWrapperHPU(*args, **kwargs) - else: - raise NotImplementedError + punica_wrapper_qualname = current_platform.get_punica_wrapper() + punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname) + punica_wrapper = punica_wrapper_cls(*args, **kwargs) + assert punica_wrapper is not None, \ + "the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong." + logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] + + ".") + return punica_wrapper diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index eb3e269cac28..4d3b84fea887 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -109,3 +109,7 @@ class CpuPlatform(Platform): def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on CPU.") return False + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b6a6c461369f..80cefcb49253 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -218,6 +218,10 @@ class CudaPlatformBase(Platform): logger.info("Using Flash Attention backend.") return "vllm.attention.backends.flash_attn.FlashAttentionBackend" + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index b579ebf494bd..242c2c127979 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -63,3 +63,7 @@ class HpuPlatform(Platform): def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on HPU.") return False + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index afa9daa9c98a..3c2ec9636df9 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -276,6 +276,13 @@ class Platform: return False return True + @classmethod + def get_punica_wrapper(cls) -> str: + """ + Return the punica wrapper for current platform. + """ + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 7f1e8aef528a..43105d7855e7 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -153,3 +153,7 @@ class RocmPlatform(Platform): "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" " is not set, enabling VLLM_USE_TRITON_AWQ.") envs.VLLM_USE_TRITON_AWQ = True + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"