mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:45:00 +08:00
[Kernel] Full Tensor Parallelism for LoRA Layers (#3524)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
18d23f642a
commit
eefeb16464
@ -2,3 +2,4 @@
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
||||
|
||||
@ -2,3 +2,4 @@
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)
|
||||
|
||||
@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
|
||||
// and vllm/tests/lora/test_punica.py
|
||||
|
||||
// Used for defining kernels going from the variety of
|
||||
// dim in to the narrow dim out
|
||||
// Using it for the fully sharded column
|
||||
// parallel LoRA A which splits the rank dim
|
||||
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
|
||||
f(in_T, out_T, W_T, 128, narrow) \
|
||||
f(in_T, out_T, W_T, 256, narrow) \
|
||||
f(in_T, out_T, W_T, 512, narrow) \
|
||||
f(in_T, out_T, W_T, 640, narrow) \
|
||||
f(in_T, out_T, W_T, 768, narrow) \
|
||||
f(in_T, out_T, W_T, 1024, narrow) \
|
||||
f(in_T, out_T, W_T, 1152, narrow) \
|
||||
f(in_T, out_T, W_T, 1280, narrow) \
|
||||
f(in_T, out_T, W_T, 1536, narrow) \
|
||||
f(in_T, out_T, W_T, 1728, narrow) \
|
||||
f(in_T, out_T, W_T, 1792, narrow) \
|
||||
f(in_T, out_T, W_T, 2048, narrow) \
|
||||
f(in_T, out_T, W_T, 2304, narrow) \
|
||||
f(in_T, out_T, W_T, 2560, narrow) \
|
||||
f(in_T, out_T, W_T, 2752, narrow) \
|
||||
f(in_T, out_T, W_T, 2816, narrow) \
|
||||
f(in_T, out_T, W_T, 3072, narrow) \
|
||||
f(in_T, out_T, W_T, 3456, narrow) \
|
||||
f(in_T, out_T, W_T, 3584, narrow) \
|
||||
f(in_T, out_T, W_T, 4096, narrow) \
|
||||
f(in_T, out_T, W_T, 4608, narrow) \
|
||||
f(in_T, out_T, W_T, 5120, narrow) \
|
||||
f(in_T, out_T, W_T, 5504, narrow) \
|
||||
f(in_T, out_T, W_T, 5632, narrow) \
|
||||
f(in_T, out_T, W_T, 6144, narrow) \
|
||||
f(in_T, out_T, W_T, 6848, narrow) \
|
||||
f(in_T, out_T, W_T, 6912, narrow) \
|
||||
f(in_T, out_T, W_T, 7168, narrow) \
|
||||
f(in_T, out_T, W_T, 8192, narrow) \
|
||||
f(in_T, out_T, W_T, 9216, narrow) \
|
||||
f(in_T, out_T, W_T, 10240, narrow) \
|
||||
f(in_T, out_T, W_T, 11008, narrow) \
|
||||
f(in_T, out_T, W_T, 12288, narrow) \
|
||||
f(in_T, out_T, W_T, 13696, narrow) \
|
||||
f(in_T, out_T, W_T, 13824, narrow) \
|
||||
f(in_T, out_T, W_T, 14336, narrow) \
|
||||
f(in_T, out_T, W_T, 15360, narrow) \
|
||||
f(in_T, out_T, W_T, 16384, narrow) \
|
||||
f(in_T, out_T, W_T, 20480, narrow) \
|
||||
f(in_T, out_T, W_T, 22016, narrow) \
|
||||
f(in_T, out_T, W_T, 24576, narrow) \
|
||||
f(in_T, out_T, W_T, 27392, narrow) \
|
||||
f(in_T, out_T, W_T, 28672, narrow) \
|
||||
f(in_T, out_T, W_T, 32000, narrow) \
|
||||
f(in_T, out_T, W_T, 32256, narrow) \
|
||||
f(in_T, out_T, W_T, 32512, narrow) \
|
||||
f(in_T, out_T, W_T, 32768, narrow) \
|
||||
f(in_T, out_T, W_T, 33024, narrow) \
|
||||
f(in_T, out_T, W_T, 36864, narrow) \
|
||||
f(in_T, out_T, W_T, 43264, narrow) \
|
||||
f(in_T, out_T, W_T, 49152, narrow) \
|
||||
f(in_T, out_T, W_T, 64000, narrow) \
|
||||
f(in_T, out_T, W_T, 64256, narrow) \
|
||||
f(in_T, out_T, W_T, 64512, narrow) \
|
||||
f(in_T, out_T, W_T, 102400, narrow) \
|
||||
f(in_T, out_T, W_T, 102656, narrow) \
|
||||
f(in_T, out_T, W_T, 102912, narrow) \
|
||||
f(in_T, out_T, W_T, 128000, narrow) \
|
||||
f(in_T, out_T, W_T, 128256, narrow) \
|
||||
f(in_T, out_T, W_T, 128512, narrow) \
|
||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||
|
||||
|
||||
// Keep this in sync with vllm/config::LoRAConfig
|
||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
|
||||
@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
|
||||
|
||||
|
||||
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
|
||||
f(in_T, out_T, W_T, 8, 64) \
|
||||
f(in_T, out_T, W_T, 16, 64) \
|
||||
f(in_T, out_T, W_T, 32, 64) \
|
||||
f(in_T, out_T, W_T, 64, 64)
|
||||
|
||||
// clang-format on
|
||||
|
||||
@ -2,3 +2,4 @@
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)
|
||||
|
||||
@ -2,3 +2,4 @@
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)
|
||||
|
||||
@ -2,3 +2,4 @@
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)
|
||||
|
||||
@ -2,3 +2,4 @@
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)
|
||||
|
||||
@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
constexpr int tz = 4;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if constexpr (feat_in < feat_out) {
|
||||
if constexpr (feat_in <= feat_out) {
|
||||
static_assert(feat_in % vec_size == 0);
|
||||
constexpr int tx = feat_in / vec_size;
|
||||
|
||||
@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
|
||||
int64_t num_layers, int64_t layer_idx, float scale);
|
||||
|
||||
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
|
||||
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
|
||||
|
||||
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
|
||||
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
|
||||
INST_BGMV(wide, narrow, in_T, out_T, W_T)
|
||||
|
||||
@ -10,6 +10,7 @@ TEMPLATE = """
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
for input_dtype in DTYPES:
|
||||
|
||||
@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
|
||||
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
|
||||
#undef CASE
|
||||
#undef CASE_ONESIDE
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@ -8,6 +8,10 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.fully_sharded_layers import (
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
@ -524,13 +528,16 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("orientation", ["row", "column"])
|
||||
@pytest.mark.parametrize("fully_shard", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
||||
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
device) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
fully_sharded_loras=fully_shard,
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_random_linear_parallel_layer():
|
||||
@ -540,14 +547,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = RowParallelLinearWithLoRA(linear)
|
||||
lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard
|
||||
else RowParallelLinearWithShardedLoRA(linear))
|
||||
else:
|
||||
linear = ColumnParallelLinear(4096,
|
||||
4096,
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = ColumnParallelLinearWithLoRA(linear)
|
||||
lora_linear = (ColumnParallelLinearWithLoRA(linear)
|
||||
if not fully_shard else
|
||||
ColumnParallelLinearWithShardedLoRA(linear))
|
||||
lora_linear.create_lora_weights(max_loras, lora_config)
|
||||
|
||||
return linear, lora_linear
|
||||
@ -629,13 +639,16 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("repeats", [1, 2, 3])
|
||||
@pytest.mark.parametrize("fully_shard", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
device) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
fully_sharded_loras=fully_shard,
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_column_parallel_packed_layer():
|
||||
@ -644,7 +657,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
|
||||
lora_linear = (MergedColumnParallelLinearWithLoRA(linear)
|
||||
if not fully_shard else
|
||||
MergedColumnParallelLinearWithShardedLoRA(linear))
|
||||
elif repeats == 3:
|
||||
linear = QKVParallelLinear(4096,
|
||||
64,
|
||||
@ -652,7 +667,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = MergedQKVParallelLinearWithLora(linear)
|
||||
lora_linear = (MergedQKVParallelLinearWithLora(linear)
|
||||
if not fully_shard else
|
||||
MergedQKVParallelLinearWithShardedLora(linear))
|
||||
else:
|
||||
linear = QKVParallelLinear(4096,
|
||||
64,
|
||||
|
||||
@ -34,11 +34,14 @@ def _lora_ref_impl(
|
||||
for i, lora_idx in zip(range(bs), indicies.cpu().tolist()):
|
||||
xi = x[i].unsqueeze(0).to(torch.float32)
|
||||
wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
|
||||
wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
|
||||
if wb_T_all is not None:
|
||||
wb = wb_T_all[lora_idx, layer_idx].transpose(-1,
|
||||
-2).to(torch.float32)
|
||||
|
||||
tmp = xi @ wa
|
||||
y_stage_1[i] = tmp.squeeze(0)
|
||||
y_final[i] += (tmp @ wb).squeeze(0) * s
|
||||
y_final[i] += ((tmp @ wb).squeeze(0) *
|
||||
s if wb_T_all is not None else y_stage_1[i])
|
||||
return y_final, y_stage_1
|
||||
|
||||
|
||||
@ -91,12 +94,56 @@ H1 = H2 = [
|
||||
128000,
|
||||
128256,
|
||||
]
|
||||
H2 = [64] + H2
|
||||
R = [1, 2, 4]
|
||||
SEED = [0xabcdabcd987]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
|
||||
@pytest.mark.parametrize("h1", H1)
|
||||
@pytest.mark.parametrize("r", R)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@torch.inference_mode()
|
||||
def test_lora_a_extra_shapes(dtype_str, h1, r, seed):
|
||||
torch.manual_seed(seed)
|
||||
num_loras = 4
|
||||
num_layers = 1
|
||||
bs = 32
|
||||
dtype = getattr(torch, dtype_str)
|
||||
device = torch.device("cuda")
|
||||
|
||||
wa_T_all = torch.randn(num_loras,
|
||||
num_layers,
|
||||
r,
|
||||
h1,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
x = torch.randn(bs, h1, dtype=dtype, device=device)
|
||||
y = torch.randn(bs, r, dtype=dtype, device=device)
|
||||
|
||||
y_ref = y.clone()
|
||||
_lora_ref_impl(
|
||||
y_ref,
|
||||
x,
|
||||
wa_T_all,
|
||||
None,
|
||||
indices,
|
||||
layer_idx,
|
||||
1.0,
|
||||
)
|
||||
|
||||
y_our = y.clone()
|
||||
punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0)
|
||||
|
||||
assert_close(y_ref, y_our)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
|
||||
@pytest.mark.parametrize("h1", H1)
|
||||
@pytest.mark.parametrize("h2", H2)
|
||||
|
||||
@ -862,6 +862,7 @@ class SpeculativeConfig:
|
||||
class LoRAConfig:
|
||||
max_lora_rank: int
|
||||
max_loras: int
|
||||
fully_sharded_loras: bool = False
|
||||
max_cpu_loras: Optional[int] = None
|
||||
lora_dtype: Optional[torch.dtype] = None
|
||||
lora_extra_vocab_size: int = 256
|
||||
|
||||
@ -52,6 +52,7 @@ class EngineArgs:
|
||||
enable_lora: bool = False
|
||||
max_loras: int = 1
|
||||
max_lora_rank: int = 16
|
||||
fully_sharded_loras: bool = False
|
||||
lora_extra_vocab_size: int = 256
|
||||
lora_dtype = 'auto'
|
||||
max_cpu_loras: Optional[int] = None
|
||||
@ -376,6 +377,14 @@ class EngineArgs:
|
||||
help=('Maximum number of LoRAs to store in CPU memory. '
|
||||
'Must be >= than max_num_seqs. '
|
||||
'Defaults to max_num_seqs.'))
|
||||
parser.add_argument(
|
||||
'--fully-sharded-loras',
|
||||
action='store_true',
|
||||
help=('By default, only half of the LoRA computation is '
|
||||
'sharded with tensor parallelism. '
|
||||
'Enabling this will use the fully sharded layers. '
|
||||
'At high sequence length, max rank or '
|
||||
'tensor parallel size, this is likely faster.'))
|
||||
parser.add_argument("--device",
|
||||
type=str,
|
||||
default=EngineArgs.device,
|
||||
@ -509,6 +518,7 @@ class EngineArgs:
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
max_loras=self.max_loras,
|
||||
fully_sharded_loras=self.fully_sharded_loras,
|
||||
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||
lora_dtype=self.lora_dtype,
|
||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||
|
||||
262
vllm/lora/fully_sharded_layers.py
Normal file
262
vllm/lora/fully_sharded_layers.py
Normal file
@ -0,0 +1,262 @@
|
||||
# pylint: disable=unused-argument
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
RowParallelLinearWithLoRA)
|
||||
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def _fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
return (can_replace(*args, **kwargs)
|
||||
and kwargs['lora_config'].fully_sharded_loras)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
# these layers are based on the tensor parallelism strategy given in
|
||||
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
|
||||
# https://arxiv.org/abs/2311.03285.
|
||||
|
||||
|
||||
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.lora_a_stacked.shape[2]
|
||||
start_idx = tp_rank * shard_size
|
||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
||||
return lora_a
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
output.shape[-1]), output.shape
|
||||
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
bgmv(buffer, x, self.lora_a_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
buffer = tensor_model_parallel_all_gather(buffer)
|
||||
bgmv(output, buffer, self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
# now have column partitioned output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
def _mcp_apply_weights(x, bias, layer):
|
||||
"""
|
||||
MergedColumnParallelLinearWithShardedLoRA and
|
||||
QKVParallelLinearWithShardedLora share the same
|
||||
LoRa weight application method.
|
||||
|
||||
The main difference is the step by shard_size for lora_b which can
|
||||
vary for QKVParallelLinearWithShardedLora but is constant for
|
||||
MergedColumnParallelLinearWithShardedLoRA.
|
||||
"""
|
||||
# expecting 2 for column parallel and 3 for qkv
|
||||
n = len(layer.lora_a_stacked)
|
||||
output = layer.base_layer.linear_method.apply_weights(
|
||||
layer.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||
buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
for idx in range(n):
|
||||
bgmv(buffers[idx], x, layer.lora_a_stacked[idx],
|
||||
layer.indices[:layer.indices_len[0]], 0, 1.0)
|
||||
|
||||
buffers = tensor_model_parallel_all_gather(buffers)
|
||||
left_offset = 0
|
||||
for idx in range(n):
|
||||
shard_size = layer.lora_b_stacked[idx].shape[2]
|
||||
dispatch_bgmv_low_level(output, buffers[idx],
|
||||
layer.lora_b_stacked[idx],
|
||||
layer.indices[:layer.indices_len[0]], 0, 1.0,
|
||||
left_offset, shard_size)
|
||||
left_offset += shard_size
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
# now have column partitioned and packed output
|
||||
return output
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithShardedLoRA(
|
||||
MergedColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from MergedColumnParallelLinearWithLoRA by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
output_shard_size = self.lora_a_stacked[0].shape[2]
|
||||
output_start_idx = self.tp_rank * output_shard_size
|
||||
lora_a = [
|
||||
lora_a[i][:, output_start_idx:output_start_idx + output_shard_size]
|
||||
for i in range(2)
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return _mcp_apply_weights(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
||||
"""
|
||||
Differs from QKVParallelLinearWithLora by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
|
||||
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
|
||||
lora_a = [
|
||||
lora_a[i][:, start_idx[i]:start_idx[i] +
|
||||
shard_size[i]] if lora_a[i] is not None else None
|
||||
for i in range(3)
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return _mcp_apply_weights(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from RowParallelLinearWithLoRA by slicing the
|
||||
LoRA B's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the output dim.
|
||||
This yields a combined partial sum from the row parallel base
|
||||
layer and column partitioned output from the LoRA.
|
||||
"""
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
shard_size = self.lora_b_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
return lora_b
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer, x)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
output.shape[-1]), output.shape
|
||||
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
bgmv(buffer, x, self.lora_a_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
|
||||
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
||||
# by adding the column partitioned lora output to a slice of output
|
||||
# tensor, which is a partial sum due to row parallel. All that
|
||||
# remains is a standard all_reduce. User should be aware though that
|
||||
# the output is not the same as a normal row_parallel, it should be
|
||||
# reduced before being used
|
||||
shard_size = self.lora_b_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0,
|
||||
start_idx, shard_size)
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
@ -1,8 +1,7 @@
|
||||
# pylint: disable=unused-argument
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -16,6 +15,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
@ -23,7 +23,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
VocabParallelEmbedding)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
||||
raise ValueError(f"Unsupported base layer: {base_layer}")
|
||||
|
||||
|
||||
def _not_fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of not using fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True
|
||||
condition = (not kwargs['lora_config'].fully_sharded_loras
|
||||
if decorate else True)
|
||||
return can_replace(*args, **kwargs) and condition
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def _apply_lora(
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
@ -130,6 +145,14 @@ class LoRAMapping:
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
"""Slice lora a if splitting for tensor parallelism."""
|
||||
...
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
"""Slice lora b if splitting with tensor parallelism."""
|
||||
...
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
@ -317,6 +340,11 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
"""
|
||||
LoRA on top of ColumnParallelLinear layer.
|
||||
|
||||
LoRA B is sliced for tensor parallelism.
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: ColumnParallelLinear) -> None:
|
||||
super().__init__()
|
||||
@ -331,10 +359,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self.lora_config = lora_config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
lora_a_output_size_per_partition = (
|
||||
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
|
||||
else divide(lora_config.max_lora_rank, self.tp_size))
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
lora_a_output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
@ -357,6 +390,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_dim
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
return lora_b
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
@ -365,12 +409,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_dim
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
@ -426,6 +469,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
return output, output_bias
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
@ -451,6 +495,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self.lora_config = lora_config
|
||||
n_slices = 2
|
||||
if not (len(self.base_layer.output_sizes) == n_slices
|
||||
and self.base_layer.output_sizes[0]
|
||||
@ -459,12 +504,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
"LoRAColumnParallelLinear2Slice requires 2 slices with "
|
||||
"the same size.")
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
lora_a_output_size_per_partition = (
|
||||
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
|
||||
else divide(lora_config.max_lora_rank, self.tp_size))
|
||||
|
||||
self.lora_a_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
lora_a_output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
@ -489,6 +539,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.lora_b_stacked[0][index] = 0
|
||||
self.lora_b_stacked[1][index] = 0
|
||||
|
||||
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
shard_size = self.output_dim
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = [
|
||||
lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx]
|
||||
]
|
||||
return lora_b
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
@ -499,13 +561,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_dim
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_b = lora_b[0][:,
|
||||
start_idx:end_idx], lora_b[1][:,
|
||||
start_idx:end_idx]
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
|
||||
if lora_a[0] is not None:
|
||||
self.lora_a_stacked[0][
|
||||
@ -536,6 +593,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
@ -627,21 +685,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self.lora_config = lora_config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.q_proj_shard_size = (self.base_layer.num_heads *
|
||||
self.base_layer.head_size)
|
||||
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
||||
self.base_layer.head_size)
|
||||
self.q_shard_id = tp_rank
|
||||
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
||||
self.q_shard_id = self.tp_rank
|
||||
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
|
||||
|
||||
lora_a_output_size_per_partition = (
|
||||
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
|
||||
else divide(lora_config.max_lora_rank, self.tp_size))
|
||||
# q, k, v
|
||||
self.lora_a_stacked = (
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
lora_a_output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
@ -649,7 +711,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
lora_a_output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
@ -657,7 +719,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
lora_a_output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
@ -705,6 +767,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
self.lora_a_stacked[2][index] = 0
|
||||
self.lora_b_stacked[2][index] = 0
|
||||
|
||||
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
if lora_b[0] is not None:
|
||||
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
if lora_b[1] is not None:
|
||||
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
if lora_b[2] is not None:
|
||||
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
lora_b = [lora_b_q, lora_b_k, lora_b_v]
|
||||
return lora_b
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
@ -715,40 +796,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
if lora_b[0] is not None:
|
||||
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
|
||||
lora_b_q.T, non_blocking=True)
|
||||
if lora_b[1] is not None:
|
||||
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
|
||||
lora_b_k.T, non_blocking=True)
|
||||
if lora_b[2] is not None:
|
||||
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
self.lora_b_stacked[2][
|
||||
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
|
||||
lora_b_v.T, non_blocking=True)
|
||||
else:
|
||||
if lora_b[0] is not None:
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
||||
lora_b[0].T, non_blocking=True)
|
||||
if lora_b[1] is not None:
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||
lora_b[1].T, non_blocking=True)
|
||||
if lora_b[2] is not None:
|
||||
self.lora_b_stacked[2][
|
||||
index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
|
||||
lora_b[2].T, non_blocking=True)
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
|
||||
if lora_b[0] is not None:
|
||||
lora_b_q = lora_b[0]
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
|
||||
lora_b_q.T, non_blocking=True)
|
||||
if lora_b[1] is not None:
|
||||
lora_b_k = lora_b[1]
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
|
||||
lora_b_k.T, non_blocking=True)
|
||||
if lora_b[2] is not None:
|
||||
lora_b_v = lora_b[2]
|
||||
self.lora_b_stacked[2][
|
||||
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
|
||||
lora_b_v.T, non_blocking=True)
|
||||
|
||||
if lora_a[0] is not None:
|
||||
self.lora_a_stacked[0][
|
||||
@ -777,6 +842,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
@ -798,6 +864,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self.lora_config = lora_config
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
@ -808,11 +876,16 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
lora_b_output_size_per_partition = (
|
||||
self.output_size if not lora_config.fully_sharded_loras else
|
||||
divide(self.output_size, tp_size))
|
||||
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
self.output_size,
|
||||
lora_b_output_size_per_partition,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
@ -826,6 +899,17 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.input_size
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_a = lora_a[start_idx:end_idx, :]
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
return lora_b
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
@ -834,12 +918,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.base_layer.tp_size > 1:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.input_size
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_a = lora_a[start_idx:end_idx, :]
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
@ -915,6 +997,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.base_layer, "weight") else self.base_layer.qweight
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
@ -1096,37 +1179,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# Special handling for the LogitsProcessor.
|
||||
return False
|
||||
|
||||
|
||||
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
|
||||
cls
|
||||
for cls in globals().values() if inspect.isclass(cls)
|
||||
and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
|
||||
}
|
||||
|
||||
|
||||
def from_layer(layer: nn.Module,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
|
||||
for lora_cls in _all_lora_classes:
|
||||
if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
|
||||
model_config):
|
||||
ret = lora_cls(layer)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
return layer
|
||||
|
||||
|
||||
def from_layer_logits_processor(
|
||||
layer: LogitsProcessor,
|
||||
lm_head: ParallelLMHead,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> LogitsProcessorWithLoRA:
|
||||
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
|
||||
lm_head.weight.dtype, lm_head.weight.device)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
|
||||
@ -11,10 +11,10 @@ from torch import nn
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
|
||||
from_layer_logits_processor)
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.utils import LRUCache, is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -49,6 +49,49 @@ def bgmv(
|
||||
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
|
||||
|
||||
|
||||
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
|
||||
w_t_all: torch.Tensor, indicies: torch.LongTensor,
|
||||
layer_idx: int, scale: float, y_offset: int,
|
||||
y_slice_size: int):
|
||||
"""
|
||||
Same as `bgmv` but you can operate on slices of y.
|
||||
Pass whole y, define y_offset and y_slice_size.
|
||||
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
|
||||
all of the transposed LoRA matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
y_offset: Offset to apply to the starting column of y.
|
||||
y_slice_size: Size of the y column slice.
|
||||
"""
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
_raise_import_error(e)
|
||||
punica_kernels.dispatch_bgmv_low_level(
|
||||
y,
|
||||
x,
|
||||
w_t_all,
|
||||
indicies,
|
||||
layer_idx,
|
||||
scale,
|
||||
x.size(1),
|
||||
y_slice_size,
|
||||
y_offset,
|
||||
)
|
||||
|
||||
|
||||
def add_lora(y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
|
||||
@ -1,11 +1,69 @@
|
||||
from typing import Tuple
|
||||
from typing import List, Optional, Set, Tuple, Type
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.fully_sharded_layers import (
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
|
||||
# being imported for _all_lora_classes below
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
QKVParallelLinearWithLora,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
|
||||
VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora,
|
||||
MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA,
|
||||
LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA
|
||||
}
|
||||
|
||||
|
||||
def from_layer(layer: nn.Module,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
|
||||
for lora_cls in _all_lora_classes:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
if lora_cls.can_replace_layer(source_layer=layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config):
|
||||
ret = lora_cls(layer)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
return layer
|
||||
|
||||
|
||||
def from_layer_logits_processor(
|
||||
layer: LogitsProcessor,
|
||||
lm_head: ParallelLMHead,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> LogitsProcessorWithLoRA:
|
||||
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
|
||||
lm_head.weight.dtype, lm_head.weight.device)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
|
||||
|
||||
def replace_submodule(model: nn.Module, module_name: str,
|
||||
new_module: nn.Module) -> nn.Module:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user