From a4c23314c04a0ce3e507cd199d6372fb83cb6732 Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Tue, 8 Jul 2025 22:07:10 +0800 Subject: [PATCH] [xpu]feat: support multi-lora on xpu (#20616) Signed-off-by: yan --- vllm/lora/ops/triton_ops/lora_expand_op.py | 2 ++ vllm/lora/ops/triton_ops/lora_shrink_op.py | 2 ++ vllm/lora/ops/triton_ops/utils.py | 12 +++++++++--- vllm/model_executor/model_loader/tensorizer.py | 5 ++++- vllm/platforms/xpu.py | 11 +++++++++++ 5 files changed, 28 insertions(+), 4 deletions(-) diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index 9e1f90e757cde..eaef8e2c1905e 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -13,6 +13,7 @@ import triton.language as tl from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr +from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -283,6 +284,7 @@ try: op_func=_lora_expand, mutates_args=["output_tensor"], fake_impl=_lora_expand_fake, + dispatch_key=current_platform.dispatch_key, ) lora_expand = torch.ops.vllm.lora_expand diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 3f9edfc6d655c..d299fa5e8e1a5 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -13,6 +13,7 @@ import triton.language as tl from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr +from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -237,6 +238,7 @@ try: op_func=_lora_shrink, mutates_args=["output_tensor"], fake_impl=_lora_shrink_fake, + dispatch_key=current_platform.dispatch_key, ) lora_shrink = torch.ops.vllm.lora_shrink diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 5857f7fecb5b4..4c50fbd270516 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -35,7 +35,9 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): lora_strides_d1.append(lora_a_weight.stride(1)) lora_strides_d2.append(lora_a_weight.stride(2)) if len(lora_a_weights) > 1: - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + lora_ptr_tensor = torch.tensor(tensor_ptrs, + device=device, + dtype=torch.uint64) else: lora_ptr_tensor = lora_a_weights[0] @@ -89,8 +91,12 @@ def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int, if len(lora_weights) > 1: # note these are device tensors - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) - slice_start_tensor = torch.tensor(slice_offset_lst, device=device) + lora_ptr_tensor = torch.tensor(tensor_ptrs, + device=device, + dtype=torch.uint64) + slice_start_tensor = torch.tensor(slice_offset_lst, + device=device, + dtype=torch.uint64) else: slice_start_tensor = slice_offset_lst[0] lora_ptr_tensor = lora_b_weight[0] diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index ff101b664130e..3bf6571a6addf 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -27,6 +27,7 @@ from vllm.config import (ModelConfig, ParallelConfig, VllmConfig, from vllm.logger import init_logger from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser, PlaceholderModule if TYPE_CHECKING: @@ -513,7 +514,9 @@ def deserialize_tensorizer_model(model: nn.Module, **tensorizer_args.stream_kwargs) as stream, TensorDeserializer( stream, dtype=tensorizer_config.dtype, - device=torch.device("cuda", torch.cuda.current_device()), + device=f'xpu:{torch.xpu.current_device()}' + if current_platform.is_xpu() else + f'cuda:{torch.cuda.current_device()}', **tensorizer_args.deserialization_kwargs) as deserializer: deserializer.load_into_module(model) end = time.perf_counter() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index e2871c1064926..9bc2e2c57e996 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -58,6 +58,10 @@ class XPUPlatform(Platform): def get_device_name(cls, device_id: int = 0) -> str: return torch.xpu.get_device_name(device_id) + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" + @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.xpu.get_device_properties(device_id) @@ -78,6 +82,13 @@ class XPUPlatform(Platform): if cache_config and cache_config.block_size is None: cache_config.block_size = 64 + # FIXME: Temporarily forcing eager mode + # remove after t.compile support stabilizes. + if (envs.VLLM_USE_V1 and vllm_config.model_config is not None + and not vllm_config.model_config.enforce_eager): + from vllm.config import CompilationLevel + vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501 + # Instances created using VllmConfig() typically have model_config as # None by default. The modification involves adding a check to prevent # potential null exceptions check and update model config.