[Misc][LoRA] Refactor and clean MergedQKVParallelLinearWithLora implementation (#10958)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-12-07 18:33:49 +08:00 committed by GitHub
parent f13cf9ad50
commit b26b4cd03c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -542,10 +542,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
Both slices must have the same size. Both slices must have the same size.
""" """
def __init__(self, base_layer: MergedColumnParallelLinear) -> None: def __init__(
self, base_layer: Union[MergedColumnParallelLinear,
QKVParallelLinear]) -> None:
super().__init__(base_layer) super().__init__(base_layer)
# There are two LoRA layers # There are two LoRA layers
self.n_slices = len(self.base_layer.output_sizes) self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
# we need to divide it by the tp_size to get correct slices size
output_sizes = self.base_layer.output_sizes
self.output_slices = tuple(
divide(output_size, self.tp_size) for output_size in output_sizes)
self.n_slices = len(self.output_slices)
self.output_ids = (self.tp_rank, ) * self.n_slices
def create_lora_weights( def create_lora_weights(
self, self,
@ -559,15 +569,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
""" """
self.lora_config = lora_config self.lora_config = lora_config
if not (len(self.base_layer.output_sizes) == self.n_slices == 2
and self.base_layer.output_sizes[0]
== self.base_layer.output_sizes[1]):
raise ValueError(
"LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size.")
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
lora_a_output_size_per_partition = ( lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size)) else divide(lora_config.max_lora_rank, self.tp_size))
@ -585,22 +586,20 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
self.output_size // 2, output_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) for _ in range(self.n_slices)) ) for output_size in self.output_slices)
if lora_config.bias_enabled: if lora_config.bias_enabled:
self.lora_bias_stacked = tuple( self.lora_bias_stacked = tuple(
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
self.output_size // 2, output_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.device, device=self.device,
) for _ in range(self.n_slices)) ) for output_size in self.output_slices)
self.output_dim = self.lora_b_stacked[0].shape[2]
self.output_slices = (self.output_dim, self.output_dim)
def slice_lora_a( def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]] self, lora_a: List[Union[torch.Tensor, None]]
@ -610,27 +609,21 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def slice_lora_b( def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]] self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]: ) -> List[Union[torch.Tensor, None]]:
#NOTE: lora_b contains 2 subloras, and each sublora could be None. for i, (shard_id, shard_size) in enumerate(
shard_size = self.output_dim zip(self.output_ids, self.output_slices)):
start_idx = self.tp_rank * shard_size if (lora_b_i := lora_b[i]) is not None:
end_idx = (self.tp_rank + 1) * shard_size lora_b[i] = lora_b_i[:, shard_size * shard_id:shard_size *
lora_b = [ (shard_id + 1)]
lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None,
lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None,
]
return lora_b return lora_b
def slice_bias( def slice_bias(
self, bias: List[Union[torch.Tensor, self, bias: List[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]: None]]) -> List[Union[torch.Tensor, None]]:
# NOTE : each bias could be None. for i, (shard_id, shard_size) in enumerate(
shard_size = self.output_dim zip(self.output_ids, self.output_slices)):
start_idx = self.tp_rank * shard_size if (bias_i := bias[i]) is not None:
end_idx = (self.tp_rank + 1) * shard_size bias[i] = bias_i[shard_size * shard_id:shard_size *
bias = [ (shard_id + 1)]
bias[0][start_idx:end_idx] if bias[0] is not None else None,
bias[1][start_idx:end_idx] if bias[1] is not None else None
]
return bias return bias
def set_lora( def set_lora(
@ -649,30 +642,25 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if lora_bias is not None: if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias) lora_bias = self.slice_bias(lora_bias)
if lora_a[0] is not None: for i in range(self.n_slices):
self.lora_a_stacked[0][ if (lora_a_i := lora_a[i]) is not None:
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( self.lora_a_stacked[i][
lora_a[0].T, non_blocking=True) index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
self.lora_b_stacked[0][ lora_a_i.T, non_blocking=True)
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( if (lora_b_i := lora_b[i]) is not None:
lora_b[0].T, non_blocking=True) self.lora_b_stacked[i][
if lora_bias is not None and lora_bias[0] is not None: index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
lora_b_i.T, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...], self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked) self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, :lora_bias[0].shape[0]].copy_( for i in range(self.n_slices):
lora_bias[0].T, non_blocking=True) if (lora_bias_i := lora_bias[i]) is not None:
if lora_a[1] is not None: self.lora_bias_stacked[i][index,
self.lora_a_stacked[1][ 0, :lora_bias_i.shape[0]].copy_(
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( lora_bias_i.T,
lora_a[1].T, non_blocking=True) non_blocking=True)
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
if lora_bias is not None and lora_bias[1] is not None:
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked)
self.lora_bias_stacked[1][index, 0, :lora_bias[1].shape[0]].copy_(
lora_bias[1].T, non_blocking=True)
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace
@ -755,8 +743,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
packed_modules_list) == 1 packed_modules_list) == 1
class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): class MergedQKVParallelLinearWithLora(MergedColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices) """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj). (q_proj + k_proj + v_proj -> qkv_proj).
@ -773,6 +761,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
self.output_slices = (
self.q_proj_shard_size,
self.kv_proj_shard_size,
self.kv_proj_shard_size,
)
self.output_ids = (
self.q_shard_id,
self.kv_shard_id,
self.kv_shard_id,
)
def create_lora_weights( def create_lora_weights(
self, self,
max_loras: int, max_loras: int,
@ -783,216 +789,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
The main reason for overloading this function is to handle inconsistent The main reason for overloading this function is to handle inconsistent
weight dimensions in qkv lora. weight dimensions in qkv lora.
""" """
self.lora_config = lora_config super().create_lora_weights(max_loras, lora_config, model_config)
if not (len(self.base_layer.output_sizes) == self.n_slices == 3):
raise ValueError(
"LoRAColumnParallelLinear3Slice requires 3 slices.")
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_a_output_size_per_partition = (
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size))
# q, k, v
self.lora_a_stacked = (
torch.zeros(
max_loras,
1,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
)
self.lora_b_stacked = (
torch.zeros(
max_loras,
1,
self.q_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
),
)
if lora_config.bias_enabled:
self.lora_bias_stacked = (
torch.zeros(
max_loras,
1,
self.q_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
)
self.output_slices = (
self.q_proj_shard_size,
self.kv_proj_shard_size,
self.kv_proj_shard_size,
)
self.packed_indices: Optional[torch.Tensor] = None
self.standard_indices: Optional[torch.Tensor] = None
# lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
return lora_a
def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
lora_b_q, lora_b_k, lora_b_v = None, None, None
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1), ]
if lora_b[1] is not None:
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1), ]
if lora_b[2] is not None:
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1), ]
lora_b = [lora_b_q, lora_b_k, lora_b_v]
return lora_b
def slice_bias(
self, bias: List[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]:
bias_q, bias_k, bias_v = bias
if bias_q is not None:
bias_q = bias_q[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
if bias_k is not None:
bias_k = bias_k[self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
if bias_v is not None:
bias_v = bias_v[self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
bias = [bias_q, bias_k, bias_v]
return bias
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
lora_bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if lora_bias is not None:
lora_bias = self.slice_bias(lora_bias)
if lora_b[0] is not None:
lora_b_q = lora_b[0]
self.lora_b_stacked[0][
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
lora_b_q.T, non_blocking=True)
if lora_b[1] is not None:
lora_b_k = lora_b[1]
self.lora_b_stacked[1][
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
lora_b_k.T, non_blocking=True)
if lora_b[2] is not None:
lora_b_v = lora_b[2]
self.lora_b_stacked[2][
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
lora_b_v.T, non_blocking=True)
if lora_a[0] is not None:
self.lora_a_stacked[0][
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
lora_a[0].T, non_blocking=True)
if lora_a[1] is not None:
self.lora_a_stacked[1][
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
lora_a[1].T, non_blocking=True)
if lora_a[2] is not None:
self.lora_a_stacked[2][
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
self.lora_bias_stacked)
if lora_bias[0] is not None:
self.lora_bias_stacked[0][index,
0, :lora_bias[0].shape[0]].copy_(
lora_bias[0].T,
non_blocking=True)
if lora_bias[1] is not None:
self.lora_bias_stacked[1][index,
0, :lora_bias[1].shape[0]].copy_(
lora_bias[1].T,
non_blocking=True)
if lora_bias[2] is not None:
self.lora_bias_stacked[2][index,
0, :lora_bias[2].shape[0]].copy_(
lora_bias[2].T,
non_blocking=True)
@classmethod @classmethod
@_not_fully_sharded_can_replace @_not_fully_sharded_can_replace