diff --git a/vllm/lora/ops/ipex_ops/__init__.py b/vllm/lora/ops/ipex_ops/__init__.py new file mode 100644 index 0000000000000..5daa432493b19 --- /dev/null +++ b/vllm/lora/ops/ipex_ops/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.lora.ops.ipex_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink) + +__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/ops/ipex_ops/lora_ops.py b/vllm/lora/ops/ipex_ops/lora_ops.py new file mode 100644 index 0000000000000..7590c868ecb67 --- /dev/null +++ b/vllm/lora/ops/ipex_ops/lora_ops.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +try: + import intel_extension_for_pytorch as ipex +except ImportError as e: + raise e + + +def bgmv_shrink(inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0) -> None: + + ipex.llm.functional.bgmv_shrink(inputs, lora_a_weights, output_tensor, + lora_indices_tensor, scaling) + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True) -> None: + ipex.llm.functional.bgmv_expand(inputs, lora_b_weights, output_tensor, + lora_indices_tensor, add_inputs) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True) -> None: + ipex.llm.functional.bgmv_expand_slice(inputs, lora_b_weights, + output_tensor, lora_indices_tensor, + slice_offset, slice_size, add_inputs) diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py new file mode 100644 index 0000000000000..572e39e0eced0 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -0,0 +1,269 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import Optional, Union, final + +import torch + +from vllm.lora.layers import LoRAMapping +from vllm.lora.ops.ipex_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink + +from .punica_base import PunicaWrapperBase + + +@final +class PunicaWrapperXPU(PunicaWrapperBase): + """ + PunicaWrapperXPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica ipex kernel. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + torch._dynamo.mark_dynamic(self._token_lora_indices, 0) + torch._dynamo.mark_dynamic(self._embeddings_indices, 1) + torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) + + def update_metadata(self, mapping: LoRAMapping, + lora_index_to_id: list[Optional[int]], max_loras: int, + vocab_size: int, extra_vocab_size: int, **kwargs): + + self.is_prefill = mapping.is_prefill + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size) + + def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: + return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) + + def _apply_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x), scale) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: int, + y_slice_size: int, + add_inputs: bool, + ): + token_lora_indices = self._get_token_lora_indices(x) + bgmv_expand_slice(x, w_t_all, y, token_lora_indices, y_offset, + y_slice_size, add_inputs) + + def add_shrink(self, y: torch.Tensor, x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, + ...], scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (torch.Tensor): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.shape[-1]) + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + output_slices: tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensors + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): + bias's weight + output_slices (tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + if lora_bias_stacked is not None: + token_lora_indices = self._get_token_lora_indices(y) + self._apply_bias(token_lora_indices, y, output_slices, + lora_bias_stacked) + + assert x.ndim == 3 + assert x.size(0) == len(output_slices) + + # TODO fuse these kernels + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_start, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_start += output_slices[slice_idx] + y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + token_lora_indices = self._get_token_lora_indices(x) + bgmv_expand(x, lora_b_stacked, y, token_lora_indices, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: tuple[torch.Tensor, ...], + lora_b_stacked: tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], + scale: float, + output_slices: tuple[int, ...], + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (tuple[int, ...]): Every slice's size. + buffer (Optional[torch.Tensor]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + token_lora_indices = self._get_token_lora_indices(y) + y = self._apply_bias(token_lora_indices, y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros( # type: ignore + (len(output_slices), x.size(0), r), + dtype=torch.float32, + device=x.device, + ) + self.add_shrink( + buffer, # type: ignore + x, + lora_a_stacked, + scale, + **kwargs) + self.add_expand( + y, + buffer, # type: ignore + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor): lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]): Default to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) + bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + return y.view_as(y_org) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index c4530c1dfaa3f..1d0bb3654929b 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -67,7 +67,7 @@ class XPUPlatform(Platform): @classmethod def get_punica_wrapper(cls) -> str: - return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" + return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: