mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 00:35:51 +08:00
[Core] Optimize LoRA weight loading (#25403)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
231c2c63e4
commit
273690a50a
@ -164,8 +164,8 @@ def populate_loras(
|
|||||||
weight=layer_weights,
|
weight=layer_weights,
|
||||||
generate_embeddings_tensor=generate_embeddings_tensor,
|
generate_embeddings_tensor=generate_embeddings_tensor,
|
||||||
)
|
)
|
||||||
sublora.lora_b = sublora.lora_b[:, (sublora_len *
|
sublora.lora_b = sublora.lora_b[(sublora_len *
|
||||||
i):(sublora_len * (i + 1))]
|
i):(sublora_len * (i + 1)), :]
|
||||||
sublora.optimize()
|
sublora.optimize()
|
||||||
subloras.append(sublora)
|
subloras.append(sublora)
|
||||||
|
|
||||||
@ -304,9 +304,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
|||||||
result = embedding(input_)
|
result = embedding(input_)
|
||||||
after_a = F.embedding(
|
after_a = F.embedding(
|
||||||
input_,
|
input_,
|
||||||
lora.lora_a,
|
lora.lora_a.T,
|
||||||
)
|
)
|
||||||
result += (after_a @ lora.lora_b)
|
result += (after_a @ lora.lora_b.T)
|
||||||
expected_results.append(result)
|
expected_results.append(result)
|
||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
@ -445,9 +445,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|||||||
result = expanded_embedding(input_)
|
result = expanded_embedding(input_)
|
||||||
after_a = F.embedding(
|
after_a = F.embedding(
|
||||||
original_input_,
|
original_input_,
|
||||||
lora.lora_a,
|
lora.lora_a.T,
|
||||||
)
|
)
|
||||||
result += (after_a @ lora.lora_b)
|
result += (after_a @ lora.lora_b.T)
|
||||||
expected_results.append(result)
|
expected_results.append(result)
|
||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
@ -575,7 +575,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
|||||||
lm_head=linear,
|
lm_head=linear,
|
||||||
embedding_bias=None)
|
embedding_bias=None)
|
||||||
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
|
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
|
||||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
|
||||||
expected_results.append(result)
|
expected_results.append(result)
|
||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
logits_processor.org_vocab_size = vocab_size
|
logits_processor.org_vocab_size = vocab_size
|
||||||
@ -692,9 +692,10 @@ def test_linear_replicated(
|
|||||||
|
|
||||||
expected_results: list[torch.Tensor] = []
|
expected_results: list[torch.Tensor] = []
|
||||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
|
|
||||||
lora = lora_dict[lora_id]
|
lora = lora_dict[lora_id]
|
||||||
result = linear(input_)[0]
|
result = linear(input_)[0]
|
||||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
|
||||||
expected_results.append(result)
|
expected_results.append(result)
|
||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
@ -817,7 +818,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||||
lora = lora_dict[lora_id]
|
lora = lora_dict[lora_id]
|
||||||
result = linear(input_)[0]
|
result = linear(input_)[0]
|
||||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
|
||||||
expected_results.append(result)
|
expected_results.append(result)
|
||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
@ -965,8 +966,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||||||
result = linear(input_)[0]
|
result = linear(input_)[0]
|
||||||
subloras = sublora_dict[lora_id]
|
subloras = sublora_dict[lora_id]
|
||||||
for i, sublora in enumerate(subloras):
|
for i, sublora in enumerate(subloras):
|
||||||
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
|
result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
|
||||||
(i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
|
(i + 1)] += (
|
||||||
|
input_ @ sublora.lora_a.T @ sublora.lora_b.T *
|
||||||
sublora.scaling)
|
sublora.scaling)
|
||||||
expected_results.append(result)
|
expected_results.append(result)
|
||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
|
|||||||
@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device):
|
|||||||
assert lora.lora_b is not None
|
assert lora.lora_b is not None
|
||||||
assert lora.lora_a.device == torch.device(device)
|
assert lora.lora_a.device == torch.device(device)
|
||||||
assert lora.lora_b.device == torch.device(device)
|
assert lora.lora_b.device == torch.device(device)
|
||||||
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
|
assert (lora.lora_a.shape[0] == lora.lora_b.shape[1]
|
||||||
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
||||||
assert lora.lora_a.shape[1] == 8
|
assert lora.lora_a.shape[0] == 8
|
||||||
embeddings_module = next(
|
embeddings_module = next(
|
||||||
(k for k in EMBEDDING_MODULES if k in module_name), None)
|
(k for k in EMBEDDING_MODULES if k in module_name), None)
|
||||||
if embeddings_module:
|
if embeddings_module:
|
||||||
@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
|
|||||||
name,
|
name,
|
||||||
8,
|
8,
|
||||||
16,
|
16,
|
||||||
torch.rand([w.shape[1], 8], device=device),
|
torch.rand([8, w.shape[1]], device=device),
|
||||||
torch.rand([8, w.shape[0]], device=device),
|
torch.rand([w.shape[0], 8], device=device),
|
||||||
)
|
)
|
||||||
return LoRAModel(lora_id, 8, loras)
|
return LoRAModel(lora_id, 8, loras)
|
||||||
|
|
||||||
@ -109,8 +109,8 @@ def create_packed_lora(
|
|||||||
replaced_module_name,
|
replaced_module_name,
|
||||||
8,
|
8,
|
||||||
16,
|
16,
|
||||||
torch.rand([w.shape[1], 8], device=device),
|
torch.rand([8, w.shape[1]], device=device),
|
||||||
torch.rand([8, w.shape[0] // len(replaced_module_names)],
|
torch.rand([w.shape[0] // len(replaced_module_names), 8],
|
||||||
device=device),
|
device=device),
|
||||||
)
|
)
|
||||||
return LoRAModel(lora_id, 8, loras)
|
return LoRAModel(lora_id, 8, loras)
|
||||||
|
|||||||
@ -36,10 +36,10 @@ class DummyLoRAManager:
|
|||||||
module_name,
|
module_name,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
lora_alpha=1,
|
lora_alpha=1,
|
||||||
lora_a=torch.rand([weight.shape[1], rank],
|
lora_a=torch.rand([rank, weight.shape[1]],
|
||||||
dtype=weight.dtype,
|
dtype=weight.dtype,
|
||||||
device=self._device),
|
device=self._device),
|
||||||
lora_b=torch.rand([rank, weight.shape[0]],
|
lora_b=torch.rand([weight.shape[0], rank],
|
||||||
dtype=weight.dtype,
|
dtype=weight.dtype,
|
||||||
device=self._device),
|
device=self._device),
|
||||||
)
|
)
|
||||||
@ -67,8 +67,8 @@ class DummyLoRAManager:
|
|||||||
module_name,
|
module_name,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
lora_alpha=1,
|
lora_alpha=1,
|
||||||
lora_a=torch.rand([input_dim, rank], device="cuda"),
|
lora_a=torch.rand([rank, input_dim], device="cuda"),
|
||||||
lora_b=torch.rand([rank, output_dim], device="cuda"),
|
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
|
||||||
embeddings_tensor=embeddings_tensor,
|
embeddings_tensor=embeddings_tensor,
|
||||||
)
|
)
|
||||||
self.set_module_lora(module_name, lora)
|
self.set_module_lora(module_name, lora)
|
||||||
|
|||||||
@ -121,18 +121,18 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_bias = self.slice_bias(lora_bias)
|
lora_bias = self.slice_bias(lora_bias)
|
||||||
|
|
||||||
self.lora_a_stacked[0][index,
|
self.lora_a_stacked[0][index,
|
||||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
||||||
lora_a.T, non_blocking=True)
|
lora_a, non_blocking=True)
|
||||||
self.lora_b_stacked[0][index,
|
self.lora_b_stacked[0][index,
|
||||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
|
||||||
lora_b.T, non_blocking=True)
|
lora_b, non_blocking=True)
|
||||||
if lora_bias is not None:
|
if lora_bias is not None:
|
||||||
|
|
||||||
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
|
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
|
||||||
self.lora_bias_stacked)
|
self.lora_bias_stacked)
|
||||||
assert len(self.lora_bias_stacked)
|
assert len(self.lora_bias_stacked)
|
||||||
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
|
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
|
||||||
lora_bias.T, non_blocking=True)
|
lora_bias, non_blocking=True)
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|||||||
@ -99,13 +99,13 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
if self.is_merged_col_linear:
|
if self.is_merged_col_linear:
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
shard_size = self.output_size // 2
|
shard_size = self.output_size // 2
|
||||||
offset = lora_b.shape[-1] // 2
|
offset = lora_b.shape[0] // 2
|
||||||
|
|
||||||
left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
|
left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) *
|
||||||
shard_size]
|
shard_size, :]
|
||||||
right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
|
right_weight = lora_b[offset + tp_rank * shard_size:offset +
|
||||||
(tp_rank + 1) * shard_size]
|
(tp_rank + 1) * shard_size, :]
|
||||||
lora_b = torch.cat([left_weight, right_weight], dim=1)
|
lora_b = torch.cat([left_weight, right_weight], dim=0)
|
||||||
# Applicable to cases where the base_layer is
|
# Applicable to cases where the base_layer is
|
||||||
# ColumnParallelLinear.
|
# ColumnParallelLinear.
|
||||||
else:
|
else:
|
||||||
@ -113,7 +113,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
shard_size = self.output_size
|
shard_size = self.output_size
|
||||||
start_idx = tensor_model_parallel_rank * shard_size
|
start_idx = tensor_model_parallel_rank * shard_size
|
||||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||||
lora_b = lora_b[:, start_idx:end_idx]
|
lora_b = lora_b[start_idx:end_idx, :]
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||||
@ -251,9 +251,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
for i, (shard_id, shard_size) in enumerate(
|
for i, (shard_id, shard_size) in enumerate(
|
||||||
zip(self.output_ids, self.output_slices)):
|
zip(self.output_ids, self.output_slices)):
|
||||||
if (lora_b_i := lora_b[i]) is not None:
|
if (lora_b_i := lora_b[i]) is not None:
|
||||||
sliced_lora_b[i] = lora_b_i[:,
|
sliced_lora_b[i] = lora_b_i[shard_size * shard_id:shard_size *
|
||||||
shard_size * shard_id:shard_size *
|
(shard_id + 1), :]
|
||||||
(shard_id + 1)]
|
|
||||||
return sliced_lora_b
|
return sliced_lora_b
|
||||||
|
|
||||||
def slice_bias(
|
def slice_bias(
|
||||||
@ -285,12 +284,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
for i in range(self.n_slices):
|
for i in range(self.n_slices):
|
||||||
if (lora_a_i := lora_a[i]) is not None:
|
if (lora_a_i := lora_a[i]) is not None:
|
||||||
self.lora_a_stacked[i][
|
self.lora_a_stacked[i][
|
||||||
index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
|
index, 0, :lora_a_i.shape[0], :lora_a_i.shape[1]].copy_(
|
||||||
lora_a_i.T, non_blocking=True)
|
lora_a_i, non_blocking=True)
|
||||||
if (lora_b_i := lora_b[i]) is not None:
|
if (lora_b_i := lora_b[i]) is not None:
|
||||||
self.lora_b_stacked[i][
|
self.lora_b_stacked[i][
|
||||||
index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
|
index, 0, :lora_b_i.shape[0], :lora_b_i.shape[1]].copy_(
|
||||||
lora_b_i.T, non_blocking=True)
|
lora_b_i, non_blocking=True)
|
||||||
|
|
||||||
if lora_bias is not None:
|
if lora_bias is not None:
|
||||||
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
|
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
|
||||||
@ -299,7 +298,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
if (lora_bias_i := lora_bias[i]) is not None:
|
if (lora_bias_i := lora_bias[i]) is not None:
|
||||||
self.lora_bias_stacked[i][index,
|
self.lora_bias_stacked[i][index,
|
||||||
0, :lora_bias_i.shape[0]].copy_(
|
0, :lora_bias_i.shape[0]].copy_(
|
||||||
lora_bias_i.T,
|
lora_bias_i,
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -345,18 +344,18 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.q_shard_id = tp_rank
|
self.q_shard_id = tp_rank
|
||||||
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
||||||
lora_b_q = lora_b[:, self.q_proj_shard_size *
|
lora_b_q = lora_b[self.q_proj_shard_size *
|
||||||
self.q_shard_id:self.q_proj_shard_size *
|
self.q_shard_id:self.q_proj_shard_size *
|
||||||
(self.q_shard_id + 1)]
|
(self.q_shard_id + 1), :]
|
||||||
k_offset = self.q_proj_total_size
|
k_offset = self.q_proj_total_size
|
||||||
lora_b_k = lora_b[:, k_offset +
|
lora_b_k = lora_b[k_offset +
|
||||||
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
|
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
|
||||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
|
||||||
v_offset = k_offset + self.kv_proj_total_size
|
v_offset = k_offset + self.kv_proj_total_size
|
||||||
lora_b_v = lora_b[:, v_offset +
|
lora_b_v = lora_b[v_offset +
|
||||||
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
|
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
|
||||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
|
||||||
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
|
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||||
@ -465,7 +464,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
shard_size = self.lora_a_stacked[0].shape[2]
|
shard_size = self.lora_a_stacked[0].shape[2]
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
lora_a = lora_a[start_idx:start_idx + shard_size, :]
|
||||||
return lora_a
|
return lora_a
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
@ -508,10 +507,10 @@ class MergedColumnParallelLinearWithShardedLoRA(
|
|||||||
output_shard_size = self.lora_a_stacked[0].shape[2]
|
output_shard_size = self.lora_a_stacked[0].shape[2]
|
||||||
output_start_idx = self.tp_rank * output_shard_size
|
output_start_idx = self.tp_rank * output_shard_size
|
||||||
lora_a = [
|
lora_a = [
|
||||||
lora_a[0][:, output_start_idx:output_start_idx +
|
lora_a[0][output_start_idx:output_start_idx +
|
||||||
output_shard_size] if lora_a[0] is not None else None,
|
output_shard_size, :] if lora_a[0] is not None else None,
|
||||||
lora_a[1][:, output_start_idx:output_start_idx +
|
lora_a[1][output_start_idx:output_start_idx +
|
||||||
output_shard_size] if lora_a[1] is not None else None,
|
output_shard_size, :] if lora_a[1] is not None else None,
|
||||||
]
|
]
|
||||||
return lora_a
|
return lora_a
|
||||||
|
|
||||||
@ -551,7 +550,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
|
|||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
shard_size = self.lora_a_stacked[0].shape[2]
|
shard_size = self.lora_a_stacked[0].shape[2]
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
lora_a = lora_a[start_idx:start_idx + shard_size, :]
|
||||||
return lora_a
|
return lora_a
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
@ -589,12 +588,12 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
|
|||||||
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
|
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
|
||||||
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
|
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
|
||||||
lora_a = [
|
lora_a = [
|
||||||
lora_a[0][:, start_idx[0]:start_idx[0] +
|
lora_a[0][start_idx[0]:start_idx[0] +
|
||||||
shard_size[0]] if lora_a[0] is not None else None,
|
shard_size[0], :] if lora_a[0] is not None else None,
|
||||||
lora_a[1][:, start_idx[1]:start_idx[1] +
|
lora_a[1][start_idx[1]:start_idx[1] +
|
||||||
shard_size[1]] if lora_a[1] is not None else None,
|
shard_size[1], :] if lora_a[1] is not None else None,
|
||||||
lora_a[2][:, start_idx[2]:start_idx[2] +
|
lora_a[2][start_idx[2]:start_idx[2] +
|
||||||
shard_size[2]] if lora_a[2] is not None else None,
|
shard_size[2], :] if lora_a[2] is not None else None,
|
||||||
]
|
]
|
||||||
return lora_a
|
return lora_a
|
||||||
|
|
||||||
|
|||||||
@ -140,11 +140,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
self.lora_a_stacked[index,
|
self.lora_a_stacked[index,
|
||||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
||||||
lora_a.T, non_blocking=True)
|
lora_a, non_blocking=True)
|
||||||
self.lora_b_stacked[index,
|
self.lora_b_stacked[index,
|
||||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
|
||||||
lora_b.T, non_blocking=True)
|
lora_b, non_blocking=True)
|
||||||
if embeddings_tensor is not None:
|
if embeddings_tensor is not None:
|
||||||
self.embeddings_tensors[
|
self.embeddings_tensors[
|
||||||
index,
|
index,
|
||||||
|
|||||||
@ -39,7 +39,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
shard_size = self.input_size
|
shard_size = self.input_size
|
||||||
start_idx = self.tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
end_idx = (self.tp_rank + 1) * shard_size
|
end_idx = (self.tp_rank + 1) * shard_size
|
||||||
lora_a = lora_a[start_idx:end_idx, :]
|
lora_a = lora_a[:,start_idx:end_idx]
|
||||||
return lora_a
|
return lora_a
|
||||||
|
|
||||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||||
@ -122,7 +122,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
|||||||
shard_size = self.lora_b_stacked[0].shape[2]
|
shard_size = self.lora_b_stacked[0].shape[2]
|
||||||
start_idx = self.tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
end_idx = (self.tp_rank + 1) * shard_size
|
end_idx = (self.tp_rank + 1) * shard_size
|
||||||
lora_b = lora_b[:, start_idx:end_idx]
|
lora_b = lora_b[ start_idx:end_idx,:]
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||||
|
|||||||
@ -95,11 +95,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
|
||||||
lora_a, non_blocking=True)
|
# so we need transpose here
|
||||||
|
self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||||
|
lora_a.T, non_blocking=True)
|
||||||
self.lora_b_stacked[index,
|
self.lora_b_stacked[index,
|
||||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
|
||||||
lora_b.T, non_blocking=True)
|
lora_b, non_blocking=True)
|
||||||
if embeddings_tensor is not None:
|
if embeddings_tensor is not None:
|
||||||
self.embeddings_tensors[
|
self.embeddings_tensors[
|
||||||
index,
|
index,
|
||||||
|
|||||||
@ -86,11 +86,11 @@ class LoRALayerWeights:
|
|||||||
embeddings_tensor_dim: Optional[int] = None,
|
embeddings_tensor_dim: Optional[int] = None,
|
||||||
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
|
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
|
||||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||||
lora_a = torch.zeros([input_dim, rank],
|
lora_a = torch.zeros([rank, input_dim],
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
pin_memory=pin_memory)
|
pin_memory=pin_memory)
|
||||||
lora_b = torch.zeros([rank, output_dim],
|
lora_b = torch.zeros([output_dim, rank],
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
pin_memory=pin_memory)
|
pin_memory=pin_memory)
|
||||||
|
|||||||
@ -152,30 +152,29 @@ class LoRAModel:
|
|||||||
module_name, peft_helper, lora_embeddings_tensor)
|
module_name, peft_helper, lora_embeddings_tensor)
|
||||||
|
|
||||||
if is_bias:
|
if is_bias:
|
||||||
loras[module_name].bias = tensor.to(device=device,
|
loras[module_name].bias = tensor.to(device=device, dtype=dtype)
|
||||||
dtype=dtype).t()
|
bias = tensor.to(device=device, dtype=dtype)
|
||||||
bias = tensor.to(device=device, dtype=dtype).t()
|
|
||||||
if pin_memory:
|
if pin_memory:
|
||||||
bias = bias.pin_memory()
|
bias = bias.pin_memory()
|
||||||
loras[module_name].bias = bias
|
loras[module_name].bias = bias
|
||||||
elif is_lora_a:
|
elif is_lora_a:
|
||||||
loras[module_name].lora_a = tensor.to(device=device,
|
loras[module_name].lora_a = tensor.to(device=device,
|
||||||
dtype=dtype).t()
|
dtype=dtype)
|
||||||
if pin_memory:
|
if pin_memory:
|
||||||
loras[module_name].lora_a = loras[
|
loras[module_name].lora_a = loras[
|
||||||
module_name].lora_a.pin_memory()
|
module_name].lora_a.pin_memory()
|
||||||
else:
|
else:
|
||||||
loras[module_name].lora_b = tensor.to(device=device,
|
loras[module_name].lora_b = tensor.to(device=device,
|
||||||
dtype=dtype).t()
|
dtype=dtype)
|
||||||
assert embedding_padding_modules is not None
|
assert embedding_padding_modules is not None
|
||||||
if any(name in module_name
|
if any(name in module_name
|
||||||
for name in embedding_padding_modules
|
for name in embedding_padding_modules
|
||||||
) and target_embedding_padding is not None:
|
) and target_embedding_padding is not None:
|
||||||
lora_b = loras[module_name].lora_b
|
lora_b = loras[module_name].lora_b
|
||||||
assert target_embedding_padding >= lora_b.shape[1]
|
assert target_embedding_padding >= lora_b.shape[0]
|
||||||
addition = target_embedding_padding - lora_b.shape[1]
|
addition = target_embedding_padding - lora_b.shape[0]
|
||||||
loras[module_name].lora_b = torch.nn.functional.pad(
|
loras[module_name].lora_b = torch.nn.functional.pad(
|
||||||
lora_b, (0, addition))
|
lora_b, (0, 0, 0, addition))
|
||||||
if pin_memory:
|
if pin_memory:
|
||||||
loras[module_name].lora_b = loras[
|
loras[module_name].lora_b = loras[
|
||||||
module_name].lora_b.pin_memory()
|
module_name].lora_b.pin_memory()
|
||||||
@ -585,7 +584,6 @@ class LoRAModelManager:
|
|||||||
"cpu",
|
"cpu",
|
||||||
bias_enabled=bias_enabled,
|
bias_enabled=bias_enabled,
|
||||||
)
|
)
|
||||||
lora.optimize()
|
|
||||||
else:
|
else:
|
||||||
parts = module_name.split(".")
|
parts = module_name.split(".")
|
||||||
replacements = self.packed_modules_mapping[parts[-1]]
|
replacements = self.packed_modules_mapping[parts[-1]]
|
||||||
@ -600,7 +598,6 @@ class LoRAModelManager:
|
|||||||
"cpu",
|
"cpu",
|
||||||
bias_enabled=bias_enabled,
|
bias_enabled=bias_enabled,
|
||||||
)
|
)
|
||||||
lora.optimize()
|
|
||||||
subloras.append(lora)
|
subloras.append(lora)
|
||||||
lora = PackedLoRALayerWeights.pack(subloras)
|
lora = PackedLoRALayerWeights.pack(subloras)
|
||||||
model.loras[module_name] = lora
|
model.loras[module_name] = lora
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user