diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index dcaf2e3d462c..0bbc1844ef45 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ -from typing import Dict, Optional - import torch import triton import triton.language as tl @@ -86,14 +84,13 @@ def _bgmv_expand_kernel( @torch.inference_mode() -def bgmv_expand( +def _bgmv_expand( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, add_inputs: bool = True, - override_config: Optional[Dict[str, int]] = None, -): +) -> None: """ Args: inputs (torch.Tensor): input tensor @@ -105,10 +102,7 @@ def bgmv_expand( batches (int): batch size add_inputs (bool, optional): Defaults to False. adds the final lora results to the output. - override_config (Optional[Dict[str, int]], optional): Defaults to None. - Triton grid config """ - assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert lora_b_weights.dtype in [ torch.float16, @@ -138,10 +132,7 @@ def bgmv_expand( ]: CAST_TYPE = True batches = lora_indices_tensor.size(0) - if override_config: - config = override_config - else: - config = get_lora_op_configs("expand", batches, N) + config = get_lora_op_configs("expand", batches, N) grid = lambda META: ( META["SPLIT_N"], batches, @@ -167,3 +158,8 @@ def bgmv_expand( **config, ) return + + +bgmv_expand = torch.library.custom_op("lora::bgmv_expand", + _bgmv_expand, + mutates_args=["output_tensor"]) diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index fa6571074f3a..87d7d9902a4c 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ -from typing import Dict, Optional - import torch import triton import triton.language as tl @@ -89,7 +87,7 @@ def _bgmv_expand_slice_kernel( @torch.inference_mode() -def bgmv_expand_slice( +def _bgmv_expand_slice( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -97,8 +95,7 @@ def bgmv_expand_slice( slice_offset: int, slice_size: int, add_inputs: bool = True, - override_config: Optional[Dict[str, int]] = None, -): +) -> None: """ Args: inputs (torch.Tensor): input tensor @@ -111,10 +108,7 @@ def bgmv_expand_slice( slice_size (int): current output_tensor's size batches (int): batch size add_inputs (bool, optional): Defaults to False. - override_config (Optional[Dict[str, int]], optional): Defaults to None. - Triton grid config """ - assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert lora_b_weights.dtype in [ torch.float16, @@ -149,10 +143,7 @@ def bgmv_expand_slice( batches = lora_indices_tensor.size(0) - if override_config: - config = override_config - else: - config = get_lora_op_configs("expand", batches, N) + config = get_lora_op_configs("expand", batches, N) grid = lambda META: ( META["SPLIT_N"], @@ -180,3 +171,8 @@ def bgmv_expand_slice( **config, ) return + + +bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", + _bgmv_expand_slice, + mutates_args=["output_tensor"]) diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/bgmv_shrink.py index e69d33078f5a..c979d758492d 100644 --- a/vllm/lora/ops/bgmv_shrink.py +++ b/vllm/lora/ops/bgmv_shrink.py @@ -5,8 +5,6 @@ Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ -from typing import Dict, Optional - import torch import triton import triton.language as tl @@ -78,14 +76,13 @@ def _bgmv_shrink_kernel( @torch.inference_mode() -def bgmv_shrink( +def _bgmv_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, output_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, scaling: float = 1.0, - override_config: Optional[Dict[str, int]] = None, -): +) -> None: """ Args: inputs (torch.Tensor): input tensor @@ -96,8 +93,6 @@ def bgmv_shrink( applied. batches (int): batch size scaling (float): Scaling factor. - override_config (Optional[Dict[str, int]], optional): Defaults to None. - Triton grid config """ assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype in [torch.float16, torch.bfloat16] @@ -119,11 +114,8 @@ def bgmv_shrink( batches = lora_indices_tensor.size(0) N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank BLOCK_N = triton.next_power_of_2(N) - if override_config: - config = override_config - else: - # First try to load optimal config from the file - config = get_lora_op_configs("bgmv_shrink", batches, K) + # First try to load optimal config from the file + config = get_lora_op_configs("bgmv_shrink", batches, K) grid = lambda META: ( META["SPLIT_K"], @@ -148,3 +140,8 @@ def bgmv_shrink( **config, ) return + + +bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", + _bgmv_shrink, + mutates_args=["output_tensor"]) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 459049546909..80a0b605b0fe 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -97,7 +97,7 @@ def _sgmv_expand_kernel( @torch.inference_mode() -def sgmv_expand( +def _sgmv_expand( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -107,7 +107,7 @@ def sgmv_expand( batches: int, max_seq_length: int, add_inputs: bool = False, -): +) -> None: """ Args: inputs (torch.Tensor): input tensor @@ -190,3 +190,8 @@ def sgmv_expand( CAST_TYPE, ) return + + +sgmv_expand = torch.library.custom_op("lora::sgmv_expand", + _sgmv_expand, + mutates_args=["output_tensor"]) diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index ff3bcda071b8..53237166a1c6 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -103,7 +103,7 @@ def _sgmv_expand_slice_kernel( @torch.inference_mode() -def sgmv_expand_slice( +def _sgmv_expand_slice( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -115,7 +115,7 @@ def sgmv_expand_slice( slice_offset: int, slice_size: int, add_inputs: bool = False, -): +) -> None: """_summary_ Args: @@ -203,3 +203,8 @@ def sgmv_expand_slice( CAST_TYPE, ) return + + +sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", + _sgmv_expand_slice, + mutates_args=["output_tensor"]) diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 8ab049989abe..51d2a09eee94 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -101,7 +101,7 @@ def _sgmv_shrink_kernel( @torch.inference_mode() -def sgmv_shrink( +def _sgmv_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, output_tensor: torch.Tensor, @@ -111,7 +111,7 @@ def sgmv_shrink( batches: int, max_seq_length: int, scaling: float, -): +) -> None: """ Args: @@ -187,3 +187,8 @@ def sgmv_shrink( SPLIT_K, ) return + + +sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", + _sgmv_shrink, + mutates_args=["output_tensor"])