mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 11:06:25 +08:00
[Bugfix] Fixed error in slice_lora_b for MergedQKVParallelLinearWithLora (#4609)
This commit is contained in:
parent
478aed5827
commit
10760da800
@ -1,5 +1,5 @@
|
||||
# pylint: disable=unused-argument
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -51,10 +51,9 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
||||
return lora_a
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer, x, bias)
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
@ -88,7 +87,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
)
|
||||
|
||||
|
||||
def _mcp_apply_weights(x, bias, layer):
|
||||
def _mcp_apply(x, bias, layer):
|
||||
"""
|
||||
MergedColumnParallelLinearWithShardedLoRA and
|
||||
QKVParallelLinearWithShardedLora share the same
|
||||
@ -100,8 +99,7 @@ def _mcp_apply_weights(x, bias, layer):
|
||||
"""
|
||||
# expecting 2 for column parallel and 3 for qkv
|
||||
n = len(layer.lora_a_stacked)
|
||||
output = layer.base_layer.linear_method.apply_weights(
|
||||
layer.base_layer, x, bias)
|
||||
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||
@ -136,18 +134,23 @@ class MergedColumnParallelLinearWithShardedLoRA(
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
def slice_lora_a(
|
||||
self, lora_a: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
if lora_a[0] is None or lora_a[1] is None:
|
||||
return lora_a
|
||||
output_shard_size = self.lora_a_stacked[0].shape[2]
|
||||
output_start_idx = self.tp_rank * output_shard_size
|
||||
lora_a = [
|
||||
lora_a[i][:, output_start_idx:output_start_idx + output_shard_size]
|
||||
for i in range(2)
|
||||
lora_a[0][:,
|
||||
output_start_idx:output_start_idx + output_shard_size],
|
||||
lora_a[1][:, output_start_idx:output_start_idx + output_shard_size]
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return _mcp_apply_weights(x, bias, self)
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
@ -172,19 +175,23 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
def slice_lora_a(
|
||||
self, lora_a: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
|
||||
return lora_a
|
||||
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
|
||||
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
|
||||
lora_a = [
|
||||
lora_a[i][:, start_idx[i]:start_idx[i] +
|
||||
shard_size[i]] if lora_a[i] is not None else None
|
||||
for i in range(3)
|
||||
lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
|
||||
lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
|
||||
lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]]
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return _mcp_apply_weights(x, bias, self)
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
@ -218,9 +225,8 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
return lora_b
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer, x)
|
||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# pylint: disable=unused-argument
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -145,11 +145,15 @@ class LoRAMapping:
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
def slice_lora_a(
|
||||
self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
|
||||
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
|
||||
"""Slice lora a if splitting for tensor parallelism."""
|
||||
...
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
def slice_lora_b(
|
||||
self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
|
||||
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
|
||||
"""Slice lora b if splitting with tensor parallelism."""
|
||||
...
|
||||
|
||||
@ -539,10 +543,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.lora_b_stacked[0][index] = 0
|
||||
self.lora_b_stacked[1][index] = 0
|
||||
|
||||
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
def slice_lora_a(
|
||||
self, lora_a: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
def slice_lora_b(
|
||||
self, lora_b: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
if lora_b[0] is None or lora_b[1] is None:
|
||||
return lora_b
|
||||
shard_size = self.output_dim
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
@ -767,10 +777,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
self.lora_a_stacked[2][index] = 0
|
||||
self.lora_b_stacked[2][index] = 0
|
||||
|
||||
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
def slice_lora_a(
|
||||
self, lora_a: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
def slice_lora_b(
|
||||
self, lora_b: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
lora_b_q, lora_b_k, lora_b_v = None, None, None
|
||||
if lora_b[0] is not None:
|
||||
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
@ -992,7 +1007,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
|
||||
return self.base_layer.weight if hasattr(
|
||||
self.base_layer, "weight") else self.base_layer.qweight
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user