From 2385b60d8300ce730ae67d9ea945f06de9ec4e21 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 22 Nov 2024 01:18:11 +0800 Subject: [PATCH] [Kernel] Register punica ops directly (#10522) Signed-off-by: Jee Jee Li --- tests/lora/test_punica_variation.py | 23 ++++++++++++++++------ vllm/lora/ops/bgmv_expand.py | 23 +++++++++++++++++++--- vllm/lora/ops/bgmv_expand_slice.py | 25 +++++++++++++++++++++--- vllm/lora/ops/bgmv_shrink.py | 23 +++++++++++++++++++--- vllm/lora/ops/sgmv_expand.py | 29 +++++++++++++++++++++++++--- vllm/lora/ops/sgmv_expand_slice.py | 30 ++++++++++++++++++++++++++--- vllm/lora/ops/sgmv_shrink.py | 28 ++++++++++++++++++++++++--- 7 files changed, 157 insertions(+), 24 deletions(-) diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 52b82f25d23e1..3b20033271d26 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -6,12 +6,13 @@ maximum ranks. import pytest import torch -from vllm.lora.ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice -from vllm.lora.ops.sgmv_shrink import sgmv_shrink +# Enable custom op register +import vllm.lora.ops.bgmv_expand +import vllm.lora.ops.bgmv_expand_slice +import vllm.lora.ops.bgmv_shrink +import vllm.lora.ops.sgmv_expand +import vllm.lora.ops.sgmv_expand_slice +import vllm.lora.ops.sgmv_shrink # noqa: F401 from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, @@ -37,6 +38,16 @@ def assert_close(a, b): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) +# Unlike test_punica_sizes.py, we directly utilize custom op for +# testing, which verifies the correct registration of these ops. +bgmv_expand = torch.ops.vllm.bgmv_expand +bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice +bgmv_shrink = torch.ops.vllm.bgmv_shrink +sgmv_expand = torch.ops.vllm.sgmv_expand +sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice +sgmv_shrink = torch.ops.vllm.sgmv_shrink + + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index f176259fddc78..42adb191b8ead 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -9,6 +9,8 @@ import torch import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + from .utils import get_lora_op_configs @@ -162,9 +164,24 @@ def _bgmv_expand( return +def bgmv_expand_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +) -> None: + return + + try: - bgmv_expand = torch.library.custom_op("lora::bgmv_expand", - _bgmv_expand, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="bgmv_expand", + op_func=_bgmv_expand, + mutates_args=["output_tensor"], + fake_impl=bgmv_expand_fake, + ) + bgmv_expand = torch.ops.vllm.bgmv_expand + except AttributeError: bgmv_expand = _bgmv_expand diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index 2c6ed96c253f0..f397d752a3ea9 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -9,6 +9,8 @@ import torch import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + from .utils import get_lora_op_configs @@ -179,9 +181,26 @@ def _bgmv_expand_slice( return +def bgmv_expand_slice_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +) -> None: + return + + try: - bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", - _bgmv_expand_slice, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="bgmv_expand_slice", + op_func=_bgmv_expand_slice, + mutates_args=["output_tensor"], + fake_impl=bgmv_expand_slice_fake, + ) + bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice + except AttributeError: bgmv_expand_slice = _bgmv_expand_slice diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/bgmv_shrink.py index 0846ff36b1692..f3ef01d39e776 100644 --- a/vllm/lora/ops/bgmv_shrink.py +++ b/vllm/lora/ops/bgmv_shrink.py @@ -9,6 +9,8 @@ import torch import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + from .utils import get_lora_op_configs @@ -142,9 +144,24 @@ def _bgmv_shrink( return +def bgmv_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +) -> None: + return + + try: - bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", - _bgmv_shrink, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="bgmv_shrink", + op_func=_bgmv_shrink, + mutates_args=["output_tensor"], + fake_impl=bgmv_shrink_fake, + ) + bgmv_shrink = torch.ops.vllm.bgmv_shrink + except AttributeError: bgmv_shrink = _bgmv_shrink diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index ee2cd2e05e2ee..77c5178493c44 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -9,6 +9,8 @@ import torch import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + @triton.jit def _sgmv_expand_kernel( @@ -196,9 +198,30 @@ def _sgmv_expand( return +def sgmv_expand_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +) -> None: + return + + try: - sgmv_expand = torch.library.custom_op("lora::sgmv_expand", - _sgmv_expand, - mutates_args=["output_tensor"]) + + direct_register_custom_op( + op_name="sgmv_expand", + op_func=_sgmv_expand, + mutates_args=["output_tensor"], + fake_impl=sgmv_expand_fake, + ) + sgmv_expand = torch.ops.vllm.sgmv_expand + except AttributeError: sgmv_expand = _sgmv_expand diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 5244fa14913a4..55c4fb68ed128 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -9,6 +9,8 @@ import torch import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + @triton.jit def _sgmv_expand_slice_kernel( @@ -209,9 +211,31 @@ def _sgmv_expand_slice( return +def sgmv_expand_slice_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> None: + return + + try: - sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", - _sgmv_expand_slice, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="sgmv_expand_slice", + op_func=_sgmv_expand_slice, + mutates_args=["output_tensor"], + fake_impl=sgmv_expand_slice_fake, + ) + sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice + except AttributeError: sgmv_expand_slice = _sgmv_expand_slice diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index b4d893047b06b..37d1dc84eebca 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -9,6 +9,8 @@ import torch import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + @triton.jit def _sgmv_shrink_kernel( @@ -190,9 +192,29 @@ def _sgmv_shrink( return +def sgmv_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +) -> None: + return + + try: - sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", - _sgmv_shrink, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="sgmv_shrink", + op_func=_sgmv_shrink, + mutates_args=["output_tensor"], + fake_impl=sgmv_shrink_fake, + ) + sgmv_shrink = torch.ops.vllm.sgmv_shrink + except AttributeError: sgmv_shrink = _sgmv_shrink