mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-10 02:09:08 +08:00
[LoRA] Adds support for bias in LoRA (#5733)
Signed-off-by: Umesh Deshpande <udeshpa@us.ibm.com> Co-authored-by: Umesh Deshpande <udeshpa@us.ibm.com>
This commit is contained in:
parent
b41fb9d3b1
commit
8a06428c70
@ -152,6 +152,11 @@ def sql_lora_files(sql_lora_huggingface_id):
|
|||||||
return snapshot_download(repo_id=sql_lora_huggingface_id)
|
return snapshot_download(repo_id=sql_lora_huggingface_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def lora_bias_files():
|
||||||
|
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def mixtral_lora_files():
|
def mixtral_lora_files():
|
||||||
# Note: this module has incorrect adapter_config.json to test
|
# Note: this module has incorrect adapter_config.json to test
|
||||||
|
|||||||
52
tests/lora/test_lora_bias_e2e.py
Normal file
52
tests/lora/test_lora_bias_e2e.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
MODEL_PATH = "ibm-granite/granite-3b-code-base"
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||||
|
prompts = [
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
|
||||||
|
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
|
||||||
|
]
|
||||||
|
sampling_params = vllm.SamplingParams(temperature=0,
|
||||||
|
max_tokens=256,
|
||||||
|
stop=["[/assistant]"])
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||||
|
if lora_id else None)
|
||||||
|
generated_texts: List[str] = []
|
||||||
|
for output in outputs:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
generated_texts.append(generated_text)
|
||||||
|
return generated_texts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("lora_bias", [True])
|
||||||
|
@pytest.mark.parametrize("fully_sharded", [True, False])
|
||||||
|
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
|
||||||
|
llm = vllm.LLM(MODEL_PATH,
|
||||||
|
enable_lora=True,
|
||||||
|
max_num_seqs=16,
|
||||||
|
max_lora_rank=8,
|
||||||
|
max_loras=1,
|
||||||
|
enable_lora_bias=lora_bias,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
fully_sharded_loras=fully_sharded)
|
||||||
|
|
||||||
|
print("lora adapter created")
|
||||||
|
output1 = do_sample(llm, lora_bias_files, lora_id=0)
|
||||||
|
|
||||||
|
print("lora")
|
||||||
|
output2 = do_sample(llm, lora_bias_files, lora_id=1)
|
||||||
|
|
||||||
|
if lora_bias:
|
||||||
|
assert output1 != output2
|
||||||
|
else:
|
||||||
|
assert output1 == output2
|
||||||
@ -12,36 +12,40 @@ from vllm.utils import LRUCache
|
|||||||
|
|
||||||
def test_parse_fine_tuned_lora_name_valid():
|
def test_parse_fine_tuned_lora_name_valid():
|
||||||
fixture = {
|
fixture = {
|
||||||
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
|
("base_model.model.lm_head.lora_A.weight", "lm_head", True, False),
|
||||||
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
|
("base_model.model.lm_head.lora_B.weight", "lm_head", False, False),
|
||||||
(
|
(
|
||||||
"base_model.model.model.embed_tokens.lora_embedding_A",
|
"base_model.model.model.embed_tokens.lora_embedding_A",
|
||||||
"model.embed_tokens",
|
"model.embed_tokens",
|
||||||
True,
|
True,
|
||||||
|
False,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"base_model.model.model.embed_tokens.lora_embedding_B",
|
"base_model.model.model.embed_tokens.lora_embedding_B",
|
||||||
"model.embed_tokens",
|
"model.embed_tokens",
|
||||||
False,
|
False,
|
||||||
|
False,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
||||||
"model.layers.9.mlp.down_proj",
|
"model.layers.9.mlp.down_proj",
|
||||||
True,
|
True,
|
||||||
|
False,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
||||||
"model.layers.9.mlp.down_proj",
|
"model.layers.9.mlp.down_proj",
|
||||||
False,
|
False,
|
||||||
|
False,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
for name, module_name, is_lora_a in fixture:
|
for name, module_name, is_lora_a, is_bias in fixture:
|
||||||
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
|
assert (module_name, is_lora_a,
|
||||||
|
is_bias) == parse_fine_tuned_lora_name(name)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_fine_tuned_lora_name_invalid():
|
def test_parse_fine_tuned_lora_name_invalid():
|
||||||
fixture = {
|
fixture = {
|
||||||
"weight",
|
|
||||||
"base_model.weight",
|
"base_model.weight",
|
||||||
"base_model.model.weight",
|
"base_model.model.weight",
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1687,6 +1687,7 @@ class LoRAConfig:
|
|||||||
# This is a constant.
|
# This is a constant.
|
||||||
lora_vocab_padding_size: ClassVar[int] = 256
|
lora_vocab_padding_size: ClassVar[int] = 256
|
||||||
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
||||||
|
bias_enabled: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Setting the maximum rank to 256 should be able to satisfy the vast
|
# Setting the maximum rank to 256 should be able to satisfy the vast
|
||||||
|
|||||||
@ -143,6 +143,7 @@ class EngineArgs:
|
|||||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||||
enable_lora: bool = False
|
enable_lora: bool = False
|
||||||
|
enable_lora_bias: bool = False
|
||||||
max_loras: int = 1
|
max_loras: int = 1
|
||||||
max_lora_rank: int = 16
|
max_lora_rank: int = 16
|
||||||
enable_prompt_adapter: bool = False
|
enable_prompt_adapter: bool = False
|
||||||
@ -584,6 +585,9 @@ class EngineArgs:
|
|||||||
parser.add_argument('--enable-lora',
|
parser.add_argument('--enable-lora',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='If True, enable handling of LoRA adapters.')
|
help='If True, enable handling of LoRA adapters.')
|
||||||
|
parser.add_argument('--enable-lora-bias',
|
||||||
|
action='store_true',
|
||||||
|
help='If True, enable bias for LoRA adapters.')
|
||||||
parser.add_argument('--max-loras',
|
parser.add_argument('--max-loras',
|
||||||
type=int,
|
type=int,
|
||||||
default=EngineArgs.max_loras,
|
default=EngineArgs.max_loras,
|
||||||
@ -1148,6 +1152,7 @@ class EngineArgs:
|
|||||||
and parallel_config.use_ray),
|
and parallel_config.use_ray),
|
||||||
policy=self.scheduling_policy)
|
policy=self.scheduling_policy)
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
|
bias_enabled=self.enable_lora_bias,
|
||||||
max_lora_rank=self.max_lora_rank,
|
max_lora_rank=self.max_lora_rank,
|
||||||
max_loras=self.max_loras,
|
max_loras=self.max_loras,
|
||||||
fully_sharded_loras=self.fully_sharded_loras,
|
fully_sharded_loras=self.fully_sharded_loras,
|
||||||
|
|||||||
@ -70,6 +70,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
self.lora_b_stacked,
|
self.lora_b_stacked,
|
||||||
add_input=True)
|
add_input=True)
|
||||||
# now have column partitioned output
|
# now have column partitioned output
|
||||||
|
|
||||||
|
if self.bias_stacked is not None:
|
||||||
|
self.bias_stacked = self.bias_stacked.view(
|
||||||
|
-1, self.bias_stacked.shape[-1])
|
||||||
|
self.bias_stacked = self.bias_stacked[
|
||||||
|
self.punica_wrapper.token_lora_indices]
|
||||||
|
output += self.bias_stacked
|
||||||
|
|
||||||
output = output.view(*out_orig_shape)
|
output = output.view(*out_orig_shape)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
|
|||||||
left_offset = 0
|
left_offset = 0
|
||||||
for idx in range(n):
|
for idx in range(n):
|
||||||
shard_size = layer.lora_b_stacked[idx].shape[2]
|
shard_size = layer.lora_b_stacked[idx].shape[2]
|
||||||
|
|
||||||
|
if layer.bias_stacked is not None:
|
||||||
|
bias = layer.bias_stacked[idx]
|
||||||
|
if bias is not None:
|
||||||
|
bias = bias.view(-1, bias.shape[-1])
|
||||||
|
bias = bias[layer.punica_wrapper.token_lora_indices]
|
||||||
|
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
|
||||||
|
output[:, left_offset:left_offset + shard_size] += bias
|
||||||
|
|
||||||
layer.punica_wrapper.add_expand_slice(
|
layer.punica_wrapper.add_expand_slice(
|
||||||
output,
|
output,
|
||||||
buffers[idx],
|
buffers[idx],
|
||||||
@ -295,6 +312,15 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
|||||||
lora_b = lora_b[:, start_idx:end_idx]
|
lora_b = lora_b[:, start_idx:end_idx]
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
|
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||||
|
if bias is None:
|
||||||
|
return bias
|
||||||
|
shard_size = self.bias_stacked.shape[2]
|
||||||
|
start_idx = self.tp_rank * shard_size
|
||||||
|
end_idx = (self.tp_rank + 1) * shard_size
|
||||||
|
bias = bias[start_idx:end_idx]
|
||||||
|
return bias
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||||
|
|
||||||
@ -318,6 +344,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
|||||||
# reduced before being used
|
# reduced before being used
|
||||||
shard_size = self.lora_b_stacked.shape[2]
|
shard_size = self.lora_b_stacked.shape[2]
|
||||||
start_idx = self.tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
|
|
||||||
|
if self.bias_stacked is not None:
|
||||||
|
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
|
||||||
|
bias = bias[self.punica_wrapper.token_lora_indices]
|
||||||
|
bias[self.punica_wrapper.token_lora_indices == -1] = 0
|
||||||
|
output += bias
|
||||||
|
|
||||||
self.punica_wrapper.add_expand_slice(output, buffer,
|
self.punica_wrapper.add_expand_slice(output, buffer,
|
||||||
self.lora_b_stacked, start_idx,
|
self.lora_b_stacked, start_idx,
|
||||||
shard_size)
|
shard_size)
|
||||||
|
|||||||
@ -67,6 +67,63 @@ def _not_fully_sharded_can_replace(can_replace):
|
|||||||
return dec
|
return dec
|
||||||
|
|
||||||
|
|
||||||
|
def apply_bias(
|
||||||
|
indices: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
bias_stacked: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Applies bias to output
|
||||||
|
|
||||||
|
Input shapes:
|
||||||
|
bias_stacked: (num_loras, output_dim)
|
||||||
|
indices: (batch_size)
|
||||||
|
output: (batch_size, output_dim)
|
||||||
|
"""
|
||||||
|
org_output = output
|
||||||
|
output = output.view(-1, output.shape[-1])
|
||||||
|
indices = indices.view(-1)
|
||||||
|
|
||||||
|
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
|
||||||
|
bias_stacked = bias_stacked[indices]
|
||||||
|
bias_stacked[indices == -1] = 0
|
||||||
|
output += bias_stacked
|
||||||
|
|
||||||
|
return output.view_as(org_output)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_bias_packed_nslice(
|
||||||
|
indices: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
output_slices: Tuple[int, ...],
|
||||||
|
bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
|
):
|
||||||
|
"""Applies bias to output
|
||||||
|
|
||||||
|
Input shapes:
|
||||||
|
bias_stacked: 3 element tuple of (num_loras, output_dim)
|
||||||
|
indices: (batch_size)
|
||||||
|
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
||||||
|
output_slices: n-1 element tuple of (slice_size...),
|
||||||
|
where n is number of slices
|
||||||
|
"""
|
||||||
|
org_output = output
|
||||||
|
output = output.view(-1, output.shape[-1])
|
||||||
|
indices = indices.view(-1)
|
||||||
|
|
||||||
|
offset_left = 0
|
||||||
|
for slice_idx, slice in enumerate(output_slices):
|
||||||
|
bias = bias_stacked[slice_idx]
|
||||||
|
if bias is not None:
|
||||||
|
bias = bias.view(-1, bias.shape[-1])
|
||||||
|
bias = bias[indices]
|
||||||
|
bias[indices == -1] = 0
|
||||||
|
output[:, offset_left:offset_left + slice] += bias
|
||||||
|
|
||||||
|
offset_left += slice
|
||||||
|
|
||||||
|
return output.view_as(org_output)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoRAMapping(AdapterMapping):
|
class LoRAMapping(AdapterMapping):
|
||||||
is_prefill: bool = False
|
is_prefill: bool = False
|
||||||
@ -105,6 +162,7 @@ class BaseLayerWithLoRA(nn.Module):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
"""Overwrites lora tensors at index."""
|
"""Overwrites lora tensors at index."""
|
||||||
...
|
...
|
||||||
@ -203,6 +261,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
||||||
@ -299,10 +358,22 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
if lora_config.bias_enabled:
|
||||||
|
self.bias_stacked = torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.output_size,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.bias_stacked = None
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
self.lora_a_stacked[index] = 0
|
self.lora_a_stacked[index] = 0
|
||||||
self.lora_b_stacked[index] = 0
|
self.lora_b_stacked[index] = 0
|
||||||
|
if self.lora_config.bias_enabled:
|
||||||
|
self.bias_stacked[index] = 0
|
||||||
|
|
||||||
def set_lora(
|
def set_lora(
|
||||||
self,
|
self,
|
||||||
@ -310,6 +381,7 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
|
|
||||||
@ -319,10 +391,21 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.lora_b_stacked[index,
|
self.lora_b_stacked[index,
|
||||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||||
lora_b.T, non_blocking=True)
|
lora_b.T, non_blocking=True)
|
||||||
|
if bias is not None:
|
||||||
|
self.bias_stacked[index,
|
||||||
|
0, :bias.shape[0]].copy_(bias.T,
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor,
|
def apply(self, x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||||
|
if self.bias_stacked is not None:
|
||||||
|
self.indices = self.punica_wrapper.token_lora_indices
|
||||||
|
output = apply_bias(
|
||||||
|
self.indices,
|
||||||
|
output,
|
||||||
|
self.bias_stacked,
|
||||||
|
)
|
||||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||||
self.lora_b_stacked, 1.0)
|
self.lora_b_stacked, 1.0)
|
||||||
return output
|
return output
|
||||||
@ -401,11 +484,25 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if lora_config.bias_enabled:
|
||||||
|
self.bias_stacked = torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.output_size,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.bias_stacked = None
|
||||||
|
|
||||||
self.output_dim = self.lora_b_stacked.shape[2]
|
self.output_dim = self.lora_b_stacked.shape[2]
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
self.lora_a_stacked[index] = 0
|
self.lora_a_stacked[index] = 0
|
||||||
self.lora_b_stacked[index] = 0
|
self.lora_b_stacked[index] = 0
|
||||||
|
if self.lora_config.bias_enabled:
|
||||||
|
self.bias_stacked[index] = 0
|
||||||
|
|
||||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||||
return lora_a
|
return lora_a
|
||||||
@ -418,18 +515,30 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_b = lora_b[:, start_idx:end_idx]
|
lora_b = lora_b[:, start_idx:end_idx]
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
|
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||||
|
if bias is None:
|
||||||
|
return bias
|
||||||
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
|
shard_size = self.output_dim
|
||||||
|
start_idx = tensor_model_parallel_rank * shard_size
|
||||||
|
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||||
|
bias = bias[start_idx:end_idx]
|
||||||
|
return bias
|
||||||
|
|
||||||
def set_lora(
|
def set_lora(
|
||||||
self,
|
self,
|
||||||
index: int,
|
index: int,
|
||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
lora_a = self.slice_lora_a(lora_a)
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
lora_b = self.slice_lora_b(lora_b)
|
||||||
|
bias = self.slice_bias(bias)
|
||||||
|
|
||||||
self.lora_a_stacked[index,
|
self.lora_a_stacked[index,
|
||||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||||
@ -437,10 +546,21 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.lora_b_stacked[index,
|
self.lora_b_stacked[index,
|
||||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||||
lora_b.T, non_blocking=True)
|
lora_b.T, non_blocking=True)
|
||||||
|
if bias is not None:
|
||||||
|
self.bias_stacked[index,
|
||||||
|
0, :bias.shape[0]].copy_(bias.T,
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor,
|
def apply(self, x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||||
|
if self.bias_stacked is not None:
|
||||||
|
self.indices = self.punica_wrapper.token_lora_indices
|
||||||
|
output = apply_bias(
|
||||||
|
self.indices,
|
||||||
|
output,
|
||||||
|
self.bias_stacked,
|
||||||
|
)
|
||||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||||
self.lora_b_stacked, 1.0)
|
self.lora_b_stacked, 1.0)
|
||||||
return output
|
return output
|
||||||
@ -534,6 +654,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
dtype=lora_config.lora_dtype,
|
dtype=lora_config.lora_dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
) for _ in range(n_slices))
|
) for _ in range(n_slices))
|
||||||
|
if lora_config.bias_enabled:
|
||||||
|
self.bias_stacked = tuple(
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.output_size // 2,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
) for _ in range(n_slices))
|
||||||
|
else:
|
||||||
|
self.bias_stacked = None
|
||||||
|
|
||||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||||
|
|
||||||
@ -542,6 +673,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
self.lora_a_stacked[1][index] = 0
|
self.lora_a_stacked[1][index] = 0
|
||||||
self.lora_b_stacked[0][index] = 0
|
self.lora_b_stacked[0][index] = 0
|
||||||
self.lora_b_stacked[1][index] = 0
|
self.lora_b_stacked[1][index] = 0
|
||||||
|
if self.lora_config.bias_enabled:
|
||||||
|
self.bias_stacked[0][index] = 0
|
||||||
|
self.bias_stacked[1][index] = 0
|
||||||
|
|
||||||
def slice_lora_a(
|
def slice_lora_a(
|
||||||
self, lora_a: List[Union[torch.Tensor, None]]
|
self, lora_a: List[Union[torch.Tensor, None]]
|
||||||
@ -562,18 +696,32 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
]
|
]
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
|
def slice_bias(
|
||||||
|
self, bias: List[Union[torch.Tensor,
|
||||||
|
None]]) -> List[Union[torch.Tensor, None]]:
|
||||||
|
if bias[0] is None or bias[1] is None:
|
||||||
|
return bias
|
||||||
|
shard_size = self.output_dim
|
||||||
|
start_idx = self.tp_rank * shard_size
|
||||||
|
end_idx = (self.tp_rank + 1) * shard_size
|
||||||
|
bias = [bias[0][start_idx:end_idx], bias[1][start_idx:end_idx]]
|
||||||
|
return bias
|
||||||
|
|
||||||
def set_lora(
|
def set_lora(
|
||||||
self,
|
self,
|
||||||
index: int,
|
index: int,
|
||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
lora_a = self.slice_lora_a(lora_a)
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
lora_b = self.slice_lora_b(lora_b)
|
||||||
|
if bias is not None:
|
||||||
|
bias = self.slice_bias(bias)
|
||||||
|
|
||||||
if lora_a[0] is not None:
|
if lora_a[0] is not None:
|
||||||
self.lora_a_stacked[0][
|
self.lora_a_stacked[0][
|
||||||
@ -582,6 +730,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
self.lora_b_stacked[0][
|
self.lora_b_stacked[0][
|
||||||
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
||||||
lora_b[0].T, non_blocking=True)
|
lora_b[0].T, non_blocking=True)
|
||||||
|
if bias is not None and bias[0] is not None:
|
||||||
|
self.bias_stacked[0][index,
|
||||||
|
0, :bias[0].shape[0]].copy_(bias[0].T,
|
||||||
|
non_blocking=True)
|
||||||
if lora_a[1] is not None:
|
if lora_a[1] is not None:
|
||||||
self.lora_a_stacked[1][
|
self.lora_a_stacked[1][
|
||||||
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
||||||
@ -589,10 +741,22 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
self.lora_b_stacked[1][
|
self.lora_b_stacked[1][
|
||||||
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||||
lora_b[1].T, non_blocking=True)
|
lora_b[1].T, non_blocking=True)
|
||||||
|
if bias is not None and bias[1] is not None:
|
||||||
|
self.bias_stacked[1][index,
|
||||||
|
0, :bias[1].shape[0]].copy_(bias[1].T,
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor,
|
def apply(self, x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||||
|
if self.bias_stacked is not None:
|
||||||
|
self.indices = self.punica_wrapper.token_lora_indices
|
||||||
|
output = apply_bias_packed_nslice(
|
||||||
|
self.indices,
|
||||||
|
output,
|
||||||
|
(self.output_dim, self.output_dim),
|
||||||
|
self.bias_stacked,
|
||||||
|
)
|
||||||
self.punica_wrapper.add_lora_packed_nslice(
|
self.punica_wrapper.add_lora_packed_nslice(
|
||||||
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
|
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
|
||||||
(self.output_dim, self.output_dim))
|
(self.output_dim, self.output_dim))
|
||||||
@ -654,17 +818,35 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
|
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
|
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||||
|
bias_q = bias[self.q_proj_shard_size *
|
||||||
|
self.q_shard_id:self.q_proj_shard_size *
|
||||||
|
(self.q_shard_id + 1)]
|
||||||
|
k_offset = self.q_proj_total_size
|
||||||
|
bias_k = bias[k_offset +
|
||||||
|
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
|
||||||
|
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||||
|
v_offset = k_offset + self.kv_proj_total_size
|
||||||
|
bias_v = bias[v_offset +
|
||||||
|
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
|
||||||
|
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||||
|
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
|
||||||
|
return bias
|
||||||
|
|
||||||
def set_lora(
|
def set_lora(
|
||||||
self,
|
self,
|
||||||
index: int,
|
index: int,
|
||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
lora_a = self.slice_lora_a(lora_a)
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
lora_b = self.slice_lora_b(lora_b)
|
||||||
|
if bias is not None:
|
||||||
|
bias = self.slice_bias(bias)
|
||||||
|
|
||||||
self.lora_a_stacked[index,
|
self.lora_a_stacked[index,
|
||||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||||
@ -672,6 +854,10 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
self.lora_b_stacked[index,
|
self.lora_b_stacked[index,
|
||||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||||
lora_b.T, non_blocking=True)
|
lora_b.T, non_blocking=True)
|
||||||
|
if bias is not None:
|
||||||
|
self.bias_stacked[index,
|
||||||
|
0, :bias.shape[0]].copy_(bias.T,
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@_not_fully_sharded_can_replace
|
@_not_fully_sharded_can_replace
|
||||||
@ -768,6 +954,32 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
if lora_config.bias_enabled:
|
||||||
|
self.bias_stacked = (
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.q_proj_shard_size,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.kv_proj_shard_size,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
torch.zeros(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.kv_proj_shard_size,
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.bias_stacked = None
|
||||||
|
|
||||||
self.output_slices = (
|
self.output_slices = (
|
||||||
self.q_proj_shard_size,
|
self.q_proj_shard_size,
|
||||||
@ -787,6 +999,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
self.lora_b_stacked[1][index] = 0
|
self.lora_b_stacked[1][index] = 0
|
||||||
self.lora_a_stacked[2][index] = 0
|
self.lora_a_stacked[2][index] = 0
|
||||||
self.lora_b_stacked[2][index] = 0
|
self.lora_b_stacked[2][index] = 0
|
||||||
|
if self.lora_config.bias_enabled:
|
||||||
|
self.bias_stacked[0][index] = 0
|
||||||
|
self.bias_stacked[1][index] = 0
|
||||||
|
self.bias_stacked[2][index] = 0
|
||||||
|
|
||||||
def slice_lora_a(
|
def slice_lora_a(
|
||||||
self, lora_a: List[Union[torch.Tensor, None]]
|
self, lora_a: List[Union[torch.Tensor, None]]
|
||||||
@ -812,18 +1028,40 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
lora_b = [lora_b_q, lora_b_k, lora_b_v]
|
lora_b = [lora_b_q, lora_b_k, lora_b_v]
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
|
def slice_bias(
|
||||||
|
self, bias: List[Union[torch.Tensor,
|
||||||
|
None]]) -> List[Union[torch.Tensor, None]]:
|
||||||
|
bias_q, bias_k, bias_v = bias
|
||||||
|
if bias_q is not None:
|
||||||
|
bias_q = bias_q[self.q_proj_shard_size *
|
||||||
|
self.q_shard_id:self.q_proj_shard_size *
|
||||||
|
(self.q_shard_id + 1)]
|
||||||
|
if bias_k is not None:
|
||||||
|
bias_k = bias_k[self.kv_proj_shard_size *
|
||||||
|
self.kv_shard_id:self.kv_proj_shard_size *
|
||||||
|
(self.kv_shard_id + 1)]
|
||||||
|
if bias_v is not None:
|
||||||
|
bias_v = bias_v[self.kv_proj_shard_size *
|
||||||
|
self.kv_shard_id:self.kv_proj_shard_size *
|
||||||
|
(self.kv_shard_id + 1)]
|
||||||
|
bias = [bias_q, bias_k, bias_v]
|
||||||
|
return bias
|
||||||
|
|
||||||
def set_lora(
|
def set_lora(
|
||||||
self,
|
self,
|
||||||
index: int,
|
index: int,
|
||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
lora_a = self.slice_lora_a(lora_a)
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
lora_b = self.slice_lora_b(lora_b)
|
||||||
|
if bias is not None:
|
||||||
|
bias = self.slice_bias(bias)
|
||||||
|
|
||||||
if lora_b[0] is not None:
|
if lora_b[0] is not None:
|
||||||
lora_b_q = lora_b[0]
|
lora_b_q = lora_b[0]
|
||||||
@ -854,9 +1092,28 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||||||
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
|
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
|
||||||
lora_a[2].T, non_blocking=True)
|
lora_a[2].T, non_blocking=True)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
if bias[0] is not None:
|
||||||
|
self.bias_stacked[0][index, 0, :bias[0].shape[0]].copy_(
|
||||||
|
bias[0].T, non_blocking=True)
|
||||||
|
if bias[1] is not None:
|
||||||
|
self.bias_stacked[1][index, 0, :bias[1].shape[0]].copy_(
|
||||||
|
bias[1].T, non_blocking=True)
|
||||||
|
if bias[2] is not None:
|
||||||
|
self.bias_stacked[2][index, 0, :bias[2].shape[0]].copy_(
|
||||||
|
bias[2].T, non_blocking=True)
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor,
|
def apply(self, x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||||
|
if self.bias_stacked is not None:
|
||||||
|
self.indices = self.punica_wrapper.token_lora_indices
|
||||||
|
output = apply_bias_packed_nslice(
|
||||||
|
self.indices,
|
||||||
|
output,
|
||||||
|
self.output_slices,
|
||||||
|
self.bias_stacked,
|
||||||
|
)
|
||||||
self.punica_wrapper.add_lora_packed_nslice(output, x,
|
self.punica_wrapper.add_lora_packed_nslice(output, x,
|
||||||
self.lora_a_stacked,
|
self.lora_a_stacked,
|
||||||
self.lora_b_stacked, 1.0,
|
self.lora_b_stacked, 1.0,
|
||||||
@ -919,9 +1176,27 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if lora_config.bias_enabled:
|
||||||
|
self.bias_stacked = torch.zeros(
|
||||||
|
(
|
||||||
|
max_loras,
|
||||||
|
1,
|
||||||
|
self.output_size,
|
||||||
|
),
|
||||||
|
dtype=lora_config.lora_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.bias_stacked = None
|
||||||
|
# Lazily initialized
|
||||||
|
self.indices: torch.Tensor
|
||||||
|
self.indices_len: List[int]
|
||||||
|
|
||||||
def reset_lora(self, index: int):
|
def reset_lora(self, index: int):
|
||||||
self.lora_a_stacked[index] = 0
|
self.lora_a_stacked[index] = 0
|
||||||
self.lora_b_stacked[index] = 0
|
self.lora_b_stacked[index] = 0
|
||||||
|
if self.lora_config.bias_enabled:
|
||||||
|
self.bias_stacked[index] = 0
|
||||||
|
|
||||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
@ -934,18 +1209,24 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||||
return lora_b
|
return lora_b
|
||||||
|
|
||||||
|
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||||
|
return bias
|
||||||
|
|
||||||
def set_lora(
|
def set_lora(
|
||||||
self,
|
self,
|
||||||
index: int,
|
index: int,
|
||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
|
|
||||||
if self.base_layer.tp_size > 1:
|
if self.base_layer.tp_size > 1:
|
||||||
lora_a = self.slice_lora_a(lora_a)
|
lora_a = self.slice_lora_a(lora_a)
|
||||||
lora_b = self.slice_lora_b(lora_b)
|
lora_b = self.slice_lora_b(lora_b)
|
||||||
|
if bias is not None:
|
||||||
|
bias = self.slice_bias(bias)
|
||||||
|
|
||||||
self.lora_a_stacked[index,
|
self.lora_a_stacked[index,
|
||||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||||
@ -953,9 +1234,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.lora_b_stacked[index,
|
self.lora_b_stacked[index,
|
||||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||||
lora_b.T, non_blocking=True)
|
lora_b.T, non_blocking=True)
|
||||||
|
if bias is not None:
|
||||||
|
self.bias_stacked[index,
|
||||||
|
0, :bias.shape[0]].copy_(bias.T,
|
||||||
|
non_blocking=True)
|
||||||
|
|
||||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||||
|
if self.bias_stacked is not None:
|
||||||
|
self.indices = self.punica_wrapper.token_lora_indices
|
||||||
|
output = apply_bias(
|
||||||
|
self.indices,
|
||||||
|
output,
|
||||||
|
self.bias_stacked,
|
||||||
|
)
|
||||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||||
self.lora_b_stacked, 1.0)
|
self.lora_b_stacked, 1.0)
|
||||||
return output
|
return output
|
||||||
@ -1132,6 +1424,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
self.reset_lora(index)
|
self.reset_lora(index)
|
||||||
self.lora_a_stacked[index,
|
self.lora_a_stacked[index,
|
||||||
@ -1199,7 +1492,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
neginf=float("-inf")))
|
neginf=float("-inf")))
|
||||||
logits[:,
|
logits[:,
|
||||||
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
||||||
lora_logits.shape[1], ] = lora_logits
|
lora_logits.shape[1]] = lora_logits
|
||||||
|
|
||||||
# LogitsProcessorWithLoRA always using bgmv
|
# LogitsProcessorWithLoRA always using bgmv
|
||||||
self.punica_wrapper.add_lora_logits(logits, hidden_states,
|
self.punica_wrapper.add_lora_logits(logits, hidden_states,
|
||||||
@ -1276,6 +1569,7 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
|
|||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
embeddings_tensor: Optional[torch.Tensor],
|
embeddings_tensor: Optional[torch.Tensor],
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|||||||
@ -17,6 +17,7 @@ class LoRALayerWeights:
|
|||||||
lora_alpha: int,
|
lora_alpha: int,
|
||||||
lora_a: torch.Tensor,
|
lora_a: torch.Tensor,
|
||||||
lora_b: torch.Tensor,
|
lora_b: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
embeddings_tensor: Optional[torch.Tensor] = None,
|
embeddings_tensor: Optional[torch.Tensor] = None,
|
||||||
scaling: Optional[float] = None,
|
scaling: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -25,6 +26,7 @@ class LoRALayerWeights:
|
|||||||
self.lora_alpha = lora_alpha
|
self.lora_alpha = lora_alpha
|
||||||
self.lora_a = lora_a
|
self.lora_a = lora_a
|
||||||
self.lora_b = lora_b
|
self.lora_b = lora_b
|
||||||
|
self.bias = bias
|
||||||
self.embeddings_tensor = embeddings_tensor
|
self.embeddings_tensor = embeddings_tensor
|
||||||
|
|
||||||
if scaling is None:
|
if scaling is None:
|
||||||
@ -66,7 +68,8 @@ class LoRALayerWeights:
|
|||||||
rank: int,
|
rank: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.types.Device,
|
device: torch.types.Device,
|
||||||
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
|
embeddings_tensor_dim: Optional[int] = None,
|
||||||
|
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
|
||||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||||
lora_a = torch.zeros([input_dim, rank],
|
lora_a = torch.zeros([input_dim, rank],
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -76,6 +79,14 @@ class LoRALayerWeights:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
pin_memory=pin_memory)
|
pin_memory=pin_memory)
|
||||||
|
if bias_enabled:
|
||||||
|
bias = torch.zeros([output_dim],
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
pin_memory=pin_memory)
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
|
||||||
embeddings_tensor = torch.rand(
|
embeddings_tensor = torch.rand(
|
||||||
10,
|
10,
|
||||||
embeddings_tensor_dim,
|
embeddings_tensor_dim,
|
||||||
@ -88,6 +99,7 @@ class LoRALayerWeights:
|
|||||||
lora_alpha=1,
|
lora_alpha=1,
|
||||||
lora_a=lora_a,
|
lora_a=lora_a,
|
||||||
lora_b=lora_b,
|
lora_b=lora_b,
|
||||||
|
bias=bias,
|
||||||
embeddings_tensor=embeddings_tensor,
|
embeddings_tensor=embeddings_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -102,6 +114,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
|||||||
lora_alphas: List[Optional[int]],
|
lora_alphas: List[Optional[int]],
|
||||||
lora_a: List[Optional[torch.Tensor]],
|
lora_a: List[Optional[torch.Tensor]],
|
||||||
lora_b: List[Optional[torch.Tensor]],
|
lora_b: List[Optional[torch.Tensor]],
|
||||||
|
bias: Optional[List[Optional[torch.Tensor]]] = None,
|
||||||
scaling: Optional[List[float]] = None,
|
scaling: Optional[List[float]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -110,6 +123,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
|||||||
lora_alpha=0,
|
lora_alpha=0,
|
||||||
lora_a=lora_a,
|
lora_a=lora_a,
|
||||||
lora_b=lora_b,
|
lora_b=lora_b,
|
||||||
|
bias=bias,
|
||||||
scaling=scaling, # type: ignore
|
scaling=scaling, # type: ignore
|
||||||
embeddings_tensor=None,
|
embeddings_tensor=None,
|
||||||
)
|
)
|
||||||
@ -141,6 +155,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
|||||||
[lora.lora_alpha if lora is not None else None for lora in loras],
|
[lora.lora_alpha if lora is not None else None for lora in loras],
|
||||||
[lora.lora_a if lora is not None else None for lora in loras],
|
[lora.lora_a if lora is not None else None for lora in loras],
|
||||||
[lora.lora_b if lora is not None else None for lora in loras],
|
[lora.lora_b if lora is not None else None for lora in loras],
|
||||||
|
[lora.bias if lora is not None else None for lora in loras],
|
||||||
scaling=[
|
scaling=[
|
||||||
1 if lora is not None else None # type: ignore
|
1 if lora is not None else None # type: ignore
|
||||||
for lora in loras
|
for lora in loras
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Dict, List, Optional, Type
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
|
||||||
|
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
@ -119,7 +119,8 @@ class LoRAModel(AdapterModel):
|
|||||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||||
loras: Dict[str, LoRALayerWeights] = {}
|
loras: Dict[str, LoRALayerWeights] = {}
|
||||||
for tensor_name, tensor in tensors.items():
|
for tensor_name, tensor in tensors.items():
|
||||||
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
|
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
|
||||||
|
tensor_name)
|
||||||
if module_name not in loras:
|
if module_name not in loras:
|
||||||
lora_embeddings_tensor = None
|
lora_embeddings_tensor = None
|
||||||
if embeddings:
|
if embeddings:
|
||||||
@ -136,8 +137,16 @@ class LoRAModel(AdapterModel):
|
|||||||
lora_embeddings_tensor.pin_memory())
|
lora_embeddings_tensor.pin_memory())
|
||||||
loras[module_name] = LoRALayerWeights(module_name, rank,
|
loras[module_name] = LoRALayerWeights(module_name, rank,
|
||||||
lora_alpha, None, None,
|
lora_alpha, None, None,
|
||||||
|
None,
|
||||||
lora_embeddings_tensor)
|
lora_embeddings_tensor)
|
||||||
if is_lora_a:
|
if is_bias:
|
||||||
|
loras[module_name].bias = tensor.to(device=device,
|
||||||
|
dtype=dtype).t()
|
||||||
|
bias = tensor.to(device=device, dtype=dtype).t()
|
||||||
|
if pin_memory:
|
||||||
|
bias = bias.pin_memory()
|
||||||
|
loras[module_name].bias = bias
|
||||||
|
elif is_lora_a:
|
||||||
loras[module_name].lora_a = tensor.to(device=device,
|
loras[module_name].lora_a = tensor.to(device=device,
|
||||||
dtype=dtype).t()
|
dtype=dtype).t()
|
||||||
if pin_memory:
|
if pin_memory:
|
||||||
@ -215,7 +224,7 @@ class LoRAModel(AdapterModel):
|
|||||||
with safetensors.safe_open(lora_tensor_path,
|
with safetensors.safe_open(lora_tensor_path,
|
||||||
framework="pt") as f: # type: ignore
|
framework="pt") as f: # type: ignore
|
||||||
for lora_module in f.keys(): # noqa
|
for lora_module in f.keys(): # noqa
|
||||||
module_name, _ = parse_fine_tuned_lora_name(lora_module)
|
module_name, _, _ = parse_fine_tuned_lora_name(lora_module)
|
||||||
part_name = module_name.split(".")[-1]
|
part_name = module_name.split(".")[-1]
|
||||||
if part_name not in expected_lora_modules:
|
if part_name not in expected_lora_modules:
|
||||||
unexpected_modules.append(module_name)
|
unexpected_modules.append(module_name)
|
||||||
@ -386,8 +395,19 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
module_lora = lora_model.get_lora(module_name)
|
module_lora = lora_model.get_lora(module_name)
|
||||||
if module_lora:
|
if module_lora:
|
||||||
module_lora.optimize()
|
module_lora.optimize()
|
||||||
|
# Bias is not explicitly enabled with the flag enable_lora_bias.
|
||||||
|
bias = module_lora.bias
|
||||||
|
if ((torch.is_tensor(bias) or
|
||||||
|
(isinstance(bias, Sequence) and any(b is not None
|
||||||
|
for b in bias)))
|
||||||
|
and not self.lora_config.bias_enabled):
|
||||||
|
module_lora.bias = None
|
||||||
|
raise ValueError(
|
||||||
|
f"Adapter bias cannot be used for {module_name}"
|
||||||
|
" without --enable-lora-bias.")
|
||||||
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
|
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
|
||||||
module_lora.embeddings_tensor)
|
module_lora.embeddings_tensor,
|
||||||
|
module_lora.bias)
|
||||||
else:
|
else:
|
||||||
module.reset_lora(index)
|
module.reset_lora(index)
|
||||||
return True
|
return True
|
||||||
@ -509,6 +529,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
"""Create zero-initialized LoRAModel for warmup."""
|
"""Create zero-initialized LoRAModel for warmup."""
|
||||||
model = LoRAModel(lora_id, rank, {}, scaling_factor)
|
model = LoRAModel(lora_id, rank, {}, scaling_factor)
|
||||||
for module_name, module in self.model.named_modules():
|
for module_name, module in self.model.named_modules():
|
||||||
|
bias_enabled = self.lora_config.bias_enabled
|
||||||
if (not self._match_target_modules(module_name)
|
if (not self._match_target_modules(module_name)
|
||||||
or not isinstance(module, BaseLayerWithLoRA)
|
or not isinstance(module, BaseLayerWithLoRA)
|
||||||
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
|
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
|
||||||
@ -536,7 +557,8 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
rank,
|
rank,
|
||||||
module.lora_a_stacked.dtype,
|
module.lora_a_stacked.dtype,
|
||||||
"cpu",
|
"cpu",
|
||||||
embeddings_tensor_dim=embeddings_tensor_dim)
|
embeddings_tensor_dim=embeddings_tensor_dim,
|
||||||
|
bias_enabled=bias_enabled)
|
||||||
else:
|
else:
|
||||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||||
module_name,
|
module_name,
|
||||||
@ -545,6 +567,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
rank,
|
rank,
|
||||||
module.lora_a_stacked.dtype,
|
module.lora_a_stacked.dtype,
|
||||||
"cpu",
|
"cpu",
|
||||||
|
bias_enabled=bias_enabled,
|
||||||
)
|
)
|
||||||
lora.optimize()
|
lora.optimize()
|
||||||
else:
|
else:
|
||||||
@ -559,6 +582,7 @@ class LoRAModelManager(AdapterModelManager):
|
|||||||
rank,
|
rank,
|
||||||
module.lora_a_stacked[i].dtype,
|
module.lora_a_stacked[i].dtype,
|
||||||
"cpu",
|
"cpu",
|
||||||
|
bias_enabled=bias_enabled,
|
||||||
)
|
)
|
||||||
lora.optimize()
|
lora.optimize()
|
||||||
subloras.append(lora)
|
subloras.append(lora)
|
||||||
|
|||||||
@ -91,7 +91,7 @@ def replace_submodule(model: nn.Module, module_name: str,
|
|||||||
return new_module
|
return new_module
|
||||||
|
|
||||||
|
|
||||||
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]:
|
||||||
"""Parse the name of lora weights.
|
"""Parse the name of lora weights.
|
||||||
|
|
||||||
args:
|
args:
|
||||||
@ -101,15 +101,18 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
|||||||
Tuple(module_name, is_lora_a):
|
Tuple(module_name, is_lora_a):
|
||||||
module_name: the name of the module, e.g. model.dense1,
|
module_name: the name of the module, e.g. model.dense1,
|
||||||
is_lora_a whether the tensor is lora_a or lora_b.
|
is_lora_a whether the tensor is lora_a or lora_b.
|
||||||
|
is_bias whether the tensor is lora bias.
|
||||||
"""
|
"""
|
||||||
parts = name.split(".")
|
parts = name.split(".")
|
||||||
|
if parts[-1] == "weight" and (parts[-2] == "lora_A"
|
||||||
|
or parts[-2] == "lora_B"):
|
||||||
|
return ".".join(parts[2:-2]), parts[-2] == "lora_A", False
|
||||||
|
|
||||||
if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
|
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||||
if parts[-1] == "weight":
|
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False
|
||||||
if parts[-2] == "lora_A" or parts[-2] == "lora_B":
|
|
||||||
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
|
if parts[-1] == "bias":
|
||||||
elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
return ".".join(parts[2:-2]), False, True
|
||||||
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
|
||||||
|
|
||||||
raise ValueError(f"{name} is unsupported LoRA weight")
|
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user