[Bugfix] Fixed error in slice_lora_b for MergedQKVParallelLinearWithLora (#4609)

This commit is contained in:
Austin Veselka 2024-05-07 12:59:07 -05:00 committed by GitHub
parent 478aed5827
commit 10760da800
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 32 deletions

View File

@ -1,5 +1,5 @@
# pylint: disable=unused-argument # pylint: disable=unused-argument
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -51,10 +51,9 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
lora_a = lora_a[:, start_idx:start_idx + shard_size] lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a return lora_a
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.base_layer, x, bias)
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output, out_orig_shape = output.view(-1,
@ -88,7 +87,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
) )
def _mcp_apply_weights(x, bias, layer): def _mcp_apply(x, bias, layer):
""" """
MergedColumnParallelLinearWithShardedLoRA and MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same QKVParallelLinearWithShardedLora share the same
@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer):
""" """
# expecting 2 for column parallel and 3 for qkv # expecting 2 for column parallel and 3 for qkv
n = len(layer.lora_a_stacked) n = len(layer.lora_a_stacked)
output = layer.base_layer.linear_method.apply_weights( output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA(
Based on S-LoRA, slicing happens along the rank dim. Based on S-LoRA, slicing happens along the rank dim.
""" """
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None:
return lora_a
output_shard_size = self.lora_a_stacked[0].shape[2] output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size output_start_idx = self.tp_rank * output_shard_size
lora_a = [ lora_a = [
lora_a[i][:, output_start_idx:output_start_idx + output_shard_size] lora_a[0][:,
for i in range(2) output_start_idx:output_start_idx + output_shard_size],
lora_a[1][:, output_start_idx:output_start_idx + output_shard_size]
] ]
return lora_a return lora_a
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self) return _mcp_apply(x, bias, self)
@classmethod @classmethod
@_fully_sharded_can_replace @_fully_sharded_can_replace
@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
Based on S-LoRA, slicing happens along the rank dim. Based on S-LoRA, slicing happens along the rank dim.
""" """
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
return lora_a
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)] start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [ lora_a = [
lora_a[i][:, start_idx[i]:start_idx[i] + lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
shard_size[i]] if lora_a[i] is not None else None lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
for i in range(3) lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]]
] ]
return lora_a return lora_a
def apply_weights(self, x: torch.Tensor, def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
return _mcp_apply_weights(x, bias, self) return _mcp_apply(x, bias, self)
@classmethod @classmethod
@_fully_sharded_can_replace @_fully_sharded_can_replace
@ -218,9 +225,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
lora_b = lora_b[:, start_idx:end_idx] lora_b = lora_b[:, start_idx:end_idx]
return lora_b return lora_b
def apply_weights(self, x: torch.Tensor) -> torch.Tensor: def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.quant_method.apply(self.base_layer, x)
self.base_layer, x)
x = x.view(-1, x.shape[-1]) x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output, out_orig_shape = output.view(-1,

View File

@ -1,7 +1,7 @@
# pylint: disable=unused-argument # pylint: disable=unused-argument
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -145,11 +145,15 @@ class LoRAMapping:
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: def slice_lora_a(
self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
"""Slice lora a if splitting for tensor parallelism.""" """Slice lora a if splitting for tensor parallelism."""
... ...
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: def slice_lora_b(
self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
"""Slice lora b if splitting with tensor parallelism.""" """Slice lora b if splitting with tensor parallelism."""
... ...
@ -539,10 +543,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0 self.lora_b_stacked[1][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
return lora_a return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_b[0] is None or lora_b[1] is None:
return lora_b
shard_size = self.output_dim shard_size = self.output_dim
start_idx = self.tp_rank * shard_size start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size end_idx = (self.tp_rank + 1) * shard_size
@ -767,10 +777,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.lora_a_stacked[2][index] = 0 self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
return lora_a return lora_a
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
lora_b_q, lora_b_k, lora_b_v = None, None, None
if lora_b[0] is not None: if lora_b[0] is not None:
lora_b_q = lora_b[0][:, self.q_proj_shard_size * lora_b_q = lora_b[0][:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size * self.q_shard_id:self.q_proj_shard_size *
@ -992,7 +1007,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@property @property
def weight(self): def weight(self):
return self.base_layer.weight if hasattr( return self.base_layer.weight if hasattr(
self.base_layer, "weight") else self.base_layer.qweight self.base_layer, "weight") else self.base_layer.qweight