[Bugfix] Fix fully sharded LoRA bug (#10352)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-11-15 18:34:58 +08:00 committed by GitHub
parent 26908554b2
commit 1d65ec7eeb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 21 additions and 19 deletions

View File

@ -165,15 +165,14 @@ class MergedColumnParallelLinearWithShardedLoRA(
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None:
return lora_a
#NOTE: lora_a contains 2 subloras, and each sublora could be None.
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][:,
output_start_idx:output_start_idx + output_shard_size],
lora_a[1][:,
output_start_idx:output_start_idx + output_shard_size],
lora_a[0][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[0] is not None else None,
lora_a[1][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[1] is not None else None,
]
return lora_a
@ -261,14 +260,16 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
return lora_a
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]],
lora_a[0][:, start_idx[0]:start_idx[0] +
shard_size[0]] if lora_a[0] is not None else None,
lora_a[1][:, start_idx[1]:start_idx[1] +
shard_size[1]] if lora_a[1] is not None else None,
lora_a[2][:, start_idx[2]:start_idx[2] +
shard_size[2]] if lora_a[2] is not None else None,
]
return lora_a

View File

@ -685,26 +685,27 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_b[0] is None or lora_b[1] is None:
return lora_b
#NOTE: lora_b contains 2 subloras, and each sublora could be None.
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = [
lora_b[0][:, start_idx:end_idx],
lora_b[1][:, start_idx:end_idx],
lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None,
lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None,
]
return lora_b
def slice_bias(
self, bias: List[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]:
if bias[0] is None or bias[1] is None:
return bias
# NOTE : each bias could be None.
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = [bias[0][start_idx:end_idx], bias[1][start_idx:end_idx]]
bias = [
bias[0][start_idx:end_idx] if bias[0] is not None else None,
bias[1][start_idx:end_idx] if bias[1] is not None else None
]
return bias
def set_lora(

View File

@ -232,7 +232,7 @@ class Worker(LocalOrDistributedWorkerBase):
logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB"
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
" memory_usage_post_profile=%.2fGib"
" memory_usage_post_profile=%.2fGiB"
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
(total_gpu_memory - free_memory_pre_profile) / (1024**3),