mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 02:25:01 +08:00
[Misc][LoRA] Clean up the function interface of Punica (#10917)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
39c89e71a8
commit
571da8fc43
@ -565,7 +565,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
|||||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("stage", STAGES)
|
@pytest.mark.parametrize("stage", STAGES)
|
||||||
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
@pytest.mark.parametrize("bias_enabled", [True, False])
|
||||||
|
def test_linear_replicated(dist_init, num_loras, device, stage,
|
||||||
|
bias_enabled) -> None:
|
||||||
|
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
@ -573,7 +575,8 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
|||||||
max_loras = 8
|
max_loras = 8
|
||||||
lora_config = LoRAConfig(max_loras=max_loras,
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
max_lora_rank=8,
|
max_lora_rank=8,
|
||||||
lora_dtype=torch.float16)
|
lora_dtype=torch.float16,
|
||||||
|
bias_enabled=bias_enabled)
|
||||||
|
|
||||||
def create_random_linear_replicated_layer():
|
def create_random_linear_replicated_layer():
|
||||||
|
|
||||||
@ -585,7 +588,12 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
|||||||
lora_linear = ReplicatedLinearWithLoRA(linear)
|
lora_linear = ReplicatedLinearWithLoRA(linear)
|
||||||
|
|
||||||
lora_linear.create_lora_weights(max_loras, lora_config)
|
lora_linear.create_lora_weights(max_loras, lora_config)
|
||||||
|
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
|
||||||
|
lora_linear.lora_b_stacked) == 1)
|
||||||
|
if bias_enabled:
|
||||||
|
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
|
||||||
|
else:
|
||||||
|
assert lora_linear.lora_bias_stacked is None
|
||||||
return linear, lora_linear
|
return linear, lora_linear
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
@ -669,8 +677,9 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
|||||||
@pytest.mark.parametrize("fully_shard", [True, False])
|
@pytest.mark.parametrize("fully_shard", [True, False])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("stage", STAGES)
|
@pytest.mark.parametrize("stage", STAGES)
|
||||||
|
@pytest.mark.parametrize("bias_enabled", [True, False])
|
||||||
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||||
device, stage) -> None:
|
device, stage, bias_enabled) -> None:
|
||||||
|
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
@ -679,7 +688,8 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||||||
lora_config = LoRAConfig(max_loras=max_loras,
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
max_lora_rank=8,
|
max_lora_rank=8,
|
||||||
fully_sharded_loras=fully_shard,
|
fully_sharded_loras=fully_shard,
|
||||||
lora_dtype=torch.float16)
|
lora_dtype=torch.float16,
|
||||||
|
bias_enabled=bias_enabled)
|
||||||
|
|
||||||
def create_random_linear_parallel_layer():
|
def create_random_linear_parallel_layer():
|
||||||
if orientation == "row":
|
if orientation == "row":
|
||||||
@ -700,7 +710,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||||||
if not fully_shard else
|
if not fully_shard else
|
||||||
ColumnParallelLinearWithShardedLoRA(linear))
|
ColumnParallelLinearWithShardedLoRA(linear))
|
||||||
lora_linear.create_lora_weights(max_loras, lora_config)
|
lora_linear.create_lora_weights(max_loras, lora_config)
|
||||||
|
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
|
||||||
|
lora_linear.lora_b_stacked) == 1)
|
||||||
|
if bias_enabled:
|
||||||
|
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
|
||||||
|
else:
|
||||||
|
assert lora_linear.lora_bias_stacked is None
|
||||||
return linear, lora_linear
|
return linear, lora_linear
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
@ -784,8 +799,9 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||||||
@pytest.mark.parametrize("fully_shard", [True, False])
|
@pytest.mark.parametrize("fully_shard", [True, False])
|
||||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
@pytest.mark.parametrize("stage", STAGES)
|
@pytest.mark.parametrize("stage", STAGES)
|
||||||
|
@pytest.mark.parametrize("bias_enabled", [True, False])
|
||||||
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||||
device, stage) -> None:
|
device, stage, bias_enabled) -> None:
|
||||||
|
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
@ -794,7 +810,8 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||||||
lora_config = LoRAConfig(max_loras=max_loras,
|
lora_config = LoRAConfig(max_loras=max_loras,
|
||||||
max_lora_rank=8,
|
max_lora_rank=8,
|
||||||
fully_sharded_loras=fully_shard,
|
fully_sharded_loras=fully_shard,
|
||||||
lora_dtype=torch.float16)
|
lora_dtype=torch.float16,
|
||||||
|
bias_enabled=bias_enabled)
|
||||||
|
|
||||||
def create_column_parallel_packed_layer():
|
def create_column_parallel_packed_layer():
|
||||||
if repeats == 2:
|
if repeats == 2:
|
||||||
@ -832,10 +849,16 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||||||
num_key_value_heads = 32
|
num_key_value_heads = 32
|
||||||
num_attention_heads = 32
|
num_attention_heads = 32
|
||||||
|
|
||||||
|
n_slices = repeats
|
||||||
lora_linear.create_lora_weights(max_loras,
|
lora_linear.create_lora_weights(max_loras,
|
||||||
lora_config,
|
lora_config,
|
||||||
model_config=FakeConfig())
|
model_config=FakeConfig())
|
||||||
|
assert (lora_linear.n_slices == len(lora_linear.lora_a_stacked) == len(
|
||||||
|
lora_linear.lora_b_stacked) == n_slices)
|
||||||
|
if bias_enabled:
|
||||||
|
assert len(lora_linear.lora_bias_stacked) == lora_linear.n_slices
|
||||||
|
else:
|
||||||
|
assert lora_linear.lora_bias_stacked is None
|
||||||
return linear, lora_linear
|
return linear, lora_linear
|
||||||
|
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
@ -911,7 +934,6 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||||||
512,
|
512,
|
||||||
lora_config.lora_extra_vocab_size,
|
lora_config.lora_extra_vocab_size,
|
||||||
)
|
)
|
||||||
# lora_linear.set_mapping(*mapping_info)
|
|
||||||
|
|
||||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||||
expected_result = linear(torch.cat(inputs))[0]
|
expected_result = linear(torch.cat(inputs))[0]
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -32,6 +32,44 @@ def _fully_sharded_can_replace(can_replace):
|
|||||||
return dec
|
return dec
|
||||||
|
|
||||||
|
|
||||||
|
def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
|
||||||
|
"""
|
||||||
|
For `ColumnParallelLinearWithLoRA` or classes that inherit from
|
||||||
|
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
|
||||||
|
"""
|
||||||
|
assert (layer.n_slices == len(layer.lora_a_stacked) == len(
|
||||||
|
layer.lora_b_stacked) == len(layer.output_slices))
|
||||||
|
if layer.lora_bias_stacked is not None:
|
||||||
|
assert layer.n_slices == len(layer.lora_bias_stacked)
|
||||||
|
|
||||||
|
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
|
||||||
|
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||||
|
|
||||||
|
# Since communication is needed, the buffer is directly initialized as a
|
||||||
|
# tensor rather than a tuple of tensor.
|
||||||
|
buffers = torch.zeros(
|
||||||
|
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0)
|
||||||
|
buffers = tensor_model_parallel_all_gather(buffers)
|
||||||
|
layer.punica_wrapper.add_expand(output,
|
||||||
|
buffers,
|
||||||
|
layer.lora_b_stacked,
|
||||||
|
layer.lora_bias_stacked,
|
||||||
|
layer.output_slices,
|
||||||
|
offset_start=0,
|
||||||
|
add_input=True)
|
||||||
|
|
||||||
|
output = output.view(*out_orig_shape)
|
||||||
|
# now have column partitioned and packed output
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
# these layers are based on the tensor parallelism strategy given in
|
# these layers are based on the tensor parallelism strategy given in
|
||||||
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
|
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
|
||||||
# https://arxiv.org/abs/2311.03285.
|
# https://arxiv.org/abs/2311.03285.
|
||||||
@ -51,34 +89,15 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
# gather operation.
|
# gather operation.
|
||||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
shard_size = self.lora_a_stacked.shape[2]
|
shard_size = self.lora_a_stacked[0].shape[2]
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
||||||
return lora_a
|
return lora_a
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor,
|
def apply(self,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
x: torch.Tensor,
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
return _mcp_apply(x, bias, self)
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
output, out_orig_shape = output.view(-1,
|
|
||||||
output.shape[-1]), output.shape
|
|
||||||
buffer = torch.zeros(
|
|
||||||
(x.shape[0], self.lora_a_stacked.shape[2]),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=x.device,
|
|
||||||
)
|
|
||||||
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
|
|
||||||
buffer = tensor_model_parallel_all_gather(buffer)
|
|
||||||
self.punica_wrapper.add_expand(output,
|
|
||||||
buffer,
|
|
||||||
self.lora_b_stacked,
|
|
||||||
self.bias_stacked,
|
|
||||||
add_input=True)
|
|
||||||
# now have column partitioned output
|
|
||||||
|
|
||||||
output = output.view(*out_orig_shape)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_fully_sharded_can_replace
|
@_fully_sharded_can_replace
|
||||||
@ -99,46 +118,6 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
|
|
||||||
"""
|
|
||||||
MergedColumnParallelLinearWithShardedLoRA and
|
|
||||||
MergedQKVParallelLinearWithShardedLora share the same
|
|
||||||
LoRa weight application method.
|
|
||||||
|
|
||||||
The main difference is the step by shard_size for lora_b which can
|
|
||||||
vary for MergedQKVParallelLinearWithShardedLora but is constant for
|
|
||||||
MergedColumnParallelLinearWithShardedLoRA.
|
|
||||||
"""
|
|
||||||
# expecting 2 for column parallel and 3 for qkv
|
|
||||||
n = len(layer.lora_a_stacked)
|
|
||||||
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
|
|
||||||
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
|
||||||
buffers = torch.zeros(
|
|
||||||
(n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=x.device,
|
|
||||||
)
|
|
||||||
for idx in range(n):
|
|
||||||
layer.punica_wrapper.add_shrink(buffers[idx], x,
|
|
||||||
layer.lora_a_stacked[idx], 1.0)
|
|
||||||
|
|
||||||
buffers = tensor_model_parallel_all_gather(buffers)
|
|
||||||
layer.punica_wrapper.add_expand_packed_nslice(
|
|
||||||
output,
|
|
||||||
buffers,
|
|
||||||
layer.lora_b_stacked,
|
|
||||||
layer.bias_stacked,
|
|
||||||
1.0,
|
|
||||||
layer.output_slices,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = output.view(*out_orig_shape)
|
|
||||||
# now have column partitioned and packed output
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class MergedColumnParallelLinearWithShardedLoRA(
|
class MergedColumnParallelLinearWithShardedLoRA(
|
||||||
MergedColumnParallelLinearWithLoRA):
|
MergedColumnParallelLinearWithLoRA):
|
||||||
"""
|
"""
|
||||||
@ -162,8 +141,9 @@ class MergedColumnParallelLinearWithShardedLoRA(
|
|||||||
]
|
]
|
||||||
return lora_a
|
return lora_a
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor,
|
def apply(self,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
return _mcp_apply(x, bias, self)
|
return _mcp_apply(x, bias, self)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -195,31 +175,15 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
|
|||||||
|
|
||||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
shard_size = self.lora_a_stacked.shape[2]
|
shard_size = self.lora_a_stacked[0].shape[2]
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
||||||
return lora_a
|
return lora_a
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor,
|
def apply(self,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
x: torch.Tensor,
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
return _mcp_apply(x, bias, self)
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
output, out_orig_shape = output.view(-1,
|
|
||||||
output.shape[-1]), output.shape
|
|
||||||
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=x.device)
|
|
||||||
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
|
|
||||||
buffer = tensor_model_parallel_all_gather(buffer)
|
|
||||||
self.punica_wrapper.add_expand(output,
|
|
||||||
buffer,
|
|
||||||
self.lora_b_stacked,
|
|
||||||
self.bias_stacked,
|
|
||||||
add_input=True)
|
|
||||||
# now have column partitioned output
|
|
||||||
output = output.view(*out_orig_shape)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_fully_sharded_can_replace
|
@_fully_sharded_can_replace
|
||||||
@ -260,8 +224,9 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
|||||||
]
|
]
|
||||||
return lora_a
|
return lora_a
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor,
|
def apply(self,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
return _mcp_apply(x, bias, self)
|
return _mcp_apply(x, bias, self)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -294,7 +259,7 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||||
shard_size = self.lora_b_stacked.shape[2]
|
shard_size = self.lora_b_stacked[0].shape[2]
|
||||||
start_idx = self.tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
end_idx = (self.tp_rank + 1) * shard_size
|
end_idx = (self.tp_rank + 1) * shard_size
|
||||||
lora_b = lora_b[:, start_idx:end_idx]
|
lora_b = lora_b[:, start_idx:end_idx]
|
||||||
@ -303,20 +268,24 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
|||||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||||
if bias is None:
|
if bias is None:
|
||||||
return bias
|
return bias
|
||||||
shard_size = self.bias_stacked.shape[2]
|
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
||||||
|
self.lora_bias_stacked)
|
||||||
|
shard_size = self.lora_bias_stacked[0].shape[2]
|
||||||
start_idx = self.tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
end_idx = (self.tp_rank + 1) * shard_size
|
end_idx = (self.tp_rank + 1) * shard_size
|
||||||
bias = bias[start_idx:end_idx]
|
bias = bias[start_idx:end_idx]
|
||||||
return bias
|
return bias
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
def apply(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||||
|
|
||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
output, out_orig_shape = output.view(-1,
|
output, out_orig_shape = output.view(-1,
|
||||||
output.shape[-1]), output.shape
|
output.shape[-1]), output.shape
|
||||||
buffer = torch.zeros(
|
buffer = torch.zeros(
|
||||||
(x.shape[0], self.lora_a_stacked.shape[2]),
|
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
)
|
)
|
||||||
@ -330,12 +299,18 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
|||||||
# remains is a standard all_reduce. User should be aware though that
|
# remains is a standard all_reduce. User should be aware though that
|
||||||
# the output is not the same as a normal row_parallel, it should be
|
# the output is not the same as a normal row_parallel, it should be
|
||||||
# reduced before being used
|
# reduced before being used
|
||||||
shard_size = self.lora_b_stacked.shape[2]
|
# NOTE offset are based on the rank.
|
||||||
start_idx = self.tp_rank * shard_size
|
shard_size = self.lora_b_stacked[0].shape[2]
|
||||||
self.punica_wrapper.add_expand_slice(output, buffer,
|
offset_start = self.tp_rank * shard_size
|
||||||
self.lora_b_stacked,
|
self.punica_wrapper.add_expand(
|
||||||
self.bias_stacked, start_idx,
|
output,
|
||||||
shard_size)
|
buffer,
|
||||||
|
self.lora_b_stacked,
|
||||||
|
self.lora_bias_stacked,
|
||||||
|
self.output_slices,
|
||||||
|
offset_start=offset_start,
|
||||||
|
add_input=True,
|
||||||
|
)
|
||||||
output = output.view(*out_orig_shape)
|
output = output.view(*out_orig_shape)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -18,11 +18,14 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
|||||||
tensor_model_parallel_gather)
|
tensor_model_parallel_gather)
|
||||||
from vllm.distributed.utils import divide
|
from vllm.distributed.utils import divide
|
||||||
from vllm.lora.punica import PunicaWrapper
|
from vllm.lora.punica import PunicaWrapper
|
||||||
|
# yapf: disable
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
LinearBase,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
|
# yapf: enable
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
LinearScalingRotaryEmbedding, RotaryEmbedding)
|
LinearScalingRotaryEmbedding, RotaryEmbedding)
|
||||||
@ -249,13 +252,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
full_lora_a_embeddings.shape[1],
|
full_lora_a_embeddings.shape[1],
|
||||||
-1,
|
-1,
|
||||||
)
|
)
|
||||||
|
self.punica_wrapper.add_lora_embedding(full_output,
|
||||||
# Embedding layer only need expand op
|
full_lora_a_embeddings,
|
||||||
self.punica_wrapper.add_expand(full_output,
|
self.lora_b_stacked,
|
||||||
full_lora_a_embeddings,
|
add_input=True)
|
||||||
self.lora_b_stacked,
|
|
||||||
bias_all=None,
|
|
||||||
add_input=True)
|
|
||||||
return full_output.view_as(full_output_org)
|
return full_output.view_as(full_output_org)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -269,14 +269,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
return type(source_layer) is VocabParallelEmbedding
|
return type(source_layer) is VocabParallelEmbedding
|
||||||
|
|
||||||
|
|
||||||
class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||||
|
|
||||||
def __init__(self, base_layer: ReplicatedLinear) -> None:
|
def __init__(self, base_layer: LinearBase):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.base_layer = base_layer
|
self.base_layer = base_layer
|
||||||
self.input_size = self.base_layer.input_size
|
self.input_size = self.base_layer.input_size
|
||||||
self.output_size = self.base_layer.output_size
|
|
||||||
self.device = _get_lora_device(self.base_layer)
|
self.device = _get_lora_device(self.base_layer)
|
||||||
|
self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None
|
||||||
|
|
||||||
|
self.output_slices: Tuple[int, ...]
|
||||||
|
self.tp_size: int
|
||||||
|
self.output_size: int
|
||||||
|
self.n_slices: int
|
||||||
|
|
||||||
def create_lora_weights(
|
def create_lora_weights(
|
||||||
self,
|
self,
|
||||||
@ -285,39 +290,64 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
model_config: Optional[PretrainedConfig] = None,
|
model_config: Optional[PretrainedConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
lora_a_output_size = lora_config.max_lora_rank
|
#
|
||||||
self.lora_a_stacked = torch.zeros(
|
if isinstance(self.base_layer, ReplicatedLinear):
|
||||||
max_loras,
|
lora_a_out_size = lora_config.max_lora_rank
|
||||||
1,
|
lora_b_out_size = self.output_size
|
||||||
lora_a_output_size,
|
|
||||||
self.input_size,
|
elif isinstance(self.base_layer, ColumnParallelLinear):
|
||||||
dtype=lora_config.lora_dtype,
|
lora_a_out_size = (lora_config.max_lora_rank if
|
||||||
device=self.device,
|
not lora_config.fully_sharded_loras else divide(
|
||||||
)
|
lora_config.max_lora_rank, self.tp_size))
|
||||||
self.lora_b_stacked = torch.zeros(
|
lora_b_out_size = self.output_size
|
||||||
max_loras,
|
|
||||||
1,
|
elif isinstance(self.base_layer, RowParallelLinear):
|
||||||
self.output_size,
|
lora_a_out_size = lora_config.max_lora_rank
|
||||||
lora_config.max_lora_rank,
|
lora_b_out_size = (self.output_size if
|
||||||
dtype=lora_config.lora_dtype,
|
not lora_config.fully_sharded_loras else divide(
|
||||||
device=self.device,
|
self.output_size, self.tp_size))
|
||||||
)
|
else:
|
||||||
if lora_config.bias_enabled:
|
raise NotImplementedError
|
||||||
self.bias_stacked = torch.zeros(
|
|
||||||
|
self.lora_a_stacked = tuple(
|
||||||
|
torch.zeros(
|
||||||
max_loras,
|
max_loras,
|
||||||
1,
|
1,
|
||||||
self.output_size,
|
lora_a_out_size,
|
||||||
|
self.input_size,
|
||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
) for _ in range(self.n_slices))
|
||||||
else:
|
self.lora_b_stacked = tuple(
|
||||||
self.bias_stacked = None
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
lora_b_out_size,
|
||||||
|
lora_config.max_lora_rank,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
) for _ in range(self.n_slices))
|
||||||
|
if lora_config.bias_enabled:
|
||||||
|
lora_bias_out_size = lora_b_out_size
|
||||||
|
self.lora_bias_stacked = tuple(
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
lora_bias_out_size,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
) for _ in range(self.n_slices))
|
||||||
|
self.output_slices = (self.lora_b_stacked[0].shape[2], )
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
self.lora_a_stacked[index] = 0
|
for s_index in range(self.n_slices):
|
||||||
self.lora_b_stacked[index] = 0
|
self.lora_a_stacked[s_index][index] = 0
|
||||||
if self.lora_config.bias_enabled:
|
self.lora_b_stacked[s_index][index] = 0
|
||||||
self.bias_stacked[index] = 0
|
if self.lora_config.bias_enabled:
|
||||||
|
# Make mypy happy
|
||||||
|
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
||||||
|
self.lora_bias_stacked)
|
||||||
|
self.lora_bias_stacked[s_index][index] = 0
|
||||||
|
|
||||||
def set_lora(
|
def set_lora(
|
||||||
self,
|
self,
|
||||||
@ -325,29 +355,56 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
bias: Optional[torch.Tensor] = None,
|
lora_bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
# Except for QKVParallelLinearWithLora and
|
||||||
|
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
|
||||||
|
# store weights in a tuple of size 1. These two layers will
|
||||||
|
# override this function.
|
||||||
|
assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
|
||||||
|
self.n_slices == 1)
|
||||||
|
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
|
if self.tp_size > 1:
|
||||||
|
lora_a = self.slice_lora_a(lora_a)
|
||||||
|
lora_b = self.slice_lora_b(lora_b)
|
||||||
|
if lora_bias is not None:
|
||||||
|
lora_bias = self.slice_bias(lora_bias)
|
||||||
|
|
||||||
self.lora_a_stacked[index,
|
self.lora_a_stacked[0][index,
|
||||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||||
lora_a.T, non_blocking=True)
|
lora_a.T, non_blocking=True)
|
||||||
self.lora_b_stacked[index,
|
self.lora_b_stacked[0][index,
|
||||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||||
lora_b.T, non_blocking=True)
|
lora_b.T, non_blocking=True)
|
||||||
if bias is not None:
|
if lora_bias is not None:
|
||||||
self.bias_stacked[index,
|
|
||||||
0, :bias.shape[0]].copy_(bias.T,
|
|
||||||
non_blocking=True)
|
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor,
|
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
self.lora_bias_stacked)
|
||||||
|
assert len(self.lora_bias_stacked)
|
||||||
|
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
|
||||||
|
lora_bias.T, non_blocking=True)
|
||||||
|
|
||||||
|
def apply(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
|
||||||
self.lora_b_stacked, self.bias_stacked,
|
self.lora_b_stacked,
|
||||||
1.0)
|
self.lora_bias_stacked, 1.0,
|
||||||
|
self.output_slices)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||||
|
|
||||||
|
def __init__(self, base_layer: ReplicatedLinear) -> None:
|
||||||
|
super().__init__(base_layer, )
|
||||||
|
# To ensure interface compatibility, set to 1 always.
|
||||||
|
self.tp_size = 1
|
||||||
|
self.output_size = self.base_layer.output_size
|
||||||
|
self.n_slices = 1
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
"""Forward of ReplicatedLinearWithLoRA
|
"""Forward of ReplicatedLinearWithLoRA
|
||||||
|
|
||||||
@ -380,73 +437,26 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
return type(source_layer) is ReplicatedLinear
|
return type(source_layer) is ReplicatedLinear
|
||||||
|
|
||||||
|
|
||||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||||
"""
|
"""
|
||||||
LoRA on top of ColumnParallelLinear layer.
|
LoRA on top of ColumnParallelLinear layer.
|
||||||
|
|
||||||
LoRA B is sliced for tensor parallelism.
|
LoRA B is sliced for tensor parallelism.
|
||||||
|
There are two types for the `base_layer`:
|
||||||
|
1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
|
||||||
|
2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, base_layer: ColumnParallelLinear) -> None:
|
def __init__(self, base_layer: ColumnParallelLinear) -> None:
|
||||||
super().__init__()
|
super().__init__(base_layer)
|
||||||
# The base_layer type is ColumnParallelLinear or
|
# The base_layer type is ColumnParallelLinear or
|
||||||
# MergedColumnParallelLinear, their weight sharding logic is
|
# MergedColumnParallelLinear, their weight sharding logic is
|
||||||
# inconsistent when TP is greater than 1.
|
# inconsistent when TP is greater than 1.
|
||||||
self.is_merged_col_linear = type(
|
self.is_merged_col_linear = type(
|
||||||
base_layer) is MergedColumnParallelLinear
|
base_layer) is MergedColumnParallelLinear
|
||||||
|
|
||||||
self.base_layer = base_layer
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.input_size = self.base_layer.input_size
|
|
||||||
self.output_size = self.base_layer.output_size_per_partition
|
self.output_size = self.base_layer.output_size_per_partition
|
||||||
self.device = _get_lora_device(self.base_layer)
|
# There is only one LoRA layer
|
||||||
|
self.n_slices = 1
|
||||||
def create_lora_weights(
|
|
||||||
self,
|
|
||||||
max_loras: int,
|
|
||||||
lora_config: LoRAConfig,
|
|
||||||
model_config: Optional[PretrainedConfig] = None,
|
|
||||||
) -> None:
|
|
||||||
self.lora_config = lora_config
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
lora_a_output_size_per_partition = (
|
|
||||||
lora_config.max_lora_rank if not lora_config.fully_sharded_loras
|
|
||||||
else divide(lora_config.max_lora_rank, self.tp_size))
|
|
||||||
self.lora_a_stacked = torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
lora_a_output_size_per_partition,
|
|
||||||
self.input_size,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self.lora_b_stacked = torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
self.output_size,
|
|
||||||
lora_config.max_lora_rank,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if lora_config.bias_enabled:
|
|
||||||
self.bias_stacked = torch.zeros(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
self.output_size,
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.bias_stacked = None
|
|
||||||
|
|
||||||
self.output_dim = self.lora_b_stacked.shape[2]
|
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
|
||||||
self.lora_a_stacked[index] = 0
|
|
||||||
self.lora_b_stacked[index] = 0
|
|
||||||
if self.lora_config.bias_enabled:
|
|
||||||
self.bias_stacked[index] = 0
|
|
||||||
|
|
||||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||||
return lora_a
|
return lora_a
|
||||||
@ -485,40 +495,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
bias = bias[start_idx:end_idx]
|
bias = bias[start_idx:end_idx]
|
||||||
return bias
|
return bias
|
||||||
|
|
||||||
def set_lora(
|
|
||||||
self,
|
|
||||||
index: int,
|
|
||||||
lora_a: torch.Tensor,
|
|
||||||
lora_b: torch.Tensor,
|
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
self.reset_lora(index)
|
|
||||||
|
|
||||||
if self.tp_size > 1:
|
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
|
||||||
bias = self.slice_bias(bias)
|
|
||||||
|
|
||||||
self.lora_a_stacked[index,
|
|
||||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
|
||||||
lora_a.T, non_blocking=True)
|
|
||||||
self.lora_b_stacked[index,
|
|
||||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
|
||||||
lora_b.T, non_blocking=True)
|
|
||||||
if bias is not None:
|
|
||||||
self.bias_stacked[index,
|
|
||||||
0, :bias.shape[0]].copy_(bias.T,
|
|
||||||
non_blocking=True)
|
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
|
||||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
|
||||||
self.lora_b_stacked, self.bias_stacked,
|
|
||||||
1.0)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
"""Forward of ColumnParallelLinear
|
"""Forward of ColumnParallelLinear
|
||||||
|
|
||||||
@ -568,6 +544,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
|
|
||||||
def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
|
def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
|
||||||
super().__init__(base_layer)
|
super().__init__(base_layer)
|
||||||
|
# There are two LoRA layers
|
||||||
|
self.n_slices = len(self.base_layer.output_sizes)
|
||||||
|
|
||||||
def create_lora_weights(
|
def create_lora_weights(
|
||||||
self,
|
self,
|
||||||
@ -575,9 +553,13 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
model_config: Optional[PretrainedConfig] = None,
|
model_config: Optional[PretrainedConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
The main reason for overriding this function is to enhance code
|
||||||
|
maintainability.
|
||||||
|
"""
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
n_slices = 2
|
|
||||||
if not (len(self.base_layer.output_sizes) == n_slices
|
if not (len(self.base_layer.output_sizes) == self.n_slices == 2
|
||||||
and self.base_layer.output_sizes[0]
|
and self.base_layer.output_sizes[0]
|
||||||
== self.base_layer.output_sizes[1]):
|
== self.base_layer.output_sizes[1]):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -598,7 +580,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
self.input_size,
|
self.input_size,
|
||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
) for _ in range(n_slices))
|
) for _ in range(self.n_slices))
|
||||||
self.lora_b_stacked = tuple(
|
self.lora_b_stacked = tuple(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
max_loras,
|
max_loras,
|
||||||
@ -607,30 +589,19 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
lora_config.max_lora_rank,
|
lora_config.max_lora_rank,
|
||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
) for _ in range(n_slices))
|
) for _ in range(self.n_slices))
|
||||||
if lora_config.bias_enabled:
|
if lora_config.bias_enabled:
|
||||||
self.bias_stacked = tuple(
|
self.lora_bias_stacked = tuple(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
max_loras,
|
max_loras,
|
||||||
1,
|
1,
|
||||||
self.output_size // 2,
|
self.output_size // 2,
|
||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
) for _ in range(n_slices))
|
) for _ in range(self.n_slices))
|
||||||
else:
|
|
||||||
self.bias_stacked = None
|
|
||||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||||
self.output_slices = (self.output_dim, self.output_dim)
|
self.output_slices = (self.output_dim, self.output_dim)
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
|
||||||
self.lora_a_stacked[0][index] = 0
|
|
||||||
self.lora_a_stacked[1][index] = 0
|
|
||||||
self.lora_b_stacked[0][index] = 0
|
|
||||||
self.lora_b_stacked[1][index] = 0
|
|
||||||
if self.lora_config.bias_enabled:
|
|
||||||
self.bias_stacked[0][index] = 0
|
|
||||||
self.bias_stacked[1][index] = 0
|
|
||||||
|
|
||||||
def slice_lora_a(
|
def slice_lora_a(
|
||||||
self, lora_a: List[Union[torch.Tensor, None]]
|
self, lora_a: List[Union[torch.Tensor, None]]
|
||||||
) -> List[Union[torch.Tensor, None]]:
|
) -> List[Union[torch.Tensor, None]]:
|
||||||
@ -668,15 +639,15 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
bias: Optional[torch.Tensor] = None,
|
lora_bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
lora_a = self.slice_lora_a(lora_a)
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
lora_b = self.slice_lora_b(lora_b)
|
||||||
if bias is not None:
|
if lora_bias is not None:
|
||||||
bias = self.slice_bias(bias)
|
lora_bias = self.slice_bias(lora_bias)
|
||||||
|
|
||||||
if lora_a[0] is not None:
|
if lora_a[0] is not None:
|
||||||
self.lora_a_stacked[0][
|
self.lora_a_stacked[0][
|
||||||
@ -685,10 +656,11 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
self.lora_b_stacked[0][
|
self.lora_b_stacked[0][
|
||||||
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
||||||
lora_b[0].T, non_blocking=True)
|
lora_b[0].T, non_blocking=True)
|
||||||
if bias is not None and bias[0] is not None:
|
if lora_bias is not None and lora_bias[0] is not None:
|
||||||
self.bias_stacked[0][index,
|
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
||||||
0, :bias[0].shape[0]].copy_(bias[0].T,
|
self.lora_bias_stacked)
|
||||||
non_blocking=True)
|
self.lora_bias_stacked[0][index, 0, :lora_bias[0].shape[0]].copy_(
|
||||||
|
lora_bias[0].T, non_blocking=True)
|
||||||
if lora_a[1] is not None:
|
if lora_a[1] is not None:
|
||||||
self.lora_a_stacked[1][
|
self.lora_a_stacked[1][
|
||||||
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
||||||
@ -696,18 +668,11 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
self.lora_b_stacked[1][
|
self.lora_b_stacked[1][
|
||||||
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||||
lora_b[1].T, non_blocking=True)
|
lora_b[1].T, non_blocking=True)
|
||||||
if bias is not None and bias[1] is not None:
|
if lora_bias is not None and lora_bias[1] is not None:
|
||||||
self.bias_stacked[1][index,
|
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
||||||
0, :bias[1].shape[0]].copy_(bias[1].T,
|
self.lora_bias_stacked)
|
||||||
non_blocking=True)
|
self.lora_bias_stacked[1][index, 0, :lora_bias[1].shape[0]].copy_(
|
||||||
|
lora_bias[1].T, non_blocking=True)
|
||||||
def apply(self, x: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
|
||||||
self.punica_wrapper.add_lora_packed_nslice(
|
|
||||||
output, x, self.lora_a_stacked, self.lora_b_stacked,
|
|
||||||
self.bias_stacked, 1.0, (self.output_dim, self.output_dim))
|
|
||||||
return output
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_not_fully_sharded_can_replace
|
@_not_fully_sharded_can_replace
|
||||||
@ -737,7 +702,6 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
|
|
||||||
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
||||||
super().__init__(base_layer)
|
super().__init__(base_layer)
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.q_proj_total_size = (self.base_layer.total_num_heads *
|
self.q_proj_total_size = (self.base_layer.total_num_heads *
|
||||||
self.base_layer.head_size)
|
self.base_layer.head_size)
|
||||||
self.q_proj_shard_size = (self.base_layer.num_heads *
|
self.q_proj_shard_size = (self.base_layer.num_heads *
|
||||||
@ -746,6 +710,8 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
self.base_layer.head_size)
|
self.base_layer.head_size)
|
||||||
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
|
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
|
||||||
self.base_layer.head_size)
|
self.base_layer.head_size)
|
||||||
|
# There is only one LoRA layer
|
||||||
|
self.n_slices = 1
|
||||||
|
|
||||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
@ -780,32 +746,6 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
|
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
|
||||||
return bias
|
return bias
|
||||||
|
|
||||||
def set_lora(
|
|
||||||
self,
|
|
||||||
index: int,
|
|
||||||
lora_a: torch.Tensor,
|
|
||||||
lora_b: torch.Tensor,
|
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
self.reset_lora(index)
|
|
||||||
if self.tp_size > 1:
|
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
|
||||||
if bias is not None:
|
|
||||||
bias = self.slice_bias(bias)
|
|
||||||
|
|
||||||
self.lora_a_stacked[index,
|
|
||||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
|
||||||
lora_a.T, non_blocking=True)
|
|
||||||
self.lora_b_stacked[index,
|
|
||||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
|
||||||
lora_b.T, non_blocking=True)
|
|
||||||
if bias is not None:
|
|
||||||
self.bias_stacked[index,
|
|
||||||
0, :bias.shape[0]].copy_(bias.T,
|
|
||||||
non_blocking=True)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_not_fully_sharded_can_replace
|
@_not_fully_sharded_can_replace
|
||||||
def can_replace_layer(cls, source_layer: nn.Module,
|
def can_replace_layer(cls, source_layer: nn.Module,
|
||||||
@ -828,6 +768,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
|
|
||||||
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
||||||
super().__init__(base_layer)
|
super().__init__(base_layer)
|
||||||
|
# There are three LoRA layer.
|
||||||
|
self.n_slices = len(self.base_layer.output_sizes)
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
def create_lora_weights(
|
def create_lora_weights(
|
||||||
self,
|
self,
|
||||||
@ -835,9 +779,16 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
lora_config: LoRAConfig,
|
lora_config: LoRAConfig,
|
||||||
model_config: Optional[PretrainedConfig] = None,
|
model_config: Optional[PretrainedConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
The main reason for overloading this function is to handle inconsistent
|
||||||
|
weight dimensions in qkv lora.
|
||||||
|
"""
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
if not (len(self.base_layer.output_sizes) == self.n_slices == 3):
|
||||||
|
raise ValueError(
|
||||||
|
"LoRAColumnParallelLinear3Slice requires 3 slices.")
|
||||||
|
|
||||||
self.q_proj_shard_size = (self.base_layer.num_heads *
|
self.q_proj_shard_size = (self.base_layer.num_heads *
|
||||||
self.base_layer.head_size)
|
self.base_layer.head_size)
|
||||||
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
||||||
@ -902,7 +853,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if lora_config.bias_enabled:
|
if lora_config.bias_enabled:
|
||||||
self.bias_stacked = (
|
self.lora_bias_stacked = (
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
max_loras,
|
max_loras,
|
||||||
1,
|
1,
|
||||||
@ -925,9 +876,6 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
self.bias_stacked = None
|
|
||||||
|
|
||||||
self.output_slices = (
|
self.output_slices = (
|
||||||
self.q_proj_shard_size,
|
self.q_proj_shard_size,
|
||||||
self.kv_proj_shard_size,
|
self.kv_proj_shard_size,
|
||||||
@ -939,18 +887,6 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
self.indices: torch.Tensor
|
self.indices: torch.Tensor
|
||||||
self.indices_len: List[int]
|
self.indices_len: List[int]
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
|
||||||
self.lora_a_stacked[0][index] = 0
|
|
||||||
self.lora_b_stacked[0][index] = 0
|
|
||||||
self.lora_a_stacked[1][index] = 0
|
|
||||||
self.lora_b_stacked[1][index] = 0
|
|
||||||
self.lora_a_stacked[2][index] = 0
|
|
||||||
self.lora_b_stacked[2][index] = 0
|
|
||||||
if self.lora_config.bias_enabled:
|
|
||||||
self.bias_stacked[0][index] = 0
|
|
||||||
self.bias_stacked[1][index] = 0
|
|
||||||
self.bias_stacked[2][index] = 0
|
|
||||||
|
|
||||||
def slice_lora_a(
|
def slice_lora_a(
|
||||||
self, lora_a: List[Union[torch.Tensor, None]]
|
self, lora_a: List[Union[torch.Tensor, None]]
|
||||||
) -> List[Union[torch.Tensor, None]]:
|
) -> List[Union[torch.Tensor, None]]:
|
||||||
@ -1000,15 +936,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
bias: Optional[torch.Tensor] = None,
|
lora_bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
lora_a = self.slice_lora_a(lora_a)
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
lora_b = self.slice_lora_b(lora_b)
|
||||||
if bias is not None:
|
if lora_bias is not None:
|
||||||
bias = self.slice_bias(bias)
|
lora_bias = self.slice_bias(lora_bias)
|
||||||
|
|
||||||
if lora_b[0] is not None:
|
if lora_b[0] is not None:
|
||||||
lora_b_q = lora_b[0]
|
lora_b_q = lora_b[0]
|
||||||
@ -1039,26 +975,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
|
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
|
||||||
lora_a[2].T, non_blocking=True)
|
lora_a[2].T, non_blocking=True)
|
||||||
|
|
||||||
if bias is not None:
|
if lora_bias is not None:
|
||||||
if bias[0] is not None:
|
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
|
||||||
self.bias_stacked[0][index, 0, :bias[0].shape[0]].copy_(
|
self.lora_bias_stacked)
|
||||||
bias[0].T, non_blocking=True)
|
if lora_bias[0] is not None:
|
||||||
if bias[1] is not None:
|
self.lora_bias_stacked[0][index,
|
||||||
self.bias_stacked[1][index, 0, :bias[1].shape[0]].copy_(
|
0, :lora_bias[0].shape[0]].copy_(
|
||||||
bias[1].T, non_blocking=True)
|
lora_bias[0].T,
|
||||||
if bias[2] is not None:
|
non_blocking=True)
|
||||||
self.bias_stacked[2][index, 0, :bias[2].shape[0]].copy_(
|
if lora_bias[1] is not None:
|
||||||
bias[2].T, non_blocking=True)
|
self.lora_bias_stacked[1][index,
|
||||||
|
0, :lora_bias[1].shape[0]].copy_(
|
||||||
def apply(self, x: torch.Tensor,
|
lora_bias[1].T,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
non_blocking=True)
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
if lora_bias[2] is not None:
|
||||||
self.punica_wrapper.add_lora_packed_nslice(output, x,
|
self.lora_bias_stacked[2][index,
|
||||||
self.lora_a_stacked,
|
0, :lora_bias[2].shape[0]].copy_(
|
||||||
self.lora_b_stacked,
|
lora_bias[2].T,
|
||||||
self.bias_stacked, 1.0,
|
non_blocking=True)
|
||||||
self.output_slices)
|
|
||||||
return output
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_not_fully_sharded_can_replace
|
@_not_fully_sharded_can_replace
|
||||||
@ -1073,76 +1007,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
and len(packed_modules_list) == 3)
|
and len(packed_modules_list) == 3)
|
||||||
|
|
||||||
|
|
||||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||||
|
|
||||||
def __init__(self, base_layer: RowParallelLinear) -> None:
|
def __init__(self, base_layer: RowParallelLinear) -> None:
|
||||||
super().__init__()
|
super().__init__(base_layer)
|
||||||
self.base_layer = base_layer
|
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
# reset input_size
|
||||||
self.input_size = self.base_layer.input_size_per_partition
|
self.input_size = self.base_layer.input_size_per_partition
|
||||||
self.output_size = self.base_layer.output_size
|
self.output_size = self.base_layer.output_size
|
||||||
self.device = _get_lora_device(self.base_layer)
|
|
||||||
|
|
||||||
def create_lora_weights(
|
|
||||||
self,
|
|
||||||
max_loras: int,
|
|
||||||
lora_config: LoRAConfig,
|
|
||||||
model_config: Optional[PretrainedConfig] = None,
|
|
||||||
) -> None:
|
|
||||||
self.lora_config = lora_config
|
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.lora_a_stacked = torch.zeros(
|
# There is only one LoRA layer.
|
||||||
(
|
self.n_slices = 1
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
lora_config.max_lora_rank,
|
|
||||||
self.input_size,
|
|
||||||
),
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
lora_b_output_size_per_partition = (
|
|
||||||
self.output_size if not lora_config.fully_sharded_loras else
|
|
||||||
divide(self.output_size, tp_size))
|
|
||||||
|
|
||||||
self.lora_b_stacked = torch.zeros(
|
|
||||||
(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
lora_b_output_size_per_partition,
|
|
||||||
lora_config.max_lora_rank,
|
|
||||||
),
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
if lora_config.bias_enabled:
|
|
||||||
self.bias_stacked = torch.zeros(
|
|
||||||
(
|
|
||||||
max_loras,
|
|
||||||
1,
|
|
||||||
self.output_size,
|
|
||||||
),
|
|
||||||
dtype=lora_config.lora_dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.bias_stacked = None
|
|
||||||
# Lazily initialized
|
|
||||||
self.indices: torch.Tensor
|
|
||||||
self.indices_len: List[int]
|
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
|
||||||
self.lora_a_stacked[index] = 0
|
|
||||||
self.lora_b_stacked[index] = 0
|
|
||||||
if self.lora_config.bias_enabled:
|
|
||||||
self.bias_stacked[index] = 0
|
|
||||||
|
|
||||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
|
||||||
shard_size = self.input_size
|
shard_size = self.input_size
|
||||||
start_idx = tensor_model_parallel_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
end_idx = (self.tp_rank + 1) * shard_size
|
||||||
lora_a = lora_a[start_idx:end_idx, :]
|
lora_a = lora_a[start_idx:end_idx, :]
|
||||||
return lora_a
|
return lora_a
|
||||||
|
|
||||||
@ -1152,40 +1035,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||||
return bias
|
return bias
|
||||||
|
|
||||||
def set_lora(
|
|
||||||
self,
|
|
||||||
index: int,
|
|
||||||
lora_a: torch.Tensor,
|
|
||||||
lora_b: torch.Tensor,
|
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
):
|
|
||||||
self.reset_lora(index)
|
|
||||||
|
|
||||||
if self.base_layer.tp_size > 1:
|
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
|
||||||
if bias is not None:
|
|
||||||
bias = self.slice_bias(bias)
|
|
||||||
|
|
||||||
self.lora_a_stacked[index,
|
|
||||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
|
||||||
lora_a.T, non_blocking=True)
|
|
||||||
self.lora_b_stacked[index,
|
|
||||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
|
||||||
lora_b.T, non_blocking=True)
|
|
||||||
if bias is not None:
|
|
||||||
self.bias_stacked[index,
|
|
||||||
0, :bias.shape[0]].copy_(bias.T,
|
|
||||||
non_blocking=True)
|
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
|
||||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
|
||||||
self.lora_b_stacked, self.bias_stacked,
|
|
||||||
1.0)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
"""Forward of RowParallelLinear
|
"""Forward of RowParallelLinear
|
||||||
|
|
||||||
@ -1203,10 +1052,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
else:
|
else:
|
||||||
# TODO: simplify code below
|
# TODO: simplify code below
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
splitted_input = split_tensor_along_last_dim(
|
splitted_input = split_tensor_along_last_dim(
|
||||||
input_, num_partitions=self.base_layer.tp_size)
|
input_, num_partitions=self.base_layer.tp_size)
|
||||||
input_parallel = splitted_input[tp_rank].contiguous()
|
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
output_parallel = self.apply(input_parallel)
|
output_parallel = self.apply(input_parallel)
|
||||||
|
|||||||
@ -555,17 +555,17 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
input_dim,
|
input_dim,
|
||||||
output_dim,
|
output_dim,
|
||||||
rank,
|
rank,
|
||||||
module.lora_a_stacked.dtype,
|
module.lora_a_stacked[0].dtype,
|
||||||
"cpu",
|
"cpu",
|
||||||
embeddings_tensor_dim=embeddings_tensor_dim,
|
embeddings_tensor_dim=embeddings_tensor_dim,
|
||||||
bias_enabled=bias_enabled)
|
bias_enabled=bias_enabled)
|
||||||
else:
|
else:
|
||||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||||
module_name,
|
module_name,
|
||||||
module.lora_a_stacked.shape[-1],
|
module.lora_a_stacked[0].shape[-1],
|
||||||
module.lora_b_stacked.shape[-2],
|
module.lora_b_stacked[0].shape[-2],
|
||||||
rank,
|
rank,
|
||||||
module.lora_a_stacked.dtype,
|
module.lora_a_stacked[0].dtype,
|
||||||
"cpu",
|
"cpu",
|
||||||
bias_enabled=bias_enabled,
|
bias_enabled=bias_enabled,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -362,7 +362,7 @@ class PunicaWrapper:
|
|||||||
long_lora_len = self.indices_len[4]
|
long_lora_len = self.indices_len[4]
|
||||||
return self._long_lora_indices[:long_lora_len]
|
return self._long_lora_indices[:long_lora_len]
|
||||||
|
|
||||||
def shrink_prefill(
|
def _shrink_prefill(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -380,7 +380,7 @@ class PunicaWrapper:
|
|||||||
scale,
|
scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
def shrink_decode(
|
def _shrink_decode(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -389,7 +389,7 @@ class PunicaWrapper:
|
|||||||
):
|
):
|
||||||
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
||||||
|
|
||||||
def expand_prefill(
|
def _expand_prefill(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -407,7 +407,7 @@ class PunicaWrapper:
|
|||||||
add_input,
|
add_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
def expand_decode(
|
def _expand_decode(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -416,7 +416,7 @@ class PunicaWrapper:
|
|||||||
):
|
):
|
||||||
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
|
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
|
||||||
|
|
||||||
def expand_slice_prefill(
|
def _expand_slice_prefill(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -438,7 +438,7 @@ class PunicaWrapper:
|
|||||||
add_input,
|
add_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
def expand_slice_decode(
|
def _expand_slice_decode(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -450,41 +450,35 @@ class PunicaWrapper:
|
|||||||
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||||
y_slice_size, add_input)
|
y_slice_size, add_input)
|
||||||
|
|
||||||
def apply_bias(
|
def _apply_expand(self,
|
||||||
self,
|
y: torch.Tensor,
|
||||||
indices: torch.Tensor,
|
x: torch.Tensor,
|
||||||
output: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
bias_stacked: torch.Tensor,
|
y_offset: Optional[int],
|
||||||
):
|
y_slice_size: Optional[int],
|
||||||
"""Applies bias to output
|
add_input: bool = True):
|
||||||
|
"""
|
||||||
Input shapes:
|
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||||
bias_stacked: (num_loras, output_dim)
|
computation, which is suitable for the
|
||||||
indices: (batch_size)
|
GEMM of lora'b.
|
||||||
output: (batch_size, output_dim)
|
|
||||||
"""
|
"""
|
||||||
org_output = output
|
|
||||||
output = output.view(-1, output.shape[-1])
|
|
||||||
indices = indices.view(-1)
|
|
||||||
|
|
||||||
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
|
expand_slice_fun: Callable = (self._expand_slice_prefill
|
||||||
bias_stacked = bias_stacked[indices]
|
if self.is_prefill else
|
||||||
bias_stacked[indices == -1] = 0
|
self._expand_slice_decode)
|
||||||
output += bias_stacked
|
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
|
||||||
|
|
||||||
return output.view_as(org_output)
|
def _apply_bias(
|
||||||
|
|
||||||
def apply_bias_packed_nslice(
|
|
||||||
self,
|
self,
|
||||||
indices: torch.Tensor,
|
indices: torch.Tensor,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
output_slices: Tuple[int, ...],
|
output_slices: Tuple[int, ...],
|
||||||
bias_stacked: Tuple[Optional[torch.Tensor], ...],
|
lora_bias_stacked: Tuple[Optional[torch.Tensor], ...],
|
||||||
):
|
):
|
||||||
"""Applies bias to output
|
"""Applies bias to output
|
||||||
|
|
||||||
Input shapes:
|
Input shapes:
|
||||||
bias_stacked: 3 element tuple of (num_loras, output_dim)
|
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
|
||||||
indices: (batch_size)
|
indices: (batch_size)
|
||||||
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
||||||
output_slices: n-1 element tuple of (slice_size...),
|
output_slices: n-1 element tuple of (slice_size...),
|
||||||
@ -496,7 +490,7 @@ class PunicaWrapper:
|
|||||||
|
|
||||||
offset_left = 0
|
offset_left = 0
|
||||||
for slice_idx, slice in enumerate(output_slices):
|
for slice_idx, slice in enumerate(output_slices):
|
||||||
bias = bias_stacked[slice_idx]
|
bias = lora_bias_stacked[slice_idx]
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
bias = bias.view(-1, bias.shape[-1])
|
bias = bias.view(-1, bias.shape[-1])
|
||||||
bias = bias[indices]
|
bias = bias[indices]
|
||||||
@ -506,7 +500,7 @@ class PunicaWrapper:
|
|||||||
|
|
||||||
return output.view_as(org_output)
|
return output.view_as(org_output)
|
||||||
|
|
||||||
def add_shrink(
|
def _apply_shrink(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -517,188 +511,215 @@ class PunicaWrapper:
|
|||||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||||
GEMM of lora'a.
|
GEMM of lora'a.
|
||||||
When `is_prefill is` true, it indicates that it is currently the
|
When `is_prefill is` true, it indicates that it is currently the
|
||||||
prefill stage, and the `shrink_prefill` function should be called.
|
prefill stage, and the `_shrink_prefill` function should be called.
|
||||||
Otherwise, it is the decode stage, and the shrink_decode function
|
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||||
should be called.
|
should be called.
|
||||||
"""
|
"""
|
||||||
shrink_fun: Callable = (self.shrink_prefill
|
y_org = y
|
||||||
if self.is_prefill else self.shrink_decode)
|
y = y.view(-1, y.shape[-1])
|
||||||
|
shrink_fun: Callable = (self._shrink_prefill
|
||||||
|
if self.is_prefill else self._shrink_decode)
|
||||||
shrink_fun(y, x, w_t_all, scale)
|
shrink_fun(y, x, w_t_all, scale)
|
||||||
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
|
def add_shrink(
|
||||||
|
self,
|
||||||
|
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
scale: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Performs GEMM for multiple slices of lora_a.
|
||||||
|
When `is_prefill is` true, it indicates that it is currently the
|
||||||
|
prefill stage, and the `_shrink_prefill` function should be called.
|
||||||
|
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||||
|
should be called.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
for i in range(len(lora_a_stacked)):
|
||||||
|
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights
|
||||||
|
scale (float): Scaling factor for the operation
|
||||||
|
"""
|
||||||
|
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
# TODO fuse these kernels
|
||||||
|
for slice_idx in range(len(lora_a_stacked)):
|
||||||
|
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
|
||||||
|
scale)
|
||||||
|
|
||||||
def add_expand(
|
def add_expand(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
|
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
|
output_slices: Tuple[int, ...],
|
||||||
|
offset_start: int = 0,
|
||||||
|
add_input=True,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
for i in range(len(lora_b_stacked)):
|
||||||
|
slice = output_slices[i]
|
||||||
|
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||||
|
lora_bias_stacked[i]
|
||||||
|
offset += slice
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Output tensor.
|
||||||
|
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||||
|
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||||
|
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||||
|
bias's weight
|
||||||
|
output_slices (Tuple[int, ...]): Every slice's size
|
||||||
|
add_input (bool): Defaults to True.
|
||||||
|
"""
|
||||||
|
y_org = y
|
||||||
|
y = y.view(-1, y.shape[-1])
|
||||||
|
offset_left = offset_start
|
||||||
|
if lora_bias_stacked is not None:
|
||||||
|
self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||||
|
lora_bias_stacked)
|
||||||
|
for slice_idx in range(len(lora_b_stacked)):
|
||||||
|
self._apply_expand(
|
||||||
|
y,
|
||||||
|
x[slice_idx],
|
||||||
|
lora_b_stacked[slice_idx],
|
||||||
|
offset_left,
|
||||||
|
output_slices[slice_idx],
|
||||||
|
add_input=add_input,
|
||||||
|
)
|
||||||
|
offset_left += output_slices[slice_idx]
|
||||||
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
|
def add_lora_embedding(
|
||||||
self,
|
self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor,
|
lora_b_stacked: torch.Tensor,
|
||||||
bias_all: Optional[torch.Tensor],
|
|
||||||
add_input: bool = True,
|
add_input: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform the ` y+=x@w_t_all+bias` computation, which is suitable for the
|
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
|
||||||
GEMM of lora'b.
|
|
||||||
When `is_prefill` is true, it indicates that it is currently the
|
|
||||||
prefill stage, and the `expand_prefill` function should be called.
|
|
||||||
Otherwise, it is the decode stage, and the expand_decode function
|
|
||||||
should be called.
|
|
||||||
"""
|
|
||||||
if bias_all is not None:
|
|
||||||
y = self.apply_bias(self.token_lora_indices, y, bias_all)
|
|
||||||
|
|
||||||
expand_fun: Callable = (self.expand_prefill
|
|
||||||
if self.is_prefill else self.expand_decode)
|
|
||||||
expand_fun(y, x, w_t_all, add_input)
|
|
||||||
|
|
||||||
def add_expand_slice(self,
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
w_t_all: torch.Tensor,
|
|
||||||
bias_all: Optional[torch.Tensor],
|
|
||||||
y_offset: Optional[int],
|
|
||||||
y_slice_size: Optional[int],
|
|
||||||
add_input: bool = True):
|
|
||||||
"""
|
|
||||||
Similar to `add_expand`
|
|
||||||
"""
|
|
||||||
if bias_all is not None:
|
|
||||||
y = self.apply_bias(self.token_lora_indices, y, bias_all)
|
|
||||||
|
|
||||||
expand_slice_fun: Callable = (self.expand_slice_prefill
|
|
||||||
if self.is_prefill else
|
|
||||||
self.expand_slice_decode)
|
|
||||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
|
|
||||||
|
|
||||||
def add_expand_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
|
|
||||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
|
||||||
bias_stacked: Optional[Tuple[torch.Tensor,
|
|
||||||
...]],
|
|
||||||
scale: float,
|
|
||||||
output_slices: Tuple[int, ...]) -> None:
|
|
||||||
"""
|
|
||||||
Similar to `add_expand`
|
|
||||||
"""
|
|
||||||
y_org = y
|
|
||||||
y = y.view(-1, y.shape[-1])
|
|
||||||
offset_left = 0
|
|
||||||
if bias_stacked is not None:
|
|
||||||
self.apply_bias_packed_nslice(self.token_lora_indices, y,
|
|
||||||
output_slices, bias_stacked)
|
|
||||||
for slice_idx in range(len(lora_b_stacked)):
|
|
||||||
self.add_expand_slice(y,
|
|
||||||
x[slice_idx],
|
|
||||||
lora_b_stacked[slice_idx],
|
|
||||||
None,
|
|
||||||
offset_left,
|
|
||||||
output_slices[slice_idx],
|
|
||||||
add_input=True)
|
|
||||||
offset_left += output_slices[slice_idx]
|
|
||||||
|
|
||||||
y = y.view_as(y_org)
|
|
||||||
|
|
||||||
def add_lora(self,
|
|
||||||
y: torch.Tensor,
|
|
||||||
x: torch.Tensor,
|
|
||||||
wa_t_all: torch.Tensor,
|
|
||||||
wb_t_all: torch.Tensor,
|
|
||||||
bias_all: Optional[torch.Tensor],
|
|
||||||
scale: float,
|
|
||||||
y_offset: Optional[int] = None,
|
|
||||||
y_slice_size: Optional[int] = None,
|
|
||||||
*,
|
|
||||||
buffer: Optional[torch.Tensor] = None) -> None:
|
|
||||||
"""
|
|
||||||
Semantics:
|
Semantics:
|
||||||
y[i] += (
|
y += x @ lora_b_stacked
|
||||||
x[i].unsqueeze(0)
|
|
||||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
|
||||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
|
||||||
* scale
|
|
||||||
).squeeze(0)+bias[i]
|
|
||||||
Args:
|
Args:
|
||||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
y (torch.Tensor): Output tensor.
|
||||||
x (torch.Tensor): Input tensor
|
x (torch.Tensor): Input tensor.
|
||||||
wa_t_all (torch.Tensor): lora_a's weight
|
lora_b_stacked (torch.Tensor): lora_b's weights.
|
||||||
wb_t_all (torch.Tensor): lora_b's weight
|
add_input (bool): Default to True.
|
||||||
bias_all: (torch.Tensor): lora's bias
|
|
||||||
scale (float): Scaling factor.
|
|
||||||
y_offset (Optional[int], optional): Offset to apply to the starting
|
|
||||||
column of y.
|
|
||||||
y_slice_size (Optional[int], optional): Size of the y column slice.
|
|
||||||
buffer (Optional[torch.Tensor], optional): Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
y_org = y
|
|
||||||
y = y.view(-1, y.shape[-1])
|
# Embedding layer only need expand op
|
||||||
x = x.view(-1, x.shape[-1])
|
expand_fun: Callable = (self._expand_prefill
|
||||||
r = wb_t_all.size(-1)
|
if self.is_prefill else self._expand_decode)
|
||||||
|
expand_fun(y, x, lora_b_stacked, add_input)
|
||||||
|
|
||||||
|
def add_lora_linear(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
|
scale: float,
|
||||||
|
output_slices: Tuple[int, ...],
|
||||||
|
*,
|
||||||
|
buffer: Optional[Tuple[torch.Tensor, ...]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Applicable to linear-related lora.
|
||||||
|
|
||||||
|
Semantics:
|
||||||
|
for i in range(len(lora_a_stacked)):
|
||||||
|
y[i] += (
|
||||||
|
x[i].unsqueeze(0)
|
||||||
|
@ lora_a_stacked[indices[i], layer_idx, :, :]
|
||||||
|
@ lora_b_stacked[indices[i], layer_idx, :, :]
|
||||||
|
* scale
|
||||||
|
).squeeze(0)+lora_bias_stacked[i]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||||
|
x (torch.Tensor): Input tensor
|
||||||
|
lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight.
|
||||||
|
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight.
|
||||||
|
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias.
|
||||||
|
scale (float): Scaling factor.
|
||||||
|
output_slices (Tuple[int, ...]): Every slice's size.
|
||||||
|
buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||||
|
if lora_bias_stacked is not None:
|
||||||
|
assert len(lora_bias_stacked) == len(output_slices)
|
||||||
|
y = self._apply_bias(self.token_lora_indices, y, output_slices,
|
||||||
|
lora_bias_stacked)
|
||||||
|
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
|
r = lora_b_stacked[0].size(-1)
|
||||||
# We set the buffer to be float32 by default ,refer to:
|
# We set the buffer to be float32 by default ,refer to:
|
||||||
# https://github.com/triton-lang/triton/issues/1387
|
# https://github.com/triton-lang/triton/issues/1387
|
||||||
buffer = torch.zeros((x.size(0), r),
|
buffer = tuple(
|
||||||
dtype=torch.float32,
|
torch.zeros(
|
||||||
device=x.device)
|
(x.size(0), r), dtype=torch.float32, device=x.device)
|
||||||
if bias_all is not None:
|
for _ in range(len(output_slices)))
|
||||||
y = self.apply_bias(self.token_lora_indices, y, bias_all)
|
self.add_shrink(buffer, x, lora_a_stacked, scale)
|
||||||
self.add_shrink(buffer, x, wa_t_all, scale)
|
self.add_expand(y,
|
||||||
if y_offset is None and y_slice_size is None:
|
buffer,
|
||||||
self.add_expand(y, buffer, wb_t_all, bias_all=None, add_input=True)
|
lora_b_stacked,
|
||||||
else:
|
None,
|
||||||
self.add_expand_slice(y,
|
output_slices,
|
||||||
buffer,
|
add_input=True)
|
||||||
wb_t_all,
|
|
||||||
None,
|
|
||||||
y_offset,
|
|
||||||
y_slice_size,
|
|
||||||
add_input=True)
|
|
||||||
y = y.view_as(y_org)
|
|
||||||
|
|
||||||
def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
|
|
||||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
|
||||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
|
||||||
bias_all: Tuple[Optional[torch.Tensor],
|
|
||||||
...], scale: float,
|
|
||||||
output_slices: Tuple[int, ...]) -> None:
|
|
||||||
"""
|
|
||||||
Applies lora to each input. Similar to add_lora, This method is
|
|
||||||
used for layers that are composed of multiple sublayers
|
|
||||||
(slices) packed together.
|
|
||||||
"""
|
|
||||||
y_org = y
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
y = y.view(-1, y.shape[-1])
|
|
||||||
offset_left = 0
|
|
||||||
if bias_all is not None:
|
|
||||||
y = self.apply_bias_packed_nslice(self.token_lora_indices, y,
|
|
||||||
output_slices, bias_all)
|
|
||||||
# TODO fuse these kernels
|
|
||||||
for slice_idx in range(len(output_slices)):
|
|
||||||
self.add_lora(y, x, lora_a_stacked[slice_idx],
|
|
||||||
lora_b_stacked[slice_idx], None, scale, offset_left,
|
|
||||||
output_slices[slice_idx])
|
|
||||||
offset_left += output_slices[slice_idx]
|
|
||||||
|
|
||||||
y = y.view_as(y_org)
|
|
||||||
|
|
||||||
def add_lora_logits(self,
|
def add_lora_logits(self,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
wa_t_all: torch.Tensor,
|
lora_a_stacked: torch.Tensor,
|
||||||
wb_t_all: torch.Tensor,
|
lora_b_stacked: torch.Tensor,
|
||||||
scale,
|
scale,
|
||||||
*,
|
*,
|
||||||
buffer: Optional[torch.Tensor] = None) -> None:
|
buffer: Optional[torch.Tensor] = None) -> None:
|
||||||
"""
|
"""
|
||||||
LogitsProcessorWithLoRA always using bgmv
|
Applies lora specifically for LogitsProcessorWithLoRA.
|
||||||
"""
|
|
||||||
|
Semantics:
|
||||||
|
buffer = (x @ lora_a_stacked) * scale
|
||||||
|
y += buffer @ lora_b_stacked
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y (torch.Tensor): Output tensor.
|
||||||
|
x (torch.Tensor): Input tensor.
|
||||||
|
lora_a_stacked (torch.Tensor): lora_a's weights.
|
||||||
|
lora_b_stacked (torch.Tensor):lora_b's weights.
|
||||||
|
scale (float): Scaling factor.
|
||||||
|
buffer (Optional[torch.Tensor]):Default to None.
|
||||||
|
"""
|
||||||
y_org = y
|
y_org = y
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
r = wb_t_all.size(-1)
|
r = lora_b_stacked.size(-1)
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
# We set the buffer to be float32 by default ,refer to:
|
# We set the buffer to be float32 by default ,refer to:
|
||||||
# https://github.com/triton-lang/triton/issues/1387
|
# https://github.com/triton-lang/triton/issues/1387
|
||||||
buffer = torch.zeros((x.size(0), r),
|
buffer = torch.zeros((x.size(0), r),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=x.device)
|
device=x.device)
|
||||||
|
# LogitsProcessorWithLoRA always using bgmv.
|
||||||
bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
|
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
|
||||||
bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
|
bgmv_expand(buffer,
|
||||||
|
lora_b_stacked,
|
||||||
|
y,
|
||||||
|
self.sampler_indices,
|
||||||
|
add_inputs=True)
|
||||||
y = y.view_as(y_org)
|
y = y.view_as(y_org)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user