[Bugfix] Fixed error in slice_lora_b for MergedQKVParallelLinearWithLora (#4609)

This commit is contained in:
Austin Veselka 2024-05-07 12:59:07 -05:00 committed by GitHub
parent 478aed5827
commit 10760da800
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 32 deletions

View File

@ -1,5 +1,5 @@
# pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Union
import torch
import torch.nn as nn
@ -51,10 +51,9 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x, bias)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
@ -88,7 +87,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
)
def _mcp_apply_weights(x, bias, layer):
def _mcp_apply(x, bias, layer):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer):
"""
# expecting 2 for column parallel and 3 for qkv
n = len(layer.lora_a_stacked)
output = layer.base_layer.linear_method.apply_weights(
layer.base_layer, x, bias)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA(
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None:
return lora_a
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[i][:, output_start_idx:output_start_idx + output_shard_size]
for i in range(2)
lora_a[0][:,
output_start_idx:output_start_idx + output_shard_size],
lora_a[1][:, output_start_idx:output_start_idx + output_shard_size]
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
return lora_a
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[i][:, start_idx[i]:start_idx[i] +
shard_size[i]] if lora_a[i] is not None else None
for i in range(3)
lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]]
]
return lora_a
def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
@ -218,9 +225,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights(
self.base_layer, x)
def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,

View File

@ -1,7 +1,7 @@
# pylint: disable=unused-argument
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -145,11 +145,15 @@ class LoRAMapping:
class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
def slice_lora_a(
self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
"""Slice lora a if splitting for tensor parallelism."""
...
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
def slice_lora_b(
self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
"""Slice lora b if splitting with tensor parallelism."""
...
@ -539,10 +543,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_b[0] is None or lora_b[1] is None:
return lora_b
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
@ -767,10 +777,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
lora_b_q, lora_b_k, lora_b_v = None, None, None
if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
@ -992,7 +1007,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@property
def weight(self):
return self.base_layer.weight if hasattr(
self.base_layer, "weight") else self.base_layer.qweight