mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 00:34:58 +08:00
[Kernel] register punica functions as torch ops (#7591)
This commit is contained in:
parent
d4f0f17b02
commit
9f69856356
@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving.
|
|||||||
https://arxiv.org/abs/2310.18547
|
https://arxiv.org/abs/2310.18547
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@ -86,14 +84,13 @@ def _bgmv_expand_kernel(
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def bgmv_expand(
|
def _bgmv_expand(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_b_weights: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
add_inputs: bool = True,
|
add_inputs: bool = True,
|
||||||
override_config: Optional[Dict[str, int]] = None,
|
) -> None:
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): input tensor
|
inputs (torch.Tensor): input tensor
|
||||||
@ -105,10 +102,7 @@ def bgmv_expand(
|
|||||||
batches (int): batch size
|
batches (int): batch size
|
||||||
add_inputs (bool, optional): Defaults to False. adds the final lora
|
add_inputs (bool, optional): Defaults to False. adds the final lora
|
||||||
results to the output.
|
results to the output.
|
||||||
override_config (Optional[Dict[str, int]], optional): Defaults to None.
|
|
||||||
Triton grid config
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||||
assert lora_b_weights.dtype in [
|
assert lora_b_weights.dtype in [
|
||||||
torch.float16,
|
torch.float16,
|
||||||
@ -138,10 +132,7 @@ def bgmv_expand(
|
|||||||
]:
|
]:
|
||||||
CAST_TYPE = True
|
CAST_TYPE = True
|
||||||
batches = lora_indices_tensor.size(0)
|
batches = lora_indices_tensor.size(0)
|
||||||
if override_config:
|
config = get_lora_op_configs("expand", batches, N)
|
||||||
config = override_config
|
|
||||||
else:
|
|
||||||
config = get_lora_op_configs("expand", batches, N)
|
|
||||||
grid = lambda META: (
|
grid = lambda META: (
|
||||||
META["SPLIT_N"],
|
META["SPLIT_N"],
|
||||||
batches,
|
batches,
|
||||||
@ -167,3 +158,8 @@ def bgmv_expand(
|
|||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
|
||||||
|
_bgmv_expand,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
|||||||
@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving.
|
|||||||
https://arxiv.org/abs/2310.18547
|
https://arxiv.org/abs/2310.18547
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@ -89,7 +87,7 @@ def _bgmv_expand_slice_kernel(
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def bgmv_expand_slice(
|
def _bgmv_expand_slice(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_b_weights: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
@ -97,8 +95,7 @@ def bgmv_expand_slice(
|
|||||||
slice_offset: int,
|
slice_offset: int,
|
||||||
slice_size: int,
|
slice_size: int,
|
||||||
add_inputs: bool = True,
|
add_inputs: bool = True,
|
||||||
override_config: Optional[Dict[str, int]] = None,
|
) -> None:
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): input tensor
|
inputs (torch.Tensor): input tensor
|
||||||
@ -111,10 +108,7 @@ def bgmv_expand_slice(
|
|||||||
slice_size (int): current output_tensor's size
|
slice_size (int): current output_tensor's size
|
||||||
batches (int): batch size
|
batches (int): batch size
|
||||||
add_inputs (bool, optional): Defaults to False.
|
add_inputs (bool, optional): Defaults to False.
|
||||||
override_config (Optional[Dict[str, int]], optional): Defaults to None.
|
|
||||||
Triton grid config
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||||
assert lora_b_weights.dtype in [
|
assert lora_b_weights.dtype in [
|
||||||
torch.float16,
|
torch.float16,
|
||||||
@ -149,10 +143,7 @@ def bgmv_expand_slice(
|
|||||||
|
|
||||||
batches = lora_indices_tensor.size(0)
|
batches = lora_indices_tensor.size(0)
|
||||||
|
|
||||||
if override_config:
|
config = get_lora_op_configs("expand", batches, N)
|
||||||
config = override_config
|
|
||||||
else:
|
|
||||||
config = get_lora_op_configs("expand", batches, N)
|
|
||||||
|
|
||||||
grid = lambda META: (
|
grid = lambda META: (
|
||||||
META["SPLIT_N"],
|
META["SPLIT_N"],
|
||||||
@ -180,3 +171,8 @@ def bgmv_expand_slice(
|
|||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
|
||||||
|
_bgmv_expand_slice,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
|||||||
@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving.
|
|||||||
https://arxiv.org/abs/2310.18547
|
https://arxiv.org/abs/2310.18547
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@ -78,14 +76,13 @@ def _bgmv_shrink_kernel(
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def bgmv_shrink(
|
def _bgmv_shrink(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_a_weights: torch.Tensor,
|
lora_a_weights: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
lora_indices_tensor: torch.Tensor,
|
lora_indices_tensor: torch.Tensor,
|
||||||
scaling: float = 1.0,
|
scaling: float = 1.0,
|
||||||
override_config: Optional[Dict[str, int]] = None,
|
) -> None:
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): input tensor
|
inputs (torch.Tensor): input tensor
|
||||||
@ -96,8 +93,6 @@ def bgmv_shrink(
|
|||||||
applied.
|
applied.
|
||||||
batches (int): batch size
|
batches (int): batch size
|
||||||
scaling (float): Scaling factor.
|
scaling (float): Scaling factor.
|
||||||
override_config (Optional[Dict[str, int]], optional): Defaults to None.
|
|
||||||
Triton grid config
|
|
||||||
"""
|
"""
|
||||||
assert inputs.dtype == lora_a_weights.dtype
|
assert inputs.dtype == lora_a_weights.dtype
|
||||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||||
@ -119,11 +114,8 @@ def bgmv_shrink(
|
|||||||
batches = lora_indices_tensor.size(0)
|
batches = lora_indices_tensor.size(0)
|
||||||
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
||||||
BLOCK_N = triton.next_power_of_2(N)
|
BLOCK_N = triton.next_power_of_2(N)
|
||||||
if override_config:
|
# First try to load optimal config from the file
|
||||||
config = override_config
|
config = get_lora_op_configs("bgmv_shrink", batches, K)
|
||||||
else:
|
|
||||||
# First try to load optimal config from the file
|
|
||||||
config = get_lora_op_configs("bgmv_shrink", batches, K)
|
|
||||||
|
|
||||||
grid = lambda META: (
|
grid = lambda META: (
|
||||||
META["SPLIT_K"],
|
META["SPLIT_K"],
|
||||||
@ -148,3 +140,8 @@ def bgmv_shrink(
|
|||||||
**config,
|
**config,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
|
||||||
|
_bgmv_shrink,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
|||||||
@ -97,7 +97,7 @@ def _sgmv_expand_kernel(
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sgmv_expand(
|
def _sgmv_expand(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_b_weights: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
@ -107,7 +107,7 @@ def sgmv_expand(
|
|||||||
batches: int,
|
batches: int,
|
||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
add_inputs: bool = False,
|
add_inputs: bool = False,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
inputs (torch.Tensor): input tensor
|
inputs (torch.Tensor): input tensor
|
||||||
@ -190,3 +190,8 @@ def sgmv_expand(
|
|||||||
CAST_TYPE,
|
CAST_TYPE,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
|
||||||
|
_sgmv_expand,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
|||||||
@ -103,7 +103,7 @@ def _sgmv_expand_slice_kernel(
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sgmv_expand_slice(
|
def _sgmv_expand_slice(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_b_weights: torch.Tensor,
|
lora_b_weights: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
@ -115,7 +115,7 @@ def sgmv_expand_slice(
|
|||||||
slice_offset: int,
|
slice_offset: int,
|
||||||
slice_size: int,
|
slice_size: int,
|
||||||
add_inputs: bool = False,
|
add_inputs: bool = False,
|
||||||
):
|
) -> None:
|
||||||
"""_summary_
|
"""_summary_
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -203,3 +203,8 @@ def sgmv_expand_slice(
|
|||||||
CAST_TYPE,
|
CAST_TYPE,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
|
||||||
|
_sgmv_expand_slice,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
|||||||
@ -101,7 +101,7 @@ def _sgmv_shrink_kernel(
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sgmv_shrink(
|
def _sgmv_shrink(
|
||||||
inputs: torch.Tensor,
|
inputs: torch.Tensor,
|
||||||
lora_a_weights: torch.Tensor,
|
lora_a_weights: torch.Tensor,
|
||||||
output_tensor: torch.Tensor,
|
output_tensor: torch.Tensor,
|
||||||
@ -111,7 +111,7 @@ def sgmv_shrink(
|
|||||||
batches: int,
|
batches: int,
|
||||||
max_seq_length: int,
|
max_seq_length: int,
|
||||||
scaling: float,
|
scaling: float,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -187,3 +187,8 @@ def sgmv_shrink(
|
|||||||
SPLIT_K,
|
SPLIT_K,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
|
||||||
|
_sgmv_shrink,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user