mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-14 23:37:57 +08:00
[LoRA][Kernel] Remove the unused libentry module (#10214)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
58170d6503
commit
36e4acd02a
@ -4,8 +4,6 @@ hidden_sizes included in the LoRA models currently supported by vLLM. It tests
|
|||||||
whether the corresponding Triton kernel can run normally when tensor parallelism
|
whether the corresponding Triton kernel can run normally when tensor parallelism
|
||||||
is set to [1, 2, 4, 8, 16, 32, 64].
|
is set to [1, 2, 4, 8, 16, 32, 64].
|
||||||
"""
|
"""
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -16,7 +14,6 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
|
|||||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils.libentry import LibEntry
|
|
||||||
|
|
||||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
from .utils import (generate_data, generate_data_for_expand_nslices,
|
||||||
ref_torch_groupgemm)
|
ref_torch_groupgemm)
|
||||||
@ -235,9 +232,6 @@ def test_punica_bgmv(
|
|||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
):
|
):
|
||||||
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
|
|
||||||
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
|
|
||||||
|
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
@ -262,33 +256,21 @@ def test_punica_bgmv(
|
|||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
if op_type == "shrink":
|
if op_type == "shrink":
|
||||||
# The current _bgmv_shrink_kernel does not require the libentry
|
bgmv_shrink(
|
||||||
# decoration. The purpose of adding this patch is to test the
|
inputs_tensor,
|
||||||
# correctness of libentry.
|
lora_weights,
|
||||||
with patch(
|
our_out_tensor,
|
||||||
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
|
indices,
|
||||||
LibEntry(_bgmv_shrink_kernel),
|
scaling,
|
||||||
):
|
)
|
||||||
bgmv_shrink(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_out_tensor,
|
|
||||||
indices,
|
|
||||||
scaling,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# ditto
|
bgmv_expand(
|
||||||
with patch(
|
inputs_tensor,
|
||||||
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
|
lora_weights,
|
||||||
LibEntry(_bgmv_expand_kernel),
|
our_out_tensor,
|
||||||
):
|
indices,
|
||||||
bgmv_expand(
|
add_inputs=True,
|
||||||
inputs_tensor,
|
)
|
||||||
lora_weights,
|
|
||||||
our_out_tensor,
|
|
||||||
indices,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
ref_torch_groupgemm(
|
ref_torch_groupgemm(
|
||||||
ref_out_tensor,
|
ref_out_tensor,
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
@ -324,7 +306,6 @@ def test_punica_expand_nslices(
|
|||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
):
|
):
|
||||||
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
|
|
||||||
|
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
@ -374,22 +355,16 @@ def test_punica_expand_nslices(
|
|||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# The current _bgmv_expand_slice_kernel does not require the
|
|
||||||
# libentry decoration. The purpose of adding this patch is to test
|
bgmv_expand_slice(
|
||||||
# the correctness of libentry.
|
inputs_tensor,
|
||||||
with patch(
|
lora_weights,
|
||||||
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
|
our_outputs,
|
||||||
LibEntry(_bgmv_expand_slice_kernel),
|
indices,
|
||||||
):
|
slice_offset,
|
||||||
bgmv_expand_slice(
|
slice_size=hidden_size,
|
||||||
inputs_tensor,
|
add_inputs=True,
|
||||||
lora_weights,
|
)
|
||||||
our_outputs,
|
|
||||||
indices,
|
|
||||||
slice_offset,
|
|
||||||
slice_size=hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
ref_torch_groupgemm(
|
ref_torch_groupgemm(
|
||||||
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
|
|||||||
@ -3,8 +3,6 @@ This script is mainly used to test whether trtion kernels can run normally
|
|||||||
under different conditions, including various batches, numbers of LoRA , and
|
under different conditions, including various batches, numbers of LoRA , and
|
||||||
maximum ranks.
|
maximum ranks.
|
||||||
"""
|
"""
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -15,7 +13,6 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand
|
|||||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils.libentry import LibEntry
|
|
||||||
|
|
||||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
from .utils import (generate_data, generate_data_for_expand_nslices,
|
||||||
ref_torch_groupgemm)
|
ref_torch_groupgemm)
|
||||||
@ -150,8 +147,6 @@ def test_punica_bgmv(
|
|||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
):
|
):
|
||||||
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
|
|
||||||
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
|
|
||||||
|
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
@ -177,33 +172,22 @@ def test_punica_bgmv(
|
|||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
if op_type == "shrink":
|
if op_type == "shrink":
|
||||||
# The current _bgmv_shrink_kernel does not require the libentry
|
bgmv_shrink(
|
||||||
# decoration. The purpose of adding this patch is to test the
|
inputs_tensor,
|
||||||
# correctness of libentry.
|
lora_weights,
|
||||||
with patch(
|
our_out_tensor,
|
||||||
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
|
indices,
|
||||||
LibEntry(_bgmv_shrink_kernel),
|
scaling,
|
||||||
):
|
)
|
||||||
bgmv_shrink(
|
|
||||||
inputs_tensor,
|
|
||||||
lora_weights,
|
|
||||||
our_out_tensor,
|
|
||||||
indices,
|
|
||||||
scaling,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# ditto
|
|
||||||
with patch(
|
bgmv_expand(
|
||||||
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
|
inputs_tensor,
|
||||||
LibEntry(_bgmv_expand_kernel),
|
lora_weights,
|
||||||
):
|
our_out_tensor,
|
||||||
bgmv_expand(
|
indices,
|
||||||
inputs_tensor,
|
add_inputs=True,
|
||||||
lora_weights,
|
)
|
||||||
our_out_tensor,
|
|
||||||
indices,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
ref_torch_groupgemm(
|
ref_torch_groupgemm(
|
||||||
ref_out_tensor,
|
ref_out_tensor,
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
@ -239,8 +223,6 @@ def test_punica_expand_nslices(
|
|||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
):
|
):
|
||||||
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
|
|
||||||
|
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
current_platform.seed_everything(seed)
|
current_platform.seed_everything(seed)
|
||||||
|
|
||||||
@ -289,22 +271,15 @@ def test_punica_expand_nslices(
|
|||||||
add_inputs=True,
|
add_inputs=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# The current _bgmv_expand_slice_kernel does not require the
|
bgmv_expand_slice(
|
||||||
# libentry decoration. The purpose of adding this patch is to test
|
inputs_tensor,
|
||||||
# the correctness of libentry.
|
lora_weights,
|
||||||
with patch(
|
our_outputs,
|
||||||
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
|
indices,
|
||||||
LibEntry(_bgmv_expand_slice_kernel),
|
slice_offset,
|
||||||
):
|
slice_size=hidden_size,
|
||||||
bgmv_expand_slice(
|
add_inputs=True,
|
||||||
inputs_tensor,
|
)
|
||||||
lora_weights,
|
|
||||||
our_outputs,
|
|
||||||
indices,
|
|
||||||
slice_offset,
|
|
||||||
slice_size=hidden_size,
|
|
||||||
add_inputs=True,
|
|
||||||
)
|
|
||||||
ref_torch_groupgemm(
|
ref_torch_groupgemm(
|
||||||
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
||||||
inputs_tensor,
|
inputs_tensor,
|
||||||
|
|||||||
@ -9,10 +9,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm.triton_utils import libentry
|
|
||||||
|
|
||||||
|
|
||||||
@libentry()
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _sgmv_expand_kernel(
|
def _sgmv_expand_kernel(
|
||||||
input_ptr,
|
input_ptr,
|
||||||
|
|||||||
@ -9,10 +9,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm.triton_utils import libentry
|
|
||||||
|
|
||||||
|
|
||||||
@libentry()
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _sgmv_expand_slice_kernel(
|
def _sgmv_expand_slice_kernel(
|
||||||
input_ptr,
|
input_ptr,
|
||||||
|
|||||||
@ -9,10 +9,7 @@ import torch
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from vllm.triton_utils import libentry
|
|
||||||
|
|
||||||
|
|
||||||
@libentry()
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _sgmv_shrink_kernel(
|
def _sgmv_shrink_kernel(
|
||||||
input_ptr,
|
input_ptr,
|
||||||
|
|||||||
@ -6,6 +6,5 @@ if HAS_TRITON:
|
|||||||
|
|
||||||
from vllm.triton_utils.custom_cache_manager import (
|
from vllm.triton_utils.custom_cache_manager import (
|
||||||
maybe_set_triton_cache_manager)
|
maybe_set_triton_cache_manager)
|
||||||
from vllm.triton_utils.libentry import libentry
|
|
||||||
|
|
||||||
__all__ += ["maybe_set_triton_cache_manager", "libentry"]
|
__all__ += ["maybe_set_triton_cache_manager"]
|
||||||
|
|||||||
@ -1,167 +0,0 @@
|
|||||||
# Copied From https://github.com/FlagOpen/FlagGems
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
import triton
|
|
||||||
|
|
||||||
|
|
||||||
class LibEntry(triton.KernelInterface):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
fn,
|
|
||||||
):
|
|
||||||
self.fn = fn
|
|
||||||
self.arg_names = fn.arg_names
|
|
||||||
self.divisibility = 16
|
|
||||||
self.kernel_cache = dict()
|
|
||||||
fn = self.fn
|
|
||||||
while not isinstance(fn, triton.runtime.JITFunction):
|
|
||||||
fn = fn.fn
|
|
||||||
self.jit_function: triton.runtime.JITFunction = fn
|
|
||||||
self.specialize_indices = [
|
|
||||||
p.num for p in self.jit_function.params
|
|
||||||
if not p.is_constexpr and not p.do_not_specialize
|
|
||||||
]
|
|
||||||
self.do_not_specialize_indices = [
|
|
||||||
p.num for p in self.jit_function.params
|
|
||||||
if not p.is_constexpr and p.do_not_specialize
|
|
||||||
]
|
|
||||||
|
|
||||||
def key(self, spec_args, dns_args, const_args):
|
|
||||||
spec_key = [(arg.dtype, arg.data_ptr() %
|
|
||||||
self.divisibility == 0) if hasattr(arg, "data_ptr") else
|
|
||||||
(type(arg), arg) for arg in spec_args]
|
|
||||||
dns_key = [
|
|
||||||
arg.dtype if hasattr(
|
|
||||||
arg, "data_ptr") else type(arg) if not isinstance(arg, int)
|
|
||||||
else "i32" if arg >= -(2**31) and arg <= 2**31 -
|
|
||||||
1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64"
|
|
||||||
for arg in dns_args
|
|
||||||
]
|
|
||||||
# const args passed by position
|
|
||||||
return tuple(spec_key + dns_key + const_args)
|
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
|
||||||
grid = kwargs["grid"]
|
|
||||||
# collect all the arguments
|
|
||||||
spec_args = [] # specialize arguments
|
|
||||||
dns_args = [] # do not specialize arguments
|
|
||||||
const_args = [] # constexpr arguments
|
|
||||||
k_args = [] # kernel arguments
|
|
||||||
for i, arg in enumerate(args):
|
|
||||||
if i in self.specialize_indices:
|
|
||||||
k_args.append(arg)
|
|
||||||
spec_args.append(arg)
|
|
||||||
elif i in self.do_not_specialize_indices:
|
|
||||||
k_args.append(arg)
|
|
||||||
dns_args.append(arg)
|
|
||||||
else:
|
|
||||||
const_args.append(arg)
|
|
||||||
for p in self.jit_function.params[len(args):]:
|
|
||||||
if p.name in kwargs:
|
|
||||||
val = kwargs[p.name]
|
|
||||||
elif p.default is inspect._empty:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
val = p.default
|
|
||||||
|
|
||||||
if p.is_constexpr:
|
|
||||||
const_args.append(val)
|
|
||||||
elif p.do_not_specialize:
|
|
||||||
dns_args.append(val)
|
|
||||||
k_args.append(val)
|
|
||||||
else:
|
|
||||||
spec_args.append(val)
|
|
||||||
k_args.append(val)
|
|
||||||
|
|
||||||
entry_key = self.key(spec_args, dns_args, const_args)
|
|
||||||
|
|
||||||
if entry_key not in self.kernel_cache:
|
|
||||||
# compile the kernel also completes the related computations
|
|
||||||
kernel = self.fn.run(*args, **kwargs)
|
|
||||||
fn = self.fn
|
|
||||||
# collect constexpr arguments for grid computation
|
|
||||||
constexprs = {}
|
|
||||||
while not isinstance(fn, triton.runtime.JITFunction):
|
|
||||||
if isinstance(fn, triton.runtime.Autotuner):
|
|
||||||
config = fn.best_config
|
|
||||||
constexprs["num_warps"] = config.num_warps
|
|
||||||
constexprs["num_stages"] = config.num_stages
|
|
||||||
constexprs["num_ctas"] = config.num_ctas
|
|
||||||
constexprs = {**constexprs, **config.kwargs}
|
|
||||||
elif isinstance(fn, triton.runtime.Heuristics):
|
|
||||||
for v, heur in fn.values.items():
|
|
||||||
constexprs[v] = heur({
|
|
||||||
**dict(zip(fn.arg_names, args)),
|
|
||||||
**kwargs,
|
|
||||||
**constexprs,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Invalid Runtime Function")
|
|
||||||
fn = fn.fn
|
|
||||||
# In vLLM, certain kernels like fused_moe_kernel get the
|
|
||||||
# best_config(as kwargs) from a configuration json file, rather
|
|
||||||
# than using Autotuner & Heuristics. Therefore, all their constexprs
|
|
||||||
# (tl.constexpr) are assigned values through the following loop.
|
|
||||||
for p in self.jit_function.params:
|
|
||||||
if p.is_constexpr and p.name not in constexprs:
|
|
||||||
constexprs[p.name] = p.default #default=inspect._empty
|
|
||||||
self.kernel_cache[entry_key] = (kernel, constexprs)
|
|
||||||
else:
|
|
||||||
# load kernel from cache directly
|
|
||||||
kernel, constexprs = self.kernel_cache[entry_key]
|
|
||||||
|
|
||||||
if callable(grid):
|
|
||||||
# collect all arguments to the grid fn,ie:
|
|
||||||
# 1. args,
|
|
||||||
# 2. kwargs,
|
|
||||||
# 3. all all other captured arguments in CompiledKernel from
|
|
||||||
# Autotunner & Heuristics when kwargs & captured args conflict,
|
|
||||||
# captured args have higher priority
|
|
||||||
# 4. We must filter out captured args with default value firstly
|
|
||||||
constexprs = {
|
|
||||||
k: v
|
|
||||||
for k, v in constexprs.items() if v is not inspect._empty
|
|
||||||
}
|
|
||||||
meta = {
|
|
||||||
**dict(zip(self.arg_names, args)),
|
|
||||||
**kwargs,
|
|
||||||
**constexprs,
|
|
||||||
}
|
|
||||||
grid = grid(meta)
|
|
||||||
if isinstance(grid, tuple):
|
|
||||||
grid = grid + (1, 1)
|
|
||||||
elif isinstance(grid, list):
|
|
||||||
grid = grid + [1, 1]
|
|
||||||
kernel[grid[0:3]](*k_args)
|
|
||||||
# maintaining the same return type as the JITFunction.run
|
|
||||||
return kernel
|
|
||||||
|
|
||||||
|
|
||||||
def libentry():
|
|
||||||
"""
|
|
||||||
Decorator for triton library entries.
|
|
||||||
Motivation:
|
|
||||||
The runtime overhead of Triton kernels is the reason for the lower
|
|
||||||
performance of small kernels, particularly evident with smaller models.
|
|
||||||
Using this decorator can reduce Triton runtime overhead.
|
|
||||||
How:
|
|
||||||
The `run` function of JITFunction needs to accomplish:
|
|
||||||
- Parameter binding using inspect
|
|
||||||
- KernelArg type wrapping
|
|
||||||
- Cache key calculation
|
|
||||||
When dealing with small size, these steps can become bottlenecks in
|
|
||||||
Triton runtime. Libentry simplifies these steps to reduce runtime
|
|
||||||
overhead, thereby improving the runtime expenses of small kernels.
|
|
||||||
NOTE:
|
|
||||||
When Triton is upgraded to version 3.0.0, libentry can be removed,
|
|
||||||
see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(fn):
|
|
||||||
return LibEntry(fn)
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
Loading…
x
Reference in New Issue
Block a user