[Kernel] Full Tensor Parallelism for LoRA Layers (#3524)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Austin Veselka 2024-04-27 02:03:48 -05:00 committed by GitHub
parent 18d23f642a
commit eefeb16464
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 686 additions and 111 deletions

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)

View File

@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py // and vllm/tests/lora/test_punica.py
// Used for defining kernels going from the variety of
// dim in to the narrow dim out
// Using it for the fully sharded column
// parallel LoRA A which splits the rank dim
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, 128, narrow) \
f(in_T, out_T, W_T, 256, narrow) \
f(in_T, out_T, W_T, 512, narrow) \
f(in_T, out_T, W_T, 640, narrow) \
f(in_T, out_T, W_T, 768, narrow) \
f(in_T, out_T, W_T, 1024, narrow) \
f(in_T, out_T, W_T, 1152, narrow) \
f(in_T, out_T, W_T, 1280, narrow) \
f(in_T, out_T, W_T, 1536, narrow) \
f(in_T, out_T, W_T, 1728, narrow) \
f(in_T, out_T, W_T, 1792, narrow) \
f(in_T, out_T, W_T, 2048, narrow) \
f(in_T, out_T, W_T, 2304, narrow) \
f(in_T, out_T, W_T, 2560, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4608, narrow) \
f(in_T, out_T, W_T, 5120, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 8192, narrow) \
f(in_T, out_T, W_T, 9216, narrow) \
f(in_T, out_T, W_T, 10240, narrow) \
f(in_T, out_T, W_T, 11008, narrow) \
f(in_T, out_T, W_T, 12288, narrow) \
f(in_T, out_T, W_T, 13696, narrow) \
f(in_T, out_T, W_T, 13824, narrow) \
f(in_T, out_T, W_T, 14336, narrow) \
f(in_T, out_T, W_T, 15360, narrow) \
f(in_T, out_T, W_T, 16384, narrow) \
f(in_T, out_T, W_T, 20480, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32512, narrow) \
f(in_T, out_T, W_T, 32768, narrow) \
f(in_T, out_T, W_T, 33024, narrow) \
f(in_T, out_T, W_T, 36864, narrow) \
f(in_T, out_T, W_T, 43264, narrow) \
f(in_T, out_T, W_T, 49152, narrow) \
f(in_T, out_T, W_T, 64000, narrow) \
f(in_T, out_T, W_T, 64256, narrow) \
f(in_T, out_T, W_T, 64512, narrow) \
f(in_T, out_T, W_T, 102400, narrow) \
f(in_T, out_T, W_T, 102656, narrow) \
f(in_T, out_T, W_T, 102912, narrow) \
f(in_T, out_T, W_T, 128000, narrow) \
f(in_T, out_T, W_T, 128256, narrow) \
f(in_T, out_T, W_T, 128512, narrow) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig // Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
f(in_T, out_T, W_T, 8, 64) \
f(in_T, out_T, W_T, 16, 64) \
f(in_T, out_T, W_T, 32, 64) \
f(in_T, out_T, W_T, 64, 64)
// clang-format on // clang-format on

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)

View File

@ -2,3 +2,4 @@
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)

View File

@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
constexpr int tz = 4; constexpr int tz = 4;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
if constexpr (feat_in < feat_out) { if constexpr (feat_in <= feat_out) {
static_assert(feat_in % vec_size == 0); static_assert(feat_in % vec_size == 0);
constexpr int tx = feat_in / vec_size; constexpr int tx = feat_in / vec_size;
@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale); int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \ INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T) INST_BGMV(wide, narrow, in_T, out_T, W_T)

View File

@ -10,6 +10,7 @@ TEMPLATE = """
#include "bgmv_impl.cuh" #include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip() # noqa: E501 """.lstrip() # noqa: E501
for input_dtype in DTYPES: for input_dtype in DTYPES:

View File

@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW(CASE, _, _, _) FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
#undef CASE #undef CASE
#undef CASE_ONESIDE #undef CASE_ONESIDE
default: default:
return false; return false;
} }
return true; return true;
} }

View File

@ -8,6 +8,10 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
@ -524,13 +528,16 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"]) @pytest.mark.parametrize("orientation", ["row", "column"])
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_linear_parallel(dist_init, num_loras, orientation, device) -> None: def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
device) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16) lora_dtype=torch.float16)
def create_random_linear_parallel_layer(): def create_random_linear_parallel_layer():
@ -540,14 +547,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
bias=False, bias=False,
params_dtype=torch.float16) params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = RowParallelLinearWithLoRA(linear) lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
else RowParallelLinearWithShardedLoRA(linear))
else: else:
linear = ColumnParallelLinear(4096, linear = ColumnParallelLinear(4096,
4096, 4096,
bias=False, bias=False,
params_dtype=torch.float16) params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ColumnParallelLinearWithLoRA(linear) lora_linear = (ColumnParallelLinearWithLoRA(linear)
if not fully_shard else
ColumnParallelLinearWithShardedLoRA(linear))
lora_linear.create_lora_weights(max_loras, lora_config) lora_linear.create_lora_weights(max_loras, lora_config)
return linear, lora_linear return linear, lora_linear
@ -629,13 +639,16 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [1, 2, 3]) @pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("fully_shard", [True, False])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
device) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8, max_lora_rank=8,
fully_sharded_loras=fully_shard,
lora_dtype=torch.float16) lora_dtype=torch.float16)
def create_column_parallel_packed_layer(): def create_column_parallel_packed_layer():
@ -644,7 +657,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
bias=False, bias=False,
params_dtype=torch.float16) params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedColumnParallelLinearWithLoRA(linear) lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
if not fully_shard else
MergedColumnParallelLinearWithShardedLoRA(linear))
elif repeats == 3: elif repeats == 3:
linear = QKVParallelLinear(4096, linear = QKVParallelLinear(4096,
64, 64,
@ -652,7 +667,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
bias=False, bias=False,
params_dtype=torch.float16) params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedQKVParallelLinearWithLora(linear) lora_linear = (MergedQKVParallelLinearWithLora(linear)
if not fully_shard else
MergedQKVParallelLinearWithShardedLora(linear))
else: else:
linear = QKVParallelLinear(4096, linear = QKVParallelLinear(4096,
64, 64,

View File

@ -34,11 +34,14 @@ def _lora_ref_impl(
for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): for i, lora_idx in zip(range(bs), indicies.cpu().tolist()):
xi = x[i].unsqueeze(0).to(torch.float32) xi = x[i].unsqueeze(0).to(torch.float32)
wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) if wb_T_all is not None:
wb = wb_T_all[lora_idx, layer_idx].transpose(-1,
-2).to(torch.float32)
tmp = xi @ wa tmp = xi @ wa
y_stage_1[i] = tmp.squeeze(0) y_stage_1[i] = tmp.squeeze(0)
y_final[i] += (tmp @ wb).squeeze(0) * s y_final[i] += ((tmp @ wb).squeeze(0) *
s if wb_T_all is not None else y_stage_1[i])
return y_final, y_stage_1 return y_final, y_stage_1
@ -91,12 +94,56 @@ H1 = H2 = [
128000, 128000,
128256, 128256,
] ]
H2 = [64] + H2
R = [1, 2, 4]
SEED = [0xabcdabcd987] SEED = [0xabcdabcd987]
CUDA_DEVICES = [ CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] ]
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("r", R)
@pytest.mark.parametrize("seed", SEED)
@torch.inference_mode()
def test_lora_a_extra_shapes(dtype_str, h1, r, seed):
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
bs = 32
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")
wa_T_all = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device)
y = torch.randn(bs, r, dtype=dtype, device=device)
y_ref = y.clone()
_lora_ref_impl(
y_ref,
x,
wa_T_all,
None,
indices,
layer_idx,
1.0,
)
y_our = y.clone()
punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0)
assert_close(y_ref, y_our)
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1) @pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2) @pytest.mark.parametrize("h2", H2)

View File

@ -862,6 +862,7 @@ class SpeculativeConfig:
class LoRAConfig: class LoRAConfig:
max_lora_rank: int max_lora_rank: int
max_loras: int max_loras: int
fully_sharded_loras: bool = False
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
lora_dtype: Optional[torch.dtype] = None lora_dtype: Optional[torch.dtype] = None
lora_extra_vocab_size: int = 256 lora_extra_vocab_size: int = 256

View File

@ -52,6 +52,7 @@ class EngineArgs:
enable_lora: bool = False enable_lora: bool = False
max_loras: int = 1 max_loras: int = 1
max_lora_rank: int = 16 max_lora_rank: int = 16
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256 lora_extra_vocab_size: int = 256
lora_dtype = 'auto' lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None max_cpu_loras: Optional[int] = None
@ -376,6 +377,14 @@ class EngineArgs:
help=('Maximum number of LoRAs to store in CPU memory. ' help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. ' 'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.')) 'Defaults to max_num_seqs.'))
parser.add_argument(
'--fully-sharded-loras',
action='store_true',
help=('By default, only half of the LoRA computation is '
'sharded with tensor parallelism. '
'Enabling this will use the fully sharded layers. '
'At high sequence length, max rank or '
'tensor parallel size, this is likely faster.'))
parser.add_argument("--device", parser.add_argument("--device",
type=str, type=str,
default=EngineArgs.device, default=EngineArgs.device,
@ -509,6 +518,7 @@ class EngineArgs:
lora_config = LoRAConfig( lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank, max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras, max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size, lora_extra_vocab_size=self.lora_extra_vocab_size,
lora_dtype=self.lora_dtype, lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras

View File

@ -0,0 +1,262 @@
# pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
RowParallelLinearWithLoRA)
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
if TYPE_CHECKING:
pass
def _fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
return (can_replace(*args, **kwargs)
and kwargs['lora_config'].fully_sharded_loras)
return dec
# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
"""
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked.shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_gather(buffer)
bgmv(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
# now have column partitioned output
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
def _mcp_apply_weights(x, bias, layer):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for QKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
n = len(layer.lora_a_stacked)
output = layer.base_layer.linear_method.apply_weights(
layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device)
for idx in range(n):
bgmv(buffers[idx], x, layer.lora_a_stacked[idx],
layer.indices[:layer.indices_len[0]], 0, 1.0)
buffers = tensor_model_parallel_all_gather(buffers)
left_offset = 0
for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2]
dispatch_bgmv_low_level(output, buffers[idx],
layer.lora_b_stacked[idx],
layer.indices[:layer.indices_len[0]], 0, 1.0,
left_offset, shard_size)
left_offset += shard_size
output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output
class MergedColumnParallelLinearWithShardedLoRA(
MergedColumnParallelLinearWithLoRA):
"""
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[i][:, output_start_idx:output_start_idx + output_shard_size]
for i in range(2)
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
"""
Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[i][:, start_idx[i]:start_idx[i] +
shard_size[i]] if lora_a[i] is not None else None
for i in range(3)
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""
Differs from RowParallelLinearWithLoRA by slicing the
LoRA B's also.
Based on S-LoRA, slicing happens along the output dim.
This yields a combined partial sum from the row parallel base
layer and column partitioned output from the LoRA.
"""
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
dtype=torch.float32,
device=x.device)
bgmv(buffer, x, self.lora_a_stacked,
self.indices[:self.indices_len[0]], 0, 1.0)
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked,
self.indices[:self.indices_len[0]], 0, 1.0,
start_idx, shard_size)
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)

View File

@ -1,8 +1,7 @@
# pylint: disable=unused-argument # pylint: disable=unused-argument
import inspect
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type from typing import TYPE_CHECKING, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -16,6 +15,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
tensor_model_parallel_gather) tensor_model_parallel_gather)
from vllm.distributed.utils import divide
from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
@ -23,7 +23,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) VocabParallelEmbedding)
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
raise ValueError(f"Unsupported base layer: {base_layer}") raise ValueError(f"Unsupported base layer: {base_layer}")
def _not_fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of not using fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True
condition = (not kwargs['lora_config'].fully_sharded_loras
if decorate else True)
return can_replace(*args, **kwargs) and condition
return dec
def _apply_lora( def _apply_lora(
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: torch.Tensor, lora_a_stacked: torch.Tensor,
@ -130,6 +145,14 @@ class LoRAMapping:
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
"""Slice lora a if splitting for tensor parallelism."""
...
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
"""Slice lora b if splitting with tensor parallelism."""
...
def create_lora_weights( def create_lora_weights(
self, self,
max_loras: int, max_loras: int,
@ -317,6 +340,11 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
"""
LoRA on top of ColumnParallelLinear layer.
LoRA B is sliced for tensor parallelism.
"""
def __init__(self, base_layer: ColumnParallelLinear) -> None: def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__() super().__init__()
@ -331,10 +359,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size()
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
self.lora_a_stacked = torch.zeros( self.lora_a_stacked = torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
@ -357,6 +390,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0 self.lora_b_stacked[index] = 0
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def set_lora( def set_lora(
self, self,
index: int, index: int,
@ -365,12 +409,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
): ):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank() lora_a = self.slice_lora_a(lora_a)
shard_size = self.output_dim lora_b = self.slice_lora_b(lora_b)
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
self.lora_a_stacked[index, self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True) lora_a.T, non_blocking=True)
@ -426,6 +469,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
return output, output_bias return output, output_bias
@classmethod @classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
@ -451,6 +495,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
n_slices = 2 n_slices = 2
if not (len(self.base_layer.output_sizes) == n_slices if not (len(self.base_layer.output_sizes) == n_slices
and self.base_layer.output_sizes[0] and self.base_layer.output_sizes[0]
@ -459,12 +504,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"LoRAColumnParallelLinear2Slice requires 2 slices with " "LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size.") "the same size.")
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
self.lora_a_stacked = tuple( self.lora_a_stacked = tuple(
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
@ -489,6 +539,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0 self.lora_b_stacked[1][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = [
lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx]
]
return lora_b
def set_lora( def set_lora(
self, self,
index: int, index: int,
@ -499,13 +561,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank() lora_a = self.slice_lora_a(lora_a)
shard_size = self.output_dim lora_b = self.slice_lora_b(lora_b)
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[0][:,
start_idx:end_idx], lora_b[1][:,
start_idx:end_idx]
if lora_a[0] is not None: if lora_a[0] is not None:
self.lora_a_stacked[0][ self.lora_a_stacked[0][
@ -536,6 +593,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
return output return output
@classmethod @classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
@ -627,21 +685,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads * self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size) self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads * self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size) self.base_layer.head_size)
self.q_shard_id = tp_rank self.q_shard_id = self.tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
# q, k, v # q, k, v
self.lora_a_stacked = ( self.lora_a_stacked = (
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
@ -649,7 +711,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
@ -657,7 +719,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_a_output_size_per_partition,
self.input_size, self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
@ -705,6 +767,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.lora_a_stacked[2][index] = 0 self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
lora_b = [lora_b_q, lora_b_k, lora_b_v]
return lora_b
def set_lora( def set_lora(
self, self,
index: int, index: int,
@ -715,40 +796,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.reset_lora(index) self.reset_lora(index)
if self.tp_size > 1: if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if lora_b[0] is not None: if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size * lora_b_q = lora_b[0]
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
self.lora_b_stacked[0][ self.lora_b_stacked[0][
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
lora_b_q.T, non_blocking=True) lora_b_q.T, non_blocking=True)
if lora_b[1] is not None: if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size * lora_b_k = lora_b[1]
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
self.lora_b_stacked[1][ self.lora_b_stacked[1][
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
lora_b_k.T, non_blocking=True) lora_b_k.T, non_blocking=True)
if lora_b[2] is not None: if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size * lora_b_v = lora_b[2]
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
self.lora_b_stacked[2][ self.lora_b_stacked[2][
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
lora_b_v.T, non_blocking=True) lora_b_v.T, non_blocking=True)
else:
if lora_b[0] is not None:
self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True)
if lora_b[1] is not None:
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
if lora_b[2] is not None:
self.lora_b_stacked[2][
index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
lora_b[2].T, non_blocking=True)
if lora_a[0] is not None: if lora_a[0] is not None:
self.lora_a_stacked[0][ self.lora_a_stacked[0][
@ -777,6 +842,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
return output return output
@classmethod @classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
@ -798,6 +864,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras: int, max_loras: int,
lora_config: LoRAConfig, lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None: model_config: Optional[PretrainedConfig] = None) -> None:
self.lora_config = lora_config
self.tp_rank = get_tensor_model_parallel_rank()
self.lora_a_stacked = torch.zeros( self.lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,
@ -808,11 +876,16 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) )
tp_size = get_tensor_model_parallel_world_size()
lora_b_output_size_per_partition = (
self.output_size if not lora_config.fully_sharded_loras else
divide(self.output_size, tp_size))
self.lora_b_stacked = torch.zeros( self.lora_b_stacked = torch.zeros(
( (
max_loras, max_loras,
1, 1,
self.output_size, lora_b_output_size_per_partition,
lora_config.max_lora_rank, lora_config.max_lora_rank,
), ),
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
@ -826,6 +899,17 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.lora_a_stacked[index] = 0 self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0 self.lora_b_stacked[index] = 0
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.input_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
return lora_b
def set_lora( def set_lora(
self, self,
index: int, index: int,
@ -834,12 +918,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor], embeddings_tensor: Optional[torch.Tensor],
): ):
self.reset_lora(index) self.reset_lora(index)
if self.base_layer.tp_size > 1: if self.base_layer.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank() lora_a = self.slice_lora_a(lora_a)
shard_size = self.input_size lora_b = self.slice_lora_b(lora_b)
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
self.lora_a_stacked[index, self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_( 0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
@ -915,6 +997,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.base_layer, "weight") else self.base_layer.qweight self.base_layer, "weight") else self.base_layer.qweight
@classmethod @classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
@ -1096,37 +1179,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
# Special handling for the LogitsProcessor. # Special handling for the LogitsProcessor.
return False return False
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
cls
for cls in globals().values() if inspect.isclass(cls)
and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
}
def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
model_config):
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
def from_layer_logits_processor(
layer: LogitsProcessor,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.dtype, lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret

View File

@ -11,10 +11,10 @@ from torch import nn
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer, from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
from_layer_logits_processor)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.utils import LRUCache, is_pin_memory_available from vllm.utils import LRUCache, is_pin_memory_available
logger = init_logger(__name__) logger = init_logger(__name__)

View File

@ -49,6 +49,49 @@ def bgmv(
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, indicies: torch.LongTensor,
layer_idx: int, scale: float, y_offset: int,
y_slice_size: int):
"""
Same as `bgmv` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
all of the transposed LoRA matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
punica_kernels.dispatch_bgmv_low_level(
y,
x,
w_t_all,
indicies,
layer_idx,
scale,
x.size(1),
y_slice_size,
y_offset,
)
def add_lora(y: torch.Tensor, def add_lora(y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
wa_t_all: torch.Tensor, wa_t_all: torch.Tensor,

View File

@ -1,11 +1,69 @@
from typing import Tuple from typing import List, Optional, Set, Tuple, Type
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__) logger = init_logger(__name__)
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA,
LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA
}
def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
# specifying kwargs so they can be easily accessed in decorator
if lora_cls.can_replace_layer(source_layer=layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config):
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer
def from_layer_logits_processor(
layer: LogitsProcessor,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.dtype, lm_head.weight.device)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
def replace_submodule(model: nn.Module, module_name: str, def replace_submodule(model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module: new_module: nn.Module) -> nn.Module: