mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 06:31:20 +08:00
[Misc][LoRA] Refactor and clean MergedQKVParallelLinearWithLora implementation (#10958)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
f13cf9ad50
commit
b26b4cd03c
@ -542,10 +542,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
Both slices must have the same size.
|
Both slices must have the same size.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
|
def __init__(
|
||||||
|
self, base_layer: Union[MergedColumnParallelLinear,
|
||||||
|
QKVParallelLinear]) -> None:
|
||||||
super().__init__(base_layer)
|
super().__init__(base_layer)
|
||||||
# There are two LoRA layers
|
# There are two LoRA layers
|
||||||
self.n_slices = len(self.base_layer.output_sizes)
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
|
||||||
|
# we need to divide it by the tp_size to get correct slices size
|
||||||
|
output_sizes = self.base_layer.output_sizes
|
||||||
|
self.output_slices = tuple(
|
||||||
|
divide(output_size, self.tp_size) for output_size in output_sizes)
|
||||||
|
self.n_slices = len(self.output_slices)
|
||||||
|
self.output_ids = (self.tp_rank, ) * self.n_slices
|
||||||
|
|
||||||
def create_lora_weights(
|
def create_lora_weights(
|
||||||
self,
|
self,
|
||||||
@ -559,15 +569,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
"""
|
"""
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
|
||||||
if not (len(self.base_layer.output_sizes) == self.n_slices == 2
|
|
||||||
and self.base_layer.output_sizes[0]
|
|
||||||
== self.base_layer.output_sizes[1]):
|
|
||||||
raise ValueError(
|
|
||||||
"LoRAColumnParallelLinear2Slice requires 2 slices with "
|
|
||||||
"the same size.")
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
lora_a_output_size_per_partition = (
|
lora_a_output_size_per_partition = (
|
||||||
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
|
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
|
||||||
else divide(lora_config.max_lora_rank, self.tp_size))
|
else divide(lora_config.max_lora_rank, self.tp_size))
|
||||||
@ -585,22 +586,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
torch.zeros(
|
torch.zeros(
|
||||||
max_loras,
|
max_loras,
|
||||||
1,
|
1,
|
||||||
self.output_size // 2,
|
output_size,
|
||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
) for _ in range(self.n_slices))
|
) for output_size in self.output_slices)
|
||||||
if lora_config.bias_enabled:
|
if lora_config.bias_enabled:
|
||||||
self.lora_bias_stacked = tuple(
|
self.lora_bias_stacked = tuple(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
max_loras,
|
max_loras,
|
||||||
1,
|
1,
|
||||||
self.output_size // 2,
|
output_size,
|
||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
) for _ in range(self.n_slices))
|
) for output_size in self.output_slices)
|
||||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
|
||||||
self.output_slices = (self.output_dim, self.output_dim)
|
|
||||||
|
|
||||||
def slice_lora_a(
|
def slice_lora_a(
|
||||||
self, lora_a: List[Union[torch.Tensor, None]]
|
self, lora_a: List[Union[torch.Tensor, None]]
|
||||||
@ -610,27 +609,21 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
def slice_lora_b(
|
def slice_lora_b(
|
||||||
self, lora_b: List[Union[torch.Tensor, None]]
|
self, lora_b: List[Union[torch.Tensor, None]]
|
||||||
) -> List[Union[torch.Tensor, None]]:
|
) -> List[Union[torch.Tensor, None]]:
|
||||||
#NOTE: lora_b contains 2 subloras, and each sublora could be None.
|
for i, (shard_id, shard_size) in enumerate(
|
||||||
shard_size = self.output_dim
|
zip(self.output_ids, self.output_slices)):
|
||||||
start_idx = self.tp_rank * shard_size
|
if (lora_b_i := lora_b[i]) is not None:
|
||||||
end_idx = (self.tp_rank + 1) * shard_size
|
lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
|
||||||
lora_b = [
|
(shard_id + 1)]
|
||||||
lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None,
|
|
||||||
lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None,
|
|
||||||
]
|
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
def slice_bias(
|
def slice_bias(
|
||||||
self, bias: List[Union[torch.Tensor,
|
self, bias: List[Union[torch.Tensor,
|
||||||
None]]) -> List[Union[torch.Tensor, None]]:
|
None]]) -> List[Union[torch.Tensor, None]]:
|
||||||
# NOTE : each bias could be None.
|
for i, (shard_id, shard_size) in enumerate(
|
||||||
shard_size = self.output_dim
|
zip(self.output_ids, self.output_slices)):
|
||||||
start_idx = self.tp_rank * shard_size
|
if (bias_i := bias[i]) is not None:
|
||||||
end_idx = (self.tp_rank + 1) * shard_size
|
bias[i] = bias_i[shard_size * shard_id:shard_size *
|
||||||
bias = [
|
(shard_id + 1)]
|
||||||
bias[0][start_idx:end_idx] if bias[0] is not None else None,
|
|
||||||
bias[1][start_idx:end_idx] if bias[1] is not None else None
|
|
||||||
]
|
|
||||||
return bias
|
return bias
|
||||||
|
|
||||||
def set_lora(
|
def set_lora(
|
||||||
@ -649,30 +642,25 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
if lora_bias is not None:
|
if lora_bias is not None:
|
||||||
lora_bias = self.slice_bias(lora_bias)
|
lora_bias = self.slice_bias(lora_bias)
|
||||||
|
|
||||||
if lora_a[0] is not None:
|
for i in range(self.n_slices):
|
||||||
self.lora_a_stacked[0][
|
if (lora_a_i := lora_a[i]) is not None:
|
||||||
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
|
self.lora_a_stacked[i][
|
||||||
lora_a[0].T, non_blocking=True)
|
index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
|
||||||
self.lora_b_stacked[0][
|
lora_a_i.T, non_blocking=True)
|
||||||
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
if (lora_b_i := lora_b[i]) is not None:
|
||||||
lora_b[0].T, non_blocking=True)
|
self.lora_b_stacked[i][
|
||||||
if lora_bias is not None and lora_bias[0] is not None:
|
index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
|
||||||
|
lora_b_i.T, non_blocking=True)
|
||||||
|
|
||||||
|
if lora_bias is not None:
|
||||||
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
||||||
self.lora_bias_stacked)
|
self.lora_bias_stacked)
|
||||||
self.lora_bias_stacked[0][index, 0, :lora_bias[0].shape[0]].copy_(
|
for i in range(self.n_slices):
|
||||||
lora_bias[0].T, non_blocking=True)
|
if (lora_bias_i := lora_bias[i]) is not None:
|
||||||
if lora_a[1] is not None:
|
self.lora_bias_stacked[i][index,
|
||||||
self.lora_a_stacked[1][
|
0, :lora_bias_i.shape[0]].copy_(
|
||||||
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
lora_bias_i.T,
|
||||||
lora_a[1].T, non_blocking=True)
|
non_blocking=True)
|
||||||
self.lora_b_stacked[1][
|
|
||||||
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
|
||||||
lora_b[1].T, non_blocking=True)
|
|
||||||
if lora_bias is not None and lora_bias[1] is not None:
|
|
||||||
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
|
||||||
self.lora_bias_stacked)
|
|
||||||
self.lora_bias_stacked[1][index, 0, :lora_bias[1].shape[0]].copy_(
|
|
||||||
lora_bias[1].T, non_blocking=True)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_not_fully_sharded_can_replace
|
@_not_fully_sharded_can_replace
|
||||||
@ -755,8 +743,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
packed_modules_list) == 1
|
packed_modules_list) == 1
|
||||||
|
|
||||||
|
|
||||||
class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA):
|
||||||
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
|
"""MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
|
||||||
packed together in qkv proj fashion
|
packed together in qkv proj fashion
|
||||||
(q_proj + k_proj + v_proj -> qkv_proj).
|
(q_proj + k_proj + v_proj -> qkv_proj).
|
||||||
|
|
||||||
@ -773,6 +761,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
self.q_proj_shard_size = (self.base_layer.num_heads *
|
||||||
|
self.base_layer.head_size)
|
||||||
|
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
||||||
|
self.base_layer.head_size)
|
||||||
|
self.q_shard_id = self.tp_rank
|
||||||
|
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
|
||||||
|
|
||||||
|
self.output_slices = (
|
||||||
|
self.q_proj_shard_size,
|
||||||
|
self.kv_proj_shard_size,
|
||||||
|
self.kv_proj_shard_size,
|
||||||
|
)
|
||||||
|
self.output_ids = (
|
||||||
|
self.q_shard_id,
|
||||||
|
self.kv_shard_id,
|
||||||
|
self.kv_shard_id,
|
||||||
|
)
|
||||||
|
|
||||||
def create_lora_weights(
|
def create_lora_weights(
|
||||||
self,
|
self,
|
||||||
max_loras: int,
|
max_loras: int,
|
||||||
@ -783,216 +789,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
The main reason for overloading this function is to handle inconsistent
|
The main reason for overloading this function is to handle inconsistent
|
||||||
weight dimensions in qkv lora.
|
weight dimensions in qkv lora.
|
||||||
"""
|
"""
|
||||||
self.lora_config = lora_config
|
super().create_lora_weights(max_loras, lora_config, model_config)
|
||||||
|
|
||||||
if not (len(self.base_layer.output_sizes) == self.n_slices == 3):
|
|
||||||
raise ValueError(
|
|
||||||
"LoRAColumnParallelLinear3Slice requires 3 slices.")
|
|
||||||
|
|
||||||
self.q_proj_shard_size = (self.base_layer.num_heads *
|
|
||||||
self.base_layer.head_size)
|
|
||||||
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
|
||||||
self.base_layer.head_size)
|
|
||||||
self.q_shard_id = self.tp_rank
|
|
||||||
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
|
|
||||||
|
|
||||||
lora_a_output_size_per_partition = (
|
|
||||||
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
|
|
||||||
else divide(lora_config.max_lora_rank, self.tp_size))
|
|
||||||
# q, k, v
|
|
||||||
self.lora_a_stacked = (
|
|
||||||
torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
lora_a_output_size_per_partition,
|
|
||||||
self.input_size,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
lora_a_output_size_per_partition,
|
|
||||||
self.input_size,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
lora_a_output_size_per_partition,
|
|
||||||
self.input_size,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.lora_b_stacked = (
|
|
||||||
torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
self.q_proj_shard_size,
|
|
||||||
lora_config.max_lora_rank,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
self.kv_proj_shard_size,
|
|
||||||
lora_config.max_lora_rank,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
self.kv_proj_shard_size,
|
|
||||||
lora_config.max_lora_rank,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if lora_config.bias_enabled:
|
|
||||||
self.lora_bias_stacked = (
|
|
||||||
torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
self.q_proj_shard_size,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
self.kv_proj_shard_size,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
self.kv_proj_shard_size,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.output_slices = (
|
|
||||||
self.q_proj_shard_size,
|
|
||||||
self.kv_proj_shard_size,
|
|
||||||
self.kv_proj_shard_size,
|
|
||||||
)
|
|
||||||
self.packed_indices: Optional[torch.Tensor] = None
|
|
||||||
self.standard_indices: Optional[torch.Tensor] = None
|
|
||||||
# lazily initialized.
|
|
||||||
self.indices: torch.Tensor
|
|
||||||
self.indices_len: List[int]
|
|
||||||
|
|
||||||
def slice_lora_a(
|
|
||||||
self, lora_a: List[Union[torch.Tensor, None]]
|
|
||||||
) -> List[Union[torch.Tensor, None]]:
|
|
||||||
return lora_a
|
|
||||||
|
|
||||||
def slice_lora_b(
|
|
||||||
self, lora_b: List[Union[torch.Tensor, None]]
|
|
||||||
) -> List[Union[torch.Tensor, None]]:
|
|
||||||
lora_b_q, lora_b_k, lora_b_v = None, None, None
|
|
||||||
if lora_b[0] is not None:
|
|
||||||
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
|
|
||||||
self.q_shard_id:self.q_proj_shard_size *
|
|
||||||
(self.q_shard_id + 1), ]
|
|
||||||
if lora_b[1] is not None:
|
|
||||||
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
|
|
||||||
self.kv_shard_id:self.kv_proj_shard_size *
|
|
||||||
(self.kv_shard_id + 1), ]
|
|
||||||
if lora_b[2] is not None:
|
|
||||||
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
|
|
||||||
self.kv_shard_id:self.kv_proj_shard_size *
|
|
||||||
(self.kv_shard_id + 1), ]
|
|
||||||
lora_b = [lora_b_q, lora_b_k, lora_b_v]
|
|
||||||
return lora_b
|
|
||||||
|
|
||||||
def slice_bias(
|
|
||||||
self, bias: List[Union[torch.Tensor,
|
|
||||||
None]]) -> List[Union[torch.Tensor, None]]:
|
|
||||||
bias_q, bias_k, bias_v = bias
|
|
||||||
if bias_q is not None:
|
|
||||||
bias_q = bias_q[self.q_proj_shard_size *
|
|
||||||
self.q_shard_id:self.q_proj_shard_size *
|
|
||||||
(self.q_shard_id + 1)]
|
|
||||||
if bias_k is not None:
|
|
||||||
bias_k = bias_k[self.kv_proj_shard_size *
|
|
||||||
self.kv_shard_id:self.kv_proj_shard_size *
|
|
||||||
(self.kv_shard_id + 1)]
|
|
||||||
if bias_v is not None:
|
|
||||||
bias_v = bias_v[self.kv_proj_shard_size *
|
|
||||||
self.kv_shard_id:self.kv_proj_shard_size *
|
|
||||||
(self.kv_shard_id + 1)]
|
|
||||||
bias = [bias_q, bias_k, bias_v]
|
|
||||||
return bias
|
|
||||||
|
|
||||||
def set_lora(
|
|
||||||
self,
|
|
||||||
index: int,
|
|
||||||
lora_a: torch.Tensor,
|
|
||||||
lora_b: torch.Tensor,
|
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
|
||||||
lora_bias: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
self.reset_lora(index)
|
|
||||||
|
|
||||||
if self.tp_size > 1:
|
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
|
||||||
if lora_bias is not None:
|
|
||||||
lora_bias = self.slice_bias(lora_bias)
|
|
||||||
|
|
||||||
if lora_b[0] is not None:
|
|
||||||
lora_b_q = lora_b[0]
|
|
||||||
self.lora_b_stacked[0][
|
|
||||||
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
|
|
||||||
lora_b_q.T, non_blocking=True)
|
|
||||||
if lora_b[1] is not None:
|
|
||||||
lora_b_k = lora_b[1]
|
|
||||||
self.lora_b_stacked[1][
|
|
||||||
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
|
|
||||||
lora_b_k.T, non_blocking=True)
|
|
||||||
if lora_b[2] is not None:
|
|
||||||
lora_b_v = lora_b[2]
|
|
||||||
self.lora_b_stacked[2][
|
|
||||||
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
|
|
||||||
lora_b_v.T, non_blocking=True)
|
|
||||||
|
|
||||||
if lora_a[0] is not None:
|
|
||||||
self.lora_a_stacked[0][
|
|
||||||
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
|
|
||||||
lora_a[0].T, non_blocking=True)
|
|
||||||
if lora_a[1] is not None:
|
|
||||||
self.lora_a_stacked[1][
|
|
||||||
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
|
||||||
lora_a[1].T, non_blocking=True)
|
|
||||||
if lora_a[2] is not None:
|
|
||||||
self.lora_a_stacked[2][
|
|
||||||
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
|
|
||||||
lora_a[2].T, non_blocking=True)
|
|
||||||
|
|
||||||
if lora_bias is not None:
|
|
||||||
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
|
||||||
self.lora_bias_stacked)
|
|
||||||
if lora_bias[0] is not None:
|
|
||||||
self.lora_bias_stacked[0][index,
|
|
||||||
0, :lora_bias[0].shape[0]].copy_(
|
|
||||||
lora_bias[0].T,
|
|
||||||
non_blocking=True)
|
|
||||||
if lora_bias[1] is not None:
|
|
||||||
self.lora_bias_stacked[1][index,
|
|
||||||
0, :lora_bias[1].shape[0]].copy_(
|
|
||||||
lora_bias[1].T,
|
|
||||||
non_blocking=True)
|
|
||||||
if lora_bias[2] is not None:
|
|
||||||
self.lora_bias_stacked[2][index,
|
|
||||||
0, :lora_bias[2].shape[0]].copy_(
|
|
||||||
lora_bias[2].T,
|
|
||||||
non_blocking=True)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_not_fully_sharded_can_replace
|
@_not_fully_sharded_can_replace
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user