[Core] Optimize LoRA weight loading (#25403)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-09-23 18:19:45 +08:00 committed by GitHub
parent 231c2c63e4
commit 273690a50a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 83 additions and 83 deletions

View File

@ -164,8 +164,8 @@ def populate_loras(
weight=layer_weights,
generate_embeddings_tensor=generate_embeddings_tensor,
)
sublora.lora_b = sublora.lora_b[:, (sublora_len *
i):(sublora_len * (i + 1))]
sublora.lora_b = sublora.lora_b[(sublora_len *
i):(sublora_len * (i + 1)), :]
sublora.optimize()
subloras.append(sublora)
@ -304,9 +304,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
result = embedding(input_)
after_a = F.embedding(
input_,
lora.lora_a,
lora.lora_a.T,
)
result += (after_a @ lora.lora_b)
result += (after_a @ lora.lora_b.T)
expected_results.append(result)
expected_result = torch.cat(expected_results)
@ -445,9 +445,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
result = expanded_embedding(input_)
after_a = F.embedding(
original_input_,
lora.lora_a,
lora.lora_a.T,
)
result += (after_a @ lora.lora_b)
result += (after_a @ lora.lora_b.T)
expected_results.append(result)
expected_result = torch.cat(expected_results)
@ -575,7 +575,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
lm_head=linear,
embedding_bias=None)
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
logits_processor.org_vocab_size = vocab_size
@ -692,9 +692,10 @@ def test_linear_replicated(
expected_results: list[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
@ -817,7 +818,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
@ -965,9 +966,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
result = linear(input_)[0]
subloras = sublora_dict[lora_id]
for i, sublora in enumerate(subloras):
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
(i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
sublora.scaling)
result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
(i + 1)] += (
input_ @ sublora.lora_a.T @ sublora.lora_b.T *
sublora.scaling)
expected_results.append(result)
expected_result = torch.cat(expected_results)

View File

@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device):
assert lora.lora_b is not None
assert lora.lora_a.device == torch.device(device)
assert lora.lora_b.device == torch.device(device)
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
assert (lora.lora_a.shape[0] == lora.lora_b.shape[1]
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
assert lora.lora_a.shape[1] == 8
assert lora.lora_a.shape[0] == 8
embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name), None)
if embeddings_module:
@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
name,
8,
16,
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0]], device=device),
torch.rand([8, w.shape[1]], device=device),
torch.rand([w.shape[0], 8], device=device),
)
return LoRAModel(lora_id, 8, loras)
@ -109,8 +109,8 @@ def create_packed_lora(
replaced_module_name,
8,
16,
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0] // len(replaced_module_names)],
torch.rand([8, w.shape[1]], device=device),
torch.rand([w.shape[0] // len(replaced_module_names), 8],
device=device),
)
return LoRAModel(lora_id, 8, loras)

View File

@ -36,10 +36,10 @@ class DummyLoRAManager:
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([weight.shape[1], rank],
lora_a=torch.rand([rank, weight.shape[1]],
dtype=weight.dtype,
device=self._device),
lora_b=torch.rand([rank, weight.shape[0]],
lora_b=torch.rand([weight.shape[0], rank],
dtype=weight.dtype,
device=self._device),
)
@ -67,8 +67,8 @@ class DummyLoRAManager:
module_name,
rank=rank,
lora_alpha=1,
lora_a=torch.rand([input_dim, rank], device="cuda"),
lora_b=torch.rand([rank, output_dim], device="cuda"),
lora_a=torch.rand([rank, input_dim], device="cuda"),
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
embeddings_tensor=embeddings_tensor,
)
self.set_module_lora(module_name, lora)

View File

@ -121,18 +121,18 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
lora_bias = self.slice_bias(lora_bias)
self.lora_a_stacked[0][index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[0][index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
assert len(self.lora_bias_stacked)
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
lora_bias.T, non_blocking=True)
lora_bias, non_blocking=True)
def apply(self,
x: torch.Tensor,

View File

@ -99,13 +99,13 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
if self.is_merged_col_linear:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.output_size // 2
offset = lora_b.shape[-1] // 2
offset = lora_b.shape[0] // 2
left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) *
shard_size]
right_weight = lora_b[:, offset + tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size]
lora_b = torch.cat([left_weight, right_weight], dim=1)
left_weight = lora_b[tp_rank * shard_size:(tp_rank + 1) *
shard_size, :]
right_weight = lora_b[offset + tp_rank * shard_size:offset +
(tp_rank + 1) * shard_size, :]
lora_b = torch.cat([left_weight, right_weight], dim=0)
# Applicable to cases where the base_layer is
# ColumnParallelLinear.
else:
@ -113,7 +113,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
shard_size = self.output_size
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
lora_b = lora_b[start_idx:end_idx, :]
return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
@ -251,9 +251,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)):
if (lora_b_i := lora_b[i]) is not None:
sliced_lora_b[i] = lora_b_i[:,
shard_size * shard_id:shard_size *
(shard_id + 1)]
sliced_lora_b[i] = lora_b_i[shard_size * shard_id:shard_size *
(shard_id + 1), :]
return sliced_lora_b
def slice_bias(
@ -285,12 +284,12 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
for i in range(self.n_slices):
if (lora_a_i := lora_a[i]) is not None:
self.lora_a_stacked[i][
index, 0, :lora_a_i.shape[1], :lora_a_i.shape[0]].copy_(
lora_a_i.T, non_blocking=True)
index, 0, :lora_a_i.shape[0], :lora_a_i.shape[1]].copy_(
lora_a_i, non_blocking=True)
if (lora_b_i := lora_b[i]) is not None:
self.lora_b_stacked[i][
index, 0, :lora_b_i.shape[1], :lora_b_i.shape[0]].copy_(
lora_b_i.T, non_blocking=True)
index, 0, :lora_b_i.shape[0], :lora_b_i.shape[1]].copy_(
lora_b_i, non_blocking=True)
if lora_bias is not None:
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
@ -299,7 +298,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
if (lora_bias_i := lora_bias[i]) is not None:
self.lora_bias_stacked[i][index,
0, :lora_bias_i.shape[0]].copy_(
lora_bias_i.T,
lora_bias_i,
non_blocking=True)
@classmethod
@ -345,18 +344,18 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
lora_b_q = lora_b[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
(self.q_shard_id + 1), :]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset +
lora_b_k = lora_b[k_offset +
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset +
lora_b_v = lora_b[v_offset +
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
self.kv_proj_shard_size * (self.kv_shard_id + 1), :]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
@ -465,7 +464,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a
def apply(self,
@ -508,10 +507,10 @@ class MergedColumnParallelLinearWithShardedLoRA(
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[0] is not None else None,
lora_a[1][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[1] is not None else None,
lora_a[0][output_start_idx:output_start_idx +
output_shard_size, :] if lora_a[0] is not None else None,
lora_a[1][output_start_idx:output_start_idx +
output_shard_size, :] if lora_a[1] is not None else None,
]
return lora_a
@ -551,7 +550,7 @@ class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
lora_a = lora_a[start_idx:start_idx + shard_size, :]
return lora_a
def apply(self,
@ -589,12 +588,12 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][:, start_idx[0]:start_idx[0] +
shard_size[0]] if lora_a[0] is not None else None,
lora_a[1][:, start_idx[1]:start_idx[1] +
shard_size[1]] if lora_a[1] is not None else None,
lora_a[2][:, start_idx[2]:start_idx[2] +
shard_size[2]] if lora_a[2] is not None else None,
lora_a[0][start_idx[0]:start_idx[0] +
shard_size[0], :] if lora_a[0] is not None else None,
lora_a[1][start_idx[1]:start_idx[1] +
shard_size[1], :] if lora_a[1] is not None else None,
lora_a[2][start_idx[2]:start_idx[2] +
shard_size[2], :] if lora_a[2] is not None else None,
]
return lora_a

View File

@ -140,11 +140,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,

View File

@ -39,7 +39,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
shard_size = self.input_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :]
lora_a = lora_a[:,start_idx:end_idx]
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
@ -122,7 +122,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
shard_size = self.lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
lora_b = lora_b[ start_idx:end_idx,:]
return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:

View File

@ -95,11 +95,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
lora_a, non_blocking=True)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
# so we need transpose here
self.lora_a_stacked[index, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
lora_b, non_blocking=True)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,

View File

@ -86,11 +86,11 @@ class LoRALayerWeights:
embeddings_tensor_dim: Optional[int] = None,
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank],
lora_a = torch.zeros([rank, input_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
lora_b = torch.zeros([rank, output_dim],
lora_b = torch.zeros([output_dim, rank],
dtype=dtype,
device=device,
pin_memory=pin_memory)

View File

@ -152,30 +152,29 @@ class LoRAModel:
module_name, peft_helper, lora_embeddings_tensor)
if is_bias:
loras[module_name].bias = tensor.to(device=device,
dtype=dtype).t()
bias = tensor.to(device=device, dtype=dtype).t()
loras[module_name].bias = tensor.to(device=device, dtype=dtype)
bias = tensor.to(device=device, dtype=dtype)
if pin_memory:
bias = bias.pin_memory()
loras[module_name].bias = bias
elif is_lora_a:
loras[module_name].lora_a = tensor.to(device=device,
dtype=dtype).t()
dtype=dtype)
if pin_memory:
loras[module_name].lora_a = loras[
module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
dtype=dtype)
assert embedding_padding_modules is not None
if any(name in module_name
for name in embedding_padding_modules
) and target_embedding_padding is not None:
lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[1]
addition = target_embedding_padding - lora_b.shape[1]
assert target_embedding_padding >= lora_b.shape[0]
addition = target_embedding_padding - lora_b.shape[0]
loras[module_name].lora_b = torch.nn.functional.pad(
lora_b, (0, addition))
lora_b, (0, 0, 0, addition))
if pin_memory:
loras[module_name].lora_b = loras[
module_name].lora_b.pin_memory()
@ -585,7 +584,6 @@ class LoRAModelManager:
"cpu",
bias_enabled=bias_enabled,
)
lora.optimize()
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
@ -600,7 +598,6 @@ class LoRAModelManager:
"cpu",
bias_enabled=bias_enabled,
)
lora.optimize()
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora