[Kernel] Register punica ops directly (#10522)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-11-22 01:18:11 +08:00 committed by GitHub
parent da7e702c6f
commit 2385b60d83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 157 additions and 24 deletions

View File

@ -6,12 +6,13 @@ maximum ranks.
import pytest import pytest
import torch import torch
from vllm.lora.ops.bgmv_expand import bgmv_expand # Enable custom op register
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice import vllm.lora.ops.bgmv_expand
from vllm.lora.ops.bgmv_shrink import bgmv_shrink import vllm.lora.ops.bgmv_expand_slice
from vllm.lora.ops.sgmv_expand import sgmv_expand import vllm.lora.ops.bgmv_shrink
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice import vllm.lora.ops.sgmv_expand
from vllm.lora.ops.sgmv_shrink import sgmv_shrink import vllm.lora.ops.sgmv_expand_slice
import vllm.lora.ops.sgmv_shrink # noqa: F401
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .utils import (generate_data, generate_data_for_expand_nslices, from .utils import (generate_data, generate_data_for_expand_nslices,
@ -37,6 +38,16 @@ def assert_close(a, b):
torch.testing.assert_close(a, b, rtol=rtol, atol=atol) torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
# Unlike test_punica_sizes.py, we directly utilize custom op for
# testing, which verifies the correct registration of these ops.
bgmv_expand = torch.ops.vllm.bgmv_expand
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
bgmv_shrink = torch.ops.vllm.bgmv_shrink
sgmv_expand = torch.ops.vllm.sgmv_expand
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
sgmv_shrink = torch.ops.vllm.sgmv_shrink
@pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("batches", BATCHES)
@pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("num_loras", NUM_LORA)
@pytest.mark.parametrize("rank", MAX_RANKS) @pytest.mark.parametrize("rank", MAX_RANKS)

View File

@ -9,6 +9,8 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs from .utils import get_lora_op_configs
@ -162,9 +164,24 @@ def _bgmv_expand(
return return
def bgmv_expand_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> None:
return
try: try:
bgmv_expand = torch.library.custom_op("lora::bgmv_expand", direct_register_custom_op(
_bgmv_expand, op_name="bgmv_expand",
mutates_args=["output_tensor"]) op_func=_bgmv_expand,
mutates_args=["output_tensor"],
fake_impl=bgmv_expand_fake,
)
bgmv_expand = torch.ops.vllm.bgmv_expand
except AttributeError: except AttributeError:
bgmv_expand = _bgmv_expand bgmv_expand = _bgmv_expand

View File

@ -9,6 +9,8 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs from .utils import get_lora_op_configs
@ -179,9 +181,26 @@ def _bgmv_expand_slice(
return return
def bgmv_expand_slice_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> None:
return
try: try:
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", direct_register_custom_op(
_bgmv_expand_slice, op_name="bgmv_expand_slice",
mutates_args=["output_tensor"]) op_func=_bgmv_expand_slice,
mutates_args=["output_tensor"],
fake_impl=bgmv_expand_slice_fake,
)
bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice
except AttributeError: except AttributeError:
bgmv_expand_slice = _bgmv_expand_slice bgmv_expand_slice = _bgmv_expand_slice

View File

@ -9,6 +9,8 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import direct_register_custom_op
from .utils import get_lora_op_configs from .utils import get_lora_op_configs
@ -142,9 +144,24 @@ def _bgmv_shrink(
return return
def bgmv_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> None:
return
try: try:
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", direct_register_custom_op(
_bgmv_shrink, op_name="bgmv_shrink",
mutates_args=["output_tensor"]) op_func=_bgmv_shrink,
mutates_args=["output_tensor"],
fake_impl=bgmv_shrink_fake,
)
bgmv_shrink = torch.ops.vllm.bgmv_shrink
except AttributeError: except AttributeError:
bgmv_shrink = _bgmv_shrink bgmv_shrink = _bgmv_shrink

View File

@ -9,6 +9,8 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import direct_register_custom_op
@triton.jit @triton.jit
def _sgmv_expand_kernel( def _sgmv_expand_kernel(
@ -196,9 +198,30 @@ def _sgmv_expand(
return return
def sgmv_expand_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
) -> None:
return
try: try:
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand, direct_register_custom_op(
mutates_args=["output_tensor"]) op_name="sgmv_expand",
op_func=_sgmv_expand,
mutates_args=["output_tensor"],
fake_impl=sgmv_expand_fake,
)
sgmv_expand = torch.ops.vllm.sgmv_expand
except AttributeError: except AttributeError:
sgmv_expand = _sgmv_expand sgmv_expand = _sgmv_expand

View File

@ -9,6 +9,8 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import direct_register_custom_op
@triton.jit @triton.jit
def _sgmv_expand_slice_kernel( def _sgmv_expand_slice_kernel(
@ -209,9 +211,31 @@ def _sgmv_expand_slice(
return return
def sgmv_expand_slice_fake(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
) -> None:
return
try: try:
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", direct_register_custom_op(
_sgmv_expand_slice, op_name="sgmv_expand_slice",
mutates_args=["output_tensor"]) op_func=_sgmv_expand_slice,
mutates_args=["output_tensor"],
fake_impl=sgmv_expand_slice_fake,
)
sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice
except AttributeError: except AttributeError:
sgmv_expand_slice = _sgmv_expand_slice sgmv_expand_slice = _sgmv_expand_slice

View File

@ -9,6 +9,8 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm.utils import direct_register_custom_op
@triton.jit @triton.jit
def _sgmv_shrink_kernel( def _sgmv_shrink_kernel(
@ -190,9 +192,29 @@ def _sgmv_shrink(
return return
def sgmv_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
) -> None:
return
try: try:
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", direct_register_custom_op(
_sgmv_shrink, op_name="sgmv_shrink",
mutates_args=["output_tensor"]) op_func=_sgmv_shrink,
mutates_args=["output_tensor"],
fake_impl=sgmv_shrink_fake,
)
sgmv_shrink = torch.ops.vllm.sgmv_shrink
except AttributeError: except AttributeError:
sgmv_shrink = _sgmv_shrink sgmv_shrink = _sgmv_shrink