From a7d59688fb75827db4316c24a057ac6097114bd3 Mon Sep 17 00:00:00 2001 From: Shanshan Shen <467638484@qq.com> Date: Mon, 13 Jan 2025 21:12:10 +0800 Subject: [PATCH] [Platform] Move get_punica_wrapper() function to Platform (#11516) Signed-off-by: Shanshan Shen <467638484@qq.com> --- vllm/lora/punica_wrapper/punica_selector.py | 26 +++++++-------------- vllm/platforms/cpu.py | 4 ++++ vllm/platforms/cuda.py | 4 ++++ vllm/platforms/hpu.py | 4 ++++ vllm/platforms/interface.py | 7 ++++++ vllm/platforms/rocm.py | 4 ++++ 6 files changed, 32 insertions(+), 17 deletions(-) diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index 9f1606e672dea..a293224651992 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -1,5 +1,6 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import resolve_obj_by_qualname from .punica_base import PunicaWrapperBase @@ -7,20 +8,11 @@ logger = init_logger(__name__) def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: - if current_platform.is_cuda_alike(): - # Lazy import to avoid ImportError - from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU - logger.info_once("Using PunicaWrapperGPU.") - return PunicaWrapperGPU(*args, **kwargs) - elif current_platform.is_cpu(): - # Lazy import to avoid ImportError - from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU - logger.info_once("Using PunicaWrapperCPU.") - return PunicaWrapperCPU(*args, **kwargs) - elif current_platform.is_hpu(): - # Lazy import to avoid ImportError - from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU - logger.info_once("Using PunicaWrapperHPU.") - return PunicaWrapperHPU(*args, **kwargs) - else: - raise NotImplementedError + punica_wrapper_qualname = current_platform.get_punica_wrapper() + punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname) + punica_wrapper = punica_wrapper_cls(*args, **kwargs) + assert punica_wrapper is not None, \ + "the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong." + logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] + + ".") + return punica_wrapper diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index eb3e269cac285..4d3b84fea887f 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -109,3 +109,7 @@ class CpuPlatform(Platform): def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on CPU.") return False + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index b6a6c461369f9..80cefcb492531 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -218,6 +218,10 @@ class CudaPlatformBase(Platform): logger.info("Using Flash Attention backend.") return "vllm.attention.backends.flash_attn.FlashAttentionBackend" + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index b579ebf494bdc..242c2c127979a 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -63,3 +63,7 @@ class HpuPlatform(Platform): def is_pin_memory_available(cls): logger.warning("Pin memory is not supported on HPU.") return False + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index afa9daa9c98a7..3c2ec9636df91 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -276,6 +276,13 @@ class Platform: return False return True + @classmethod + def get_punica_wrapper(cls) -> str: + """ + Return the punica wrapper for current platform. + """ + raise NotImplementedError + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 7f1e8aef528a6..43105d7855e79 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -153,3 +153,7 @@ class RocmPlatform(Platform): "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" " is not set, enabling VLLM_USE_TRITON_AWQ.") envs.VLLM_USE_TRITON_AWQ = True + + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"