[V1] LoRA - Add triton kernels for V1 (#13096)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-03-10 17:27:53 -04:00 committed by GitHub
parent 0967110e42
commit 5ff0d32580
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1165 additions and 191 deletions

View File

@ -23,6 +23,7 @@ from vllm.lora.ops.triton_ops.bgmv_shrink import bgmv_shrink
from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand from vllm.lora.ops.triton_ops.sgmv_expand import sgmv_expand
from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink from vllm.lora.ops.triton_ops.sgmv_shrink import sgmv_shrink
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta, v1_expand, v1_shrink
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
@ -171,6 +172,8 @@ class OpType(Enum):
SGMV_EXPAND = auto() SGMV_EXPAND = auto()
BGMV_EXPAND = auto() BGMV_EXPAND = auto()
BGMV_EXPAND_SLICE = auto() BGMV_EXPAND_SLICE = auto()
V1_SHRINK = auto()
V1_EXPAND = auto()
@staticmethod @staticmethod
def from_str(s: str) -> "OpType": def from_str(s: str) -> "OpType":
@ -184,28 +187,43 @@ class OpType(Enum):
return OpType.BGMV_EXPAND return OpType.BGMV_EXPAND
if s.lower() == "bgmv_expand_slice": if s.lower() == "bgmv_expand_slice":
return OpType.BGMV_EXPAND_SLICE return OpType.BGMV_EXPAND_SLICE
if s.lower() == "v1_shrink":
return OpType.V1_SHRINK
if s.lower() == "v1_expand":
return OpType.V1_EXPAND
raise ValueError(f"Unrecognized str {s} to convert to OpType") raise ValueError(f"Unrecognized str {s} to convert to OpType")
def is_shrink_fn(self) -> bool: def is_shrink_fn(self) -> bool:
return self in [OpType.SGMV_SHRINK, OpType.BGMV_SHRINK] return self in [
OpType.SGMV_SHRINK, OpType.BGMV_SHRINK, OpType.V1_SHRINK
]
def is_expand_fn(self) -> bool: def is_expand_fn(self) -> bool:
return self in [OpType.SGMV_EXPAND, OpType.BGMV_EXPAND] return self in [
OpType.SGMV_EXPAND, OpType.BGMV_EXPAND, OpType.V1_EXPAND
]
def is_prefill_op(self) -> bool: def is_prefill_op(self) -> bool:
return self in [OpType.SGMV_SHRINK, OpType.SGMV_EXPAND] return self in [
OpType.SGMV_SHRINK, OpType.SGMV_EXPAND, OpType.V1_SHRINK,
OpType.V1_EXPAND
]
def is_decode_op(self) -> bool: def is_decode_op(self) -> bool:
return self in [ return self in [
OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE OpType.BGMV_SHRINK, OpType.BGMV_EXPAND, OpType.BGMV_EXPAND_SLICE,
OpType.V1_SHRINK, OpType.V1_EXPAND
] ]
def is_expand_slice_fn(self) -> bool: def is_expand_slice_fn(self) -> bool:
return self in [OpType.BGMV_EXPAND_SLICE] return self in [OpType.BGMV_EXPAND_SLICE]
def num_slices(self) -> list[int]: def num_slices(self) -> list[int]:
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]: if self in [
# SGMV kernels supports slices OpType.SGMV_EXPAND, OpType.SGMV_SHRINK, OpType.V1_SHRINK,
OpType.V1_EXPAND
]:
# SGMV kernels and v1 kernels supports slices
return [1, 2, 3] return [1, 2, 3]
if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]: if self in [OpType.BGMV_SHRINK, OpType.BGMV_EXPAND]:
return [1] return [1]
@ -250,11 +268,13 @@ class OpType(Enum):
m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank) m, k, n = self.mkn(batch_size, seq_length, hidden_size, lora_rank)
b_shape = (num_loras, n, k) # col-major b_shape = (num_loras, n, k) # col-major
if self == OpType.SGMV_SHRINK: if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
# SGMV shrink supports num_slices inherently in the kernel # SGMV shrink and V1 shrink kernels support num_slices inherently
# in the kernel.
return ((m, k), b_shape, (num_slices, m, n)) return ((m, k), b_shape, (num_slices, m, n))
if self == OpType.SGMV_EXPAND: if self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
# SGMV expand supports num_slices inherently in the kernel # SGMV expand and V1 expand kernels support num_slices inherently
# in the kernel
return ((num_slices, m, k), b_shape, (m, n * num_slices)) return ((num_slices, m, k), b_shape, (m, n * num_slices))
if self == OpType.BGMV_SHRINK: if self == OpType.BGMV_SHRINK:
return ((m, k), b_shape, (m, n)) return ((m, k), b_shape, (m, n))
@ -281,25 +301,30 @@ class OpType(Enum):
return bgmv_expand return bgmv_expand
if self == OpType.BGMV_EXPAND_SLICE: if self == OpType.BGMV_EXPAND_SLICE:
return emulate_bgmv_expand_slice return emulate_bgmv_expand_slice
if self == OpType.V1_SHRINK:
return v1_shrink
if self == OpType.V1_EXPAND:
return v1_expand
raise ValueError(f"Unrecognized optype {self}") raise ValueError(f"Unrecognized optype {self}")
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
lora_weights: list[torch.Tensor], lora_weights: list[torch.Tensor],
**kwargs) -> Callable: **kwargs) -> Callable:
"""Each benchmark operation expected the input, lora_weights and outputs """Each benchmark operation expects the input, lora_weights and outputs
in a slightly different format. Refer to self.matmul_shapes(). in a slightly different format. Refer to self.matmul_shapes().
run_ref_group_gemm accounts for those differences in executing a run_ref_group_gemm accounts for those differences in executing a
reference group gemm for correctness testing. reference group gemm for correctness testing.
""" """
w_dtype = lora_weights[0].dtype w_dtype = lora_weights[0].dtype
num_slices = len(lora_weights) num_slices = len(lora_weights)
if self == OpType.SGMV_SHRINK: if self in [OpType.SGMV_SHRINK, OpType.V1_SHRINK]:
for slice_idx in range(num_slices): for slice_idx in range(num_slices):
ref_group_gemm(ref_out=output[slice_idx, :], ref_group_gemm(ref_out=output[slice_idx, :],
input=input, input=input,
lora_weights=lora_weights[slice_idx], lora_weights=lora_weights[slice_idx],
**kwargs) **kwargs)
if self == OpType.SGMV_EXPAND: elif self in [OpType.SGMV_EXPAND, OpType.V1_EXPAND]:
hidden_size = lora_weights[0].shape[1] hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices): for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size slice_offset = slice_idx * hidden_size
@ -308,19 +333,19 @@ class OpType(Enum):
input=input[slice_idx].clone().to(dtype=w_dtype), input=input[slice_idx].clone().to(dtype=w_dtype),
lora_weights=lora_weights[slice_idx], lora_weights=lora_weights[slice_idx],
**kwargs) **kwargs)
if self == OpType.BGMV_SHRINK: elif self == OpType.BGMV_SHRINK:
assert num_slices == 1 assert num_slices == 1
ref_group_gemm(ref_out=output, ref_group_gemm(ref_out=output,
input=input, input=input,
lora_weights=lora_weights[0], lora_weights=lora_weights[0],
**kwargs) **kwargs)
if self == OpType.BGMV_EXPAND: elif self == OpType.BGMV_EXPAND:
assert num_slices == 1 assert num_slices == 1
ref_group_gemm(ref_out=output, ref_group_gemm(ref_out=output,
input=input.clone().to(dtype=w_dtype), input=input.clone().to(dtype=w_dtype),
lora_weights=lora_weights[0], lora_weights=lora_weights[0],
**kwargs) **kwargs)
if self == OpType.BGMV_EXPAND_SLICE: elif self == OpType.BGMV_EXPAND_SLICE:
hidden_size = lora_weights[0].shape[1] hidden_size = lora_weights[0].shape[1]
for slice_idx in range(num_slices): for slice_idx in range(num_slices):
slice_offset = slice_idx * hidden_size slice_offset = slice_idx * hidden_size
@ -329,6 +354,7 @@ class OpType(Enum):
input=input[slice_idx].clone().to(dtype=w_dtype), input=input[slice_idx].clone().to(dtype=w_dtype),
lora_weights=lora_weights[slice_idx], lora_weights=lora_weights[slice_idx],
**kwargs) **kwargs)
else:
raise ValueError(f"Unrecognized optype {self}") raise ValueError(f"Unrecognized optype {self}")
@ -390,6 +416,8 @@ class BenchmarkTensors:
seq_start_loc: torch.Tensor seq_start_loc: torch.Tensor
prompt_lora_mapping: torch.Tensor prompt_lora_mapping: torch.Tensor
token_lora_mapping: torch.Tensor token_lora_mapping: torch.Tensor
# v1 kernel metadata
v1_kernel_meta: Optional[V1KernelMeta] = None
def io_types(self) -> str: def io_types(self) -> str:
return (f"{dtype_to_str(self.input.dtype)}x" return (f"{dtype_to_str(self.input.dtype)}x"
@ -432,10 +460,19 @@ class BenchmarkTensors:
total_tokens, ctx.batch_size, prompt_lora_indices_tensor, total_tokens, ctx.batch_size, prompt_lora_indices_tensor,
seq_len_tensor, "cpu") seq_len_tensor, "cpu")
v1_kernel_meta = None
if op_type in [OpType.V1_SHRINK, OpType.V1_EXPAND]:
v1_kernel_meta = V1KernelMeta.make(
max_loras=ctx.num_loras,
max_num_tokens=token_lora_indices_tensor.size(0),
device="cpu")
v1_kernel_meta.prepare_tensors(
token_lora_mapping=token_lora_indices_tensor)
return BenchmarkTensors(input_tensor, lora_weights, output_tensor, return BenchmarkTensors(input_tensor, lora_weights, output_tensor,
seq_len_tensor, seq_start_loc_tensor, seq_len_tensor, seq_start_loc_tensor,
prompt_lora_indices_tensor, prompt_lora_indices_tensor,
token_lora_indices_tensor) token_lora_indices_tensor, v1_kernel_meta)
def sanity_check(self) -> None: def sanity_check(self) -> None:
""" """
@ -468,6 +505,13 @@ class BenchmarkTensors:
for i in range(len(self.lora_weights_lst)): for i in range(len(self.lora_weights_lst)):
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i]) self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
# v1 meta
if self.v1_kernel_meta:
for field_name in V1KernelMeta.__dataclass_fields__:
field = getattr(self.v1_kernel_meta, field_name)
assert isinstance(field, torch.Tensor)
setattr(self.v1_kernel_meta, field_name, to_device(field))
def metadata(self) -> tuple[int, int, int]: def metadata(self) -> tuple[int, int, int]:
""" """
Return num_seqs, num_tokens and max_seq_len Return num_seqs, num_tokens and max_seq_len
@ -667,6 +711,78 @@ class BenchmarkTensors:
}) })
return {'kwargs_list': kwargs_list} return {'kwargs_list': kwargs_list}
def as_v1_shrink_kwargs(self) -> dict[str, Any]:
assert self.v1_kernel_meta is not None
self.sanity_check()
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata()
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape [num_tokens, hidden_size]
assert len(i_shape) == 2
assert i_shape[0] == num_tokens
hidden_size = i_shape[1]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert len(lw_shape) == 3
assert lw_shape[2] == hidden_size
lora_rank = lw_shape[1]
# Expected output shape [num_slices, num_tokens, lora_rank]
assert len(o_shape) == 3
assert o_shape == (num_slices, num_tokens, lora_rank)
return {
'inputs': self.input,
'lora_a_weights': self.lora_weights_lst,
'output_tensor': self.output,
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids':
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
'lora_ids': self.v1_kernel_meta.active_lora_ids,
'scaling': 1.0,
}
def as_v1_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
assert self.v1_kernel_meta is not None
self.sanity_check()
self.to_device(self.input.device)
_, num_tokens, _, num_slices = self.metadata()
# Sanity check matrix shapes.
i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[
0].shape, self.output.shape
# Expected input shape : [num_slices, num_tokens, lora_rank]
assert len(i_shape) == 3
assert i_shape[0] == num_slices
assert i_shape[1] == num_tokens
lora_rank = i_shape[2]
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
assert len(lw_shape) == 3
assert lw_shape[2] == lora_rank
hidden_size = lw_shape[1]
# Expected output shape : [num_tokens, hidden_size * num_slices]
assert len(o_shape) == 2
assert o_shape == (num_tokens, hidden_size * num_slices)
return {
'inputs': self.input,
'lora_b_weights': self.lora_weights_lst,
'output_tensor': self.output,
'token_lora_mapping': self.v1_kernel_meta.token_lora_mapping,
'token_indices_sorted_by_lora_ids':
self.v1_kernel_meta.token_indices_sorted_by_lora_ids,
'num_tokens_per_lora': self.v1_kernel_meta.num_tokens_per_lora,
'lora_token_start_loc': self.v1_kernel_meta.lora_token_start_loc,
'lora_ids': self.v1_kernel_meta.active_lora_ids,
'offset_start': 0,
'add_inputs': add_inputs,
}
def bench_fn_kwargs(self, def bench_fn_kwargs(self,
op_type: OpType, op_type: OpType,
add_inputs: Optional[bool] = None) -> dict[str, Any]: add_inputs: Optional[bool] = None) -> dict[str, Any]:
@ -685,6 +801,10 @@ class BenchmarkTensors:
return self.as_bgmv_expand_kwargs(add_inputs) return self.as_bgmv_expand_kwargs(add_inputs)
if op_type == OpType.BGMV_EXPAND_SLICE: if op_type == OpType.BGMV_EXPAND_SLICE:
return self.as_bgmv_expand_slice_kwargs(add_inputs) return self.as_bgmv_expand_slice_kwargs(add_inputs)
if op_type == OpType.V1_SHRINK:
return self.as_v1_shrink_kwargs()
if op_type == OpType.V1_EXPAND:
return self.as_v1_expand_kwargs(add_inputs)
raise ValueError(f"Unrecognized optype {self}") raise ValueError(f"Unrecognized optype {self}")
def test_correctness(self, op_type: OpType, def test_correctness(self, op_type: OpType,
@ -872,12 +992,9 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
timers = [] timers = []
for bench_ctx in bench_ctxs: for bench_ctx in bench_ctxs:
for seq_len in args.seq_lengths: for seq_len in args.seq_lengths:
bench_ops: list[OpType] = [] bench_ops: list[OpType] = args.op_types
if seq_len == 1: if seq_len > 1:
# bench all decode ops # bench only prefill ops
bench_ops = [op for op in args.op_types if op.is_decode_op()]
else:
# bench all prefill ops
bench_ops = [op for op in args.op_types if op.is_prefill_op()] bench_ops = [op for op in args.op_types if op.is_prefill_op()]
seq_len_timers = [] seq_len_timers = []

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import importlib
import random import random
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
@ -63,6 +64,36 @@ DEVICES = ([
# stages, so we need to verify this. prefill stage(True) or decode stage(False) # stages, so we need to verify this. prefill stage(True) or decode stage(False)
STAGES = [True, False] STAGES = [True, False]
# With the inclusion of V1 tests (look at the run_with_both_engines_lora),
# the tests in this file run twice, once with the V0 engine and then with
# the V1 engine.
# The NUM_RANDOM_SEEDS value was set to 10 before. It is cut to half
# with the inclusion of V1 tests to maintain the CI test times.
NUM_RANDOM_SEEDS = 5
# The VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS value was set to
# 256 before. It is cut to half with the inclusion of V1 tests to maintain
# the CI test times.
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
# Reload punica_gpu as the kernels used are tied to engine type.
from vllm.lora.punica_wrapper import punica_gpu
importlib.reload(punica_gpu)
# Release any memory we might be holding on to. CI runs OOMs otherwise.
from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT,
_LORA_B_PTR_DICT)
_LORA_B_PTR_DICT.clear()
_LORA_A_PTR_DICT.clear()
yield
def get_random_id_to_index(num_loras: int, def get_random_id_to_index(num_loras: int,
num_slots: int, num_slots: int,
@ -226,7 +257,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper) assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
@ -241,7 +272,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
return embedding, lora_embedding return embedding, lora_embedding
for i in range(10): for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i) set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -329,7 +360,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper) assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
@ -353,7 +384,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
return expanded_embedding, lora_embedding return expanded_embedding, lora_embedding
for i in range(10): for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i) set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -468,7 +499,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
punica_wrapper = get_punica_wrapper(8192, 256, device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper) assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
@ -490,7 +521,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
return linear, logits_processor, lora_logits_processor return linear, logits_processor, lora_logits_processor
for i in range(10): for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i) set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -600,10 +631,10 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8 max_loras = 8
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
lora_dtype=torch.float16, lora_dtype=torch.float16,
@ -627,7 +658,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage,
assert lora_linear.lora_bias_stacked is None assert lora_linear.lora_bias_stacked is None
return linear, lora_linear return linear, lora_linear
for i in range(10): for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i) set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -716,10 +747,10 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8 max_loras = 8
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
fully_sharded_loras=fully_shard, fully_sharded_loras=fully_shard,
@ -753,7 +784,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
assert lora_linear.lora_bias_stacked is None assert lora_linear.lora_bias_stacked is None
return linear, lora_linear return linear, lora_linear
for i in range(10): for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i) set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -842,10 +873,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
torch.cuda.set_device(device) torch.cuda.set_device(device)
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device)
assert check_punica_wrapper(punica_wrapper)
max_loras = 8 max_loras = 8
torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
fully_sharded_loras=fully_shard, fully_sharded_loras=fully_shard,
@ -900,7 +931,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
assert lora_linear.lora_bias_stacked is None assert lora_linear.lora_bias_stacked is None
return linear, lora_linear return linear, lora_linear
for i in range(10): for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i) set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
@ -1002,12 +1033,12 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
is_neox_style, rotary_dim, head_size, is_neox_style, rotary_dim, head_size,
seq_len) -> None: seq_len) -> None:
dtype = torch.float16 dtype = torch.float16
max_loras = 8
seed = 0 seed = 0
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
punica_wrapper = get_punica_wrapper(8192, 256, device) punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper) assert check_punica_wrapper(punica_wrapper)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
long_lora_scaling_factors=scaling_factors, long_lora_scaling_factors=scaling_factors,
@ -1083,7 +1114,8 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) @pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize("seed", list(range(256))) @pytest.mark.parametrize(
"seed", list(range(VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS)))
def test_vocab_parallel_embedding_indices(tp_size, seed): def test_vocab_parallel_embedding_indices(tp_size, seed):
random.seed(seed) random.seed(seed)
vocab_size = random.randint(4000, 64000) vocab_size = random.randint(4000, 64000)

View File

@ -5,10 +5,12 @@ import pytest
import torch import torch
import vllm.lora.ops.triton_ops # noqa: F401 import vllm.lora.ops.triton_ops # noqa: F401
import vllm.lora.ops.triton_ops.v1 # noqa: F401
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice, from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand, bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink) sgmv_expand_slice, sgmv_shrink)
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
from vllm.lora.ops.triton_ops.v1 import V1KernelMeta
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .utils import (PunicaTensors, assert_close, generate_data, from .utils import (PunicaTensors, assert_close, generate_data,
@ -91,12 +93,12 @@ def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
_dict_lock = Lock() _dict_lock = Lock()
def check_sgmv_shrink(batches: int, num_loras: int, rank: int, def check_shrink_kernels(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype, hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, scaling: float): device: str, seq_length: int, scaling: float):
""" """
Compare outputs of vllm.sgmv_shrink kernel against a reference Compare outputs of vllm.sgmv_shrink and vllm.v1_shrink kernel against a
implementation. reference implementation.
""" """
data: PunicaTensors = generate_data_for_nslices( data: PunicaTensors = generate_data_for_nslices(
batches, batches,
@ -111,44 +113,63 @@ def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
) )
max_seq_length, token_nums = data.meta() max_seq_length, token_nums = data.meta()
# Setup metadata information for SGMV and reference kernels
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
# Setup metadata information for the V1 kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping)
ref_out_tensor = data.ref_out_tensor
sgmv_out_tensor = data.our_out_tensor
v1_out_tensor = data.our_out_tensor.clone()
# Preventing cache error pointer. # Preventing cache error pointer.
with _dict_lock: with _dict_lock:
# SGMV shrink kernel
_LORA_A_PTR_DICT.clear() _LORA_A_PTR_DICT.clear()
torch.ops.vllm.sgmv_shrink( torch.ops.vllm.sgmv_shrink(
data.inputs_tensor, data.inputs_tensor,
data.lora_weights, data.lora_weights,
data.our_out_tensor, sgmv_out_tensor,
data.b_seq_start_loc, *sgmv_meta_args,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
scaling, scaling,
) )
# V1 shrink kernel
_LORA_A_PTR_DICT.clear()
torch.ops.vllm.v1_shrink(
data.inputs_tensor,
data.lora_weights,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
scaling,
)
# Reference
sgmv_shrink_for_nslices( sgmv_shrink_for_nslices(
nslices, nslices,
data.inputs_tensor, data.inputs_tensor,
data.lora_weights, data.lora_weights,
data.ref_out_tensor, ref_out_tensor,
data.b_seq_start_loc, *sgmv_meta_args,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
scaling, scaling,
) )
assert_close(data.our_out_tensor, data.ref_out_tensor)
assert_close(sgmv_out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
def check_sgmv_expand(batches: int, num_loras: int, rank: int, def check_expand_kernels(batches: int, num_loras: int, rank: int,
hidden_size: int, nslices: int, dtype: torch.dtype, hidden_size: int, nslices: int, dtype: torch.dtype,
device: str, seq_length: int, add_inputs: bool): device: str, seq_length: int, add_inputs: bool):
""" """
Compare outputs of vllm.sgmv_expand kernel against a reference Compare outputs of vllm.sgmv_expand and vllm.v1_expand kernels against a
implementation. reference implementation.
""" """
data: PunicaTensors = generate_data_for_nslices( data: PunicaTensors = generate_data_for_nslices(
batches, batches,
@ -164,36 +185,54 @@ def check_sgmv_expand(batches: int, num_loras: int, rank: int,
max_seq_length, token_nums = data.meta() max_seq_length, token_nums = data.meta()
# Setup metadata information for SGMV and reference kernels
sgmv_meta_args = (data.b_seq_start_loc, data.seq_len_tensor,
data.prompt_lora_mapping, batches, max_seq_length,
token_nums)
# Setup metadata information for the V1 kernel.
v1_meta = V1KernelMeta.make(max_loras=num_loras,
max_num_tokens=token_nums,
device='cuda')
v1_meta.prepare_tensors(data.token_lora_mapping)
# Setup output tensors
ref_out_tensor = data.ref_out_tensor
sgmv_out_tensor = data.our_out_tensor
v1_out_tensor = data.our_out_tensor.clone()
with _dict_lock: with _dict_lock:
# SGMV expand kernel
_LORA_B_PTR_DICT.clear() _LORA_B_PTR_DICT.clear()
torch.ops.vllm.sgmv_expand( torch.ops.vllm.sgmv_expand(
data.inputs_tensor, data.inputs_tensor,
data.lora_weights, data.lora_weights,
data.our_out_tensor, sgmv_out_tensor,
data.b_seq_start_loc, *sgmv_meta_args,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
offset_start=0, offset_start=0,
add_inputs=add_inputs, add_inputs=add_inputs,
) )
# V1 expand kernel
_LORA_B_PTR_DICT.clear()
torch.ops.vllm.v1_expand(data.inputs_tensor,
data.lora_weights,
v1_out_tensor,
*v1_meta.meta_args(token_nums=token_nums),
offset_start=0,
add_inputs=add_inputs)
# Reference
sgmv_expand_for_nslices(nslices, sgmv_expand_for_nslices(nslices,
hidden_size, hidden_size,
data.inputs_tensor, data.inputs_tensor,
data.lora_weights, data.lora_weights,
data.ref_out_tensor, ref_out_tensor,
data.b_seq_start_loc, *sgmv_meta_args,
data.seq_len_tensor,
data.prompt_lora_mapping,
batches,
max_seq_length,
token_nums,
add_inputs=add_inputs) add_inputs=add_inputs)
assert_close(data.our_out_tensor, data.ref_out_tensor) assert_close(sgmv_out_tensor, ref_out_tensor)
assert_close(v1_out_tensor, ref_out_tensor)
def check_bgmv_shrink(batches: int, num_loras: int, rank: int, def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
@ -439,7 +478,7 @@ SEED = [0]
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_sgmv( def test_kernels(
batches: int, batches: int,
num_loras: int, num_loras: int,
rank: int, rank: int,
@ -450,11 +489,14 @@ def test_punica_sgmv(
seed: int, seed: int,
op_type: str, op_type: str,
): ):
"""
Tests SGMV and V1 kernels.
"""
torch.set_default_device(device) torch.set_default_device(device)
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
if op_type == "shrink": if op_type == "shrink":
check_sgmv_shrink(batches=batches, check_shrink_kernels(batches=batches,
num_loras=num_loras, num_loras=num_loras,
rank=rank, rank=rank,
hidden_size=hidden_size, hidden_size=hidden_size,
@ -464,7 +506,7 @@ def test_punica_sgmv(
seq_length=128, seq_length=128,
scaling=0.5) scaling=0.5)
else: else:
check_sgmv_expand(batches=batches, check_expand_kernels(batches=batches,
num_loras=num_loras, num_loras=num_loras,
rank=rank, rank=rank,
hidden_size=hidden_size, hidden_size=hidden_size,
@ -484,7 +526,7 @@ def test_punica_sgmv(
@pytest.mark.parametrize("device", DEVICES) @pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("op_type", ["shrink", "expand"])
def test_punica_sgmv_hidden_size( def test_kernels_hidden_size(
batches: int, batches: int,
num_loras: int, num_loras: int,
rank: int, rank: int,
@ -495,11 +537,14 @@ def test_punica_sgmv_hidden_size(
seed: int, seed: int,
op_type: str, op_type: str,
): ):
"""
Tests SGMV and V1 kernels.
"""
torch.set_default_device(device) torch.set_default_device(device)
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
if op_type == "shrink": if op_type == "shrink":
check_sgmv_shrink(batches=batches, check_shrink_kernels(batches=batches,
num_loras=num_loras, num_loras=num_loras,
rank=rank, rank=rank,
hidden_size=hidden_size, hidden_size=hidden_size,
@ -509,7 +554,7 @@ def test_punica_sgmv_hidden_size(
seq_length=128, seq_length=128,
scaling=0.5) scaling=0.5)
else: else:
check_sgmv_expand(batches=batches, check_expand_kernels(batches=batches,
num_loras=num_loras, num_loras=num_loras,
rank=rank, rank=rank,
hidden_size=hidden_size, hidden_size=hidden_size,

View File

@ -326,9 +326,11 @@ class LoRAModelManager(AdapterModelManager):
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens, self.punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs, max_batches=self.max_num_seqs,
device=self.device) device=self.device,
max_loras=self.lora_config.max_loras)
# Scaling factor -> offset to the sin_cos_cache to it. # Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora. # Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {} self.scaling_factor_to_offset: Dict[float, int] = {}

View File

@ -54,7 +54,7 @@ _LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} _LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {}
def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str): def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: torch.device):
""" """
`_LORA_A_PTR_DICT` collects the required information during `profile_run`, `_LORA_A_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT. After this, it remains constant and subsequent usage is through LUT.
@ -100,7 +100,7 @@ def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str):
def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int,
device: str): device: torch.device):
""" """
`_LORA_B_PTR_DICT` collects the required information during `profile_run`, `_LORA_B_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT. After this, it remains constant and subsequent usage is through LUT.

View File

@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.lora.ops.triton_ops.v1.v1_expand import v1_expand
from vllm.lora.ops.triton_ops.v1.v1_kernel_metadata import V1KernelMeta
from vllm.lora.ops.triton_ops.v1.v1_shrink import v1_shrink
__all__ = [
"v1_expand",
"v1_shrink",
"V1KernelMeta",
]

View File

@ -0,0 +1,282 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import List
import torch
import triton
import triton.language as tl
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
from vllm.utils import direct_register_custom_op
@triton.jit
def _v1_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_loc,
input_d0_stride,
input_d1_stride,
input_d2_stride, # 1
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr, # 1
output_d0_stride,
output_d1_stride, # 1
output_hs_ptr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
SLICE_NUM: tl.constexpr,
SAME_STRIDE: tl.constexpr):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_mn = tl.program_id(axis=0)
pid_m = pid_mn % cta_m_num
pid_n = (pid_mn // cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# When the output dimensions of each slice are the same,cur_n=N, otherwise
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
# qkv linear.
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
if pid_n * BLOCK_N >= curr_N:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
lora_m_indices_start + cta_m_offset)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_expand_kernel(
pid_n,
lora_id,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
curr_N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M,
BLOCK_N,
BLOCK_K,
SAME_STRIDE,
SLICE_NUM,
EVEN_K,
CAST_TYPE,
ADD_INPUTS)
@torch.inference_mode()
def _v1_expand(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: List[
torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
output_tensor: torch.
Tensor, # shape [num_tokens, hidden_size * num_slices]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (List[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
"""
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(0) == len(lora_b_weights)
assert output_tensor.is_contiguous()
# metadata sanity check.
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor,
lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor,
same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start,
inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank
M = inputs.size(1)
ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False
NUM_SLICES = len(lora_b_weights)
# Triton kernel configs.
BLOCK_M = 64
BLOCK_N = 128
BLOCK_K = 16
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
MAX_NREG = None
EVEN_K = K % BLOCK_K == 0 # type: ignore
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only a few input tokens require
# LoRA. This might not be the best in all cases.
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks simply exit.
MAX_LORAS,
)
_v1_expand_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
MAX_N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_tensor,
inputs.stride(0),
inputs.stride(1),
inputs.stride(2),
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
hidden_sizes_tensor,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
NUM_SLICES,
same_stride,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
maxnreg=MAX_NREG,
)
return
def _v1_expand_fake(
inputs: torch.Tensor,
lora_b_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="v1_expand",
op_func=_v1_expand,
mutates_args=["output_tensor"],
fake_impl=_v1_expand_fake,
)
v1_expand = torch.ops.vllm.v1_expand
except AttributeError:
v1_expand = _v1_expand

View File

@ -0,0 +1,117 @@
# SPDX-License-Identifier: Apache-2.0
"""
V1 LoRA kernels metadata preparation utilities.
"""
from dataclasses import dataclass
from typing import Tuple, Union
import torch
@dataclass
class V1KernelMeta:
token_lora_mapping: torch.Tensor
token_indices_sorted_by_lora_ids: torch.Tensor
active_lora_ids: torch.Tensor
num_tokens_per_lora: torch.Tensor
lora_token_start_loc: torch.Tensor
@staticmethod
def make(max_loras: int, max_num_tokens: int,
device: Union[torch.device, str]) -> "V1KernelMeta":
token_lora_mapping = torch.empty(max_num_tokens,
dtype=torch.int32,
device=device)
token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens,
dtype=torch.int32,
device=device)
# +1 because "no-lora" is also a possibility
# example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1]
# is a possibility.
active_lora_ids = torch.empty(max_loras + 1,
dtype=torch.int32,
device=device)
# using running example, [3, 10, 5, 2] is a possibility.
num_tokens_per_lora = torch.zeros(max_loras + 1,
dtype=torch.int32,
device=device)
# +2 for this because, the first index is always 0.
# using running example, lora_token_start_loc
# is [0, 3, 13, 18, 20].
lora_token_start_loc = torch.zeros(max_loras + 2,
dtype=torch.int32,
device=device)
return V1KernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc)
def _reset(self):
self.active_lora_ids.fill_(-1)
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
Prepare kernel metadata tensors for the current forward pass.
Args:
token_lora_tensor (torch.Tensor): Tensor containing lora indices
for each input token.
"""
self._reset()
num_tokens = token_lora_mapping.size(0)
# copy token lora mapping
self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping,
non_blocking=True)
# token_indices_sorted_by_lora_ids
_, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping,
stable=True)
# start gpu transfer
self.token_indices_sorted_by_lora_ids[:num_tokens].copy_(
token_indices_sorted_by_lora_ids, non_blocking=True)
# active_lora_ids, num_tokens_per_lora
lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping,
sorted=False,
return_counts=True)
self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids,
non_blocking=True)
self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_(
num_tokens_per_lora, non_blocking=True)
# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_(
lora_token_start_loc, non_blocking=True)
def meta_args(
self, token_nums: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]:
"""
This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the
metadata required by the kernel, in order, as a tuple, so it can be
unpacked directly during the v1_shrink/v1_expand function call.
Args:
token_nums (int): Number of input tokens in the current forward
pass.
"""
return (self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
self.num_tokens_per_lora, self.lora_token_start_loc,
self.active_lora_ids)

View File

@ -0,0 +1,236 @@
# SPDX-License-Identifier: Apache-2.0
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import List
import torch
import triton
import triton.language as tl
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
from vllm.utils import direct_register_custom_op
@triton.jit
def _v1_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
lora_token_start_loc, lora_ids, scaling, input_d0_stride,
input_d1_stride, lora_d0_stride, lora_d1_stride,
lora_d2_stride, output_d0_stride, output_d1_stride,
output_d2_stride, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_sk_m_n = tl.program_id(axis=0)
pid_sk = pid_sk_m_n % SPLIT_K
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
lora_m_indices_start + cta_m_offset)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
SLICE_NUM)
@torch.inference_mode()
def _v1_shrink(
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
lora_a_weights: List[
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): Input tensor
lora_a_weights (List[torch.Tensor]): LoRA weights
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
scaling (float): Scaling factor.
"""
assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(1) == lora_a_weights[0].size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
# metadata sanity check
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
M = inputs.size(0)
NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0)
# Triton kernel configs
BLOCK_M = 32
BLOCK_N = 16
BLOCK_K = 256 if M < 128 else 32
SPLIT_K = 64 if M < 128 else 8
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
MAX_NREG = None
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only few of the input tokens
# require LoRA. This might not be the best in all cases.
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks exit early.
MAX_LORAS,
)
_v1_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_strides_d0,
lora_strides_d1,
lora_strides_d2,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor.stride(2),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
NUM_SLICES,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
maxnreg=MAX_NREG,
)
return
def _v1_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: List[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
scaling: float,
) -> None:
return
try:
direct_register_custom_op(
op_name="v1_shrink",
op_func=_v1_shrink,
mutates_args=["output_tensor"],
fake_impl=_v1_shrink_fake,
)
v1_shrink = torch.ops.vllm.v1_shrink
except AttributeError:
v1_shrink = _v1_shrink

View File

@ -6,13 +6,19 @@ Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547 https://arxiv.org/abs/2310.18547
""" """
from typing import Optional, Tuple, Union, final from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final
import torch import torch
import vllm.envs as env
from vllm.lora.layers import LoRAMapping
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
if HAS_TRITON: if HAS_TRITON:
if env.VLLM_USE_V1:
from vllm.lora.ops.triton_ops.v1 import (V1KernelMeta, v1_expand,
v1_shrink)
else:
from vllm.lora.ops.triton_ops import bgmv_expand from vllm.lora.ops.triton_ops import bgmv_expand
from vllm.lora.ops.triton_ops import bgmv_expand_slice from vllm.lora.ops.triton_ops import bgmv_expand_slice
from vllm.lora.ops.triton_ops import bgmv_shrink from vllm.lora.ops.triton_ops import bgmv_shrink
@ -21,9 +27,62 @@ if HAS_TRITON:
from .punica_base import PunicaWrapperBase from .punica_base import PunicaWrapperBase
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.models import LongContextLoRAContext
class V1KernelMixin:
def _v1_make_metadata(self, max_loras: int, max_num_batched_tokens: int,
max_batches: int, device: Union[torch.device, str]):
self.token_mapping_v1_meta = V1KernelMeta.make(max_loras,
max_num_batched_tokens,
device=device)
self.prompt_mapping_v1_meta = V1KernelMeta.make(max_loras,
max_batches,
device=device)
def _v1_prepare_metadata_tensors(self, token_lora_indices: torch.Tensor,
sampler_indices: torch.Tensor):
self.token_mapping_v1_meta.prepare_tensors(token_lora_indices)
self.prompt_mapping_v1_meta.prepare_tensors(sampler_indices)
def _v1_apply_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
scale: float,
):
v1_shrink(
x,
w_t_all,
y,
*self.token_mapping_v1_meta.meta_args(x.size(0)),
scale,
)
def _v1_apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: Tuple[torch.Tensor, ...],
offset_start: int,
add_inputs: bool,
):
v1_expand(
x,
w_t_all,
y,
*self.token_mapping_v1_meta.meta_args(x.size(0)),
offset_start=offset_start,
add_inputs=add_inputs,
)
@final @final
class PunicaWrapperGPU(PunicaWrapperBase): class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
""" """
PunicaWrapperGPU is designed to manage and provide metadata for the punica PunicaWrapperGPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for kernel. The main function is to maintain the state information for
@ -35,6 +94,36 @@ class PunicaWrapperGPU(PunicaWrapperBase):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device) device)
self.max_loras = kwargs['max_loras']
if env.VLLM_USE_V1:
self._v1_make_metadata(self.max_loras, max_num_batched_tokens,
max_batches, device)
def update_metadata(
self,
mapping: LoRAMapping,
lora_index_to_id: List[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs):
if env.VLLM_USE_V1:
self.is_prefill = mapping.is_prefill
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
self._v1_prepare_metadata_tensors(self.token_lora_indices,
self.sampler_indices)
else:
# Forward to base class update_metadata
PunicaWrapperBase.update_metadata(self, mapping, lora_index_to_id,
max_loras, vocab_size,
extra_vocab_size,
long_lora_context, **kwargs)
def _apply_shrink_prefill( def _apply_shrink_prefill(
self, self,
y: torch.Tensor, y: torch.Tensor,
@ -66,7 +155,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self, self,
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, w_t_all: Tuple[torch.Tensor, ...],
offset_start: int, offset_start: int,
add_inputs: bool, add_inputs: bool,
): ):
@ -118,9 +207,16 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
if env.VLLM_USE_V1:
self._v1_apply_shrink(y, x, lora_a_stacked, scale) # type: ignore
else:
if self.is_prefill: if self.is_prefill:
# NOTE fused kernel # NOTE fused kernel
self._apply_shrink_prefill(y, x, lora_a_stacked, scale) self._apply_shrink_prefill(
y, # type: ignore
x,
lora_a_stacked,
scale)
else: else:
# TODO fuse these kernels # TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)): for slice_idx in range(len(lora_a_stacked)):
@ -160,10 +256,23 @@ class PunicaWrapperGPU(PunicaWrapperBase):
if lora_bias_stacked is not None: if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices, self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked) lora_bias_stacked)
if env.VLLM_USE_V1:
# TODO (varun): Profile with add_inputs = False. i.e. move the
# addition out of the kernel
self._v1_apply_expand(
y,
x, # type: ignore
lora_b_stacked,
offset_start,
add_inputs=True)
else:
if self.is_prefill: if self.is_prefill:
# NOTE fused kernel # NOTE fused kernel
self._apply_expand_prefill(y, self._apply_expand_prefill(
x, y,
x, # type: ignore
lora_b_stacked, lora_b_stacked,
offset_start, offset_start,
add_inputs=True) add_inputs=True)
@ -200,10 +309,16 @@ class PunicaWrapperGPU(PunicaWrapperBase):
add_inputs (bool): Default to True. add_inputs (bool): Default to True.
""" """
if env.VLLM_USE_V1:
self._v1_apply_expand(y,
x.unsqueeze(dim=0), (lora_b_stacked, ),
offset_start=0,
add_inputs=add_inputs)
else:
if self.is_prefill: if self.is_prefill:
sgmv_expand( sgmv_expand(
x.unsqueeze(dim=0), x.unsqueeze(dim=0),
[lora_b_stacked], (lora_b_stacked, ),
y, y,
*self.prefill_metadata, *self.prefill_metadata,
offset_start=0, offset_start=0,
@ -257,14 +372,20 @@ class PunicaWrapperGPU(PunicaWrapperBase):
r = lora_b_stacked[0].size(-1) r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default ,refer to: # We set the buffer to be float32 by default ,refer to:
# https://github.com/triton-lang/triton/issues/1387 # https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros( buffer = torch.zeros( # type: ignore
(len(output_slices), x.size(0), r), (len(output_slices), x.size(0), r),
dtype=torch.float32, dtype=torch.float32,
device=x.device, device=x.device,
) )
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) self.add_shrink(
self.add_expand(y, buffer, # type: ignore
buffer, x,
lora_a_stacked,
scale,
**kwargs)
self.add_expand(
y,
buffer, # type: ignore
lora_b_stacked, lora_b_stacked,
None, None,
output_slices, output_slices,
@ -305,7 +426,18 @@ class PunicaWrapperGPU(PunicaWrapperBase):
buffer = torch.zeros((x.size(0), r), buffer = torch.zeros((x.size(0), r),
dtype=torch.float32, dtype=torch.float32,
device=x.device) device=x.device)
# LogitsProcessorWithLoRA always using bgmv.
if env.VLLM_USE_V1:
v1_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0),
*self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale)
v1_expand(buffer.unsqueeze(dim=0), [lora_b_stacked],
y,
*self.prompt_mapping_v1_meta.meta_args(buffer.size(0)),
add_inputs=True)
else:
# V0 LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale) bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer, bgmv_expand(buffer,
lora_b_stacked, lora_b_stacked,

View File

@ -62,9 +62,9 @@ class LoRAModelRunnerMixin:
if not self.lora_manager: if not self.lora_manager:
raise RuntimeError("LoRA is not enabled.") raise RuntimeError("LoRA is not enabled.")
# We dont make any distinction between prefills and decodes in the # Set is_prefill to True, so we always use the SGMV kernels.
# scheduler. To that effect, set is_prefill to True so we use the # For cuda platforms, we have specialized triton kernels, and
# sgmv punica kernels always. # the cuda path ignores `is_prefill`.
lora_mapping = LoRAMapping(token_lora_mapping, lora_mapping = LoRAMapping(token_lora_mapping,
prompt_lora_mapping, prompt_lora_mapping,
is_prefill=True) is_prefill=True)