[LoRA] ReplicatedLinear support LoRA (#7081)

This commit is contained in:
Jee Jee Li 2024-08-03 13:40:19 +08:00 committed by GitHub
parent fb2c1c86c1
commit 99d7cabd7b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 199 additions and 0 deletions

View File

@ -22,6 +22,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
@ -31,6 +32,7 @@ from vllm.lora.punica import PunicaWrapper
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
@ -545,6 +547,107 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
atol=atol)
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
torch.set_default_device(device)
punica_wrapper = PunicaWrapper(8192, 256, device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
lora_dtype=torch.float16)
def create_random_linear_replicated_layer():
linear = ReplicatedLinear(4096,
4096,
bias=False,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ReplicatedLinearWithLoRA(linear)
lora_linear.create_lora_weights(max_loras, lora_config)
return linear, lora_linear
for i in range(10):
set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_replicated_layer()
lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_linear,
layer_weights=linear.weight,
)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping,
id_to_index,
max_loras,
512,
lora_config.lora_extra_vocab_size,
)
lora_result = lora_linear(torch.cat(inputs))[0]
expected_results: List[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
# Check that resetting the lora weights succeeds
for slot_idx in range(max_loras):
lora_linear.reset_lora(slot_idx)
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=32 * num_loras,
input_size=(1, 4096),
input_range=(0, 1),
input_type=torch.float16,
)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size)
lora_result = lora_linear(torch.cat(inputs))[0]
expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result,
expected_result,
rtol=rtol,
atol=atol)
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])

View File

@ -21,6 +21,7 @@ from vllm.lora.punica import PunicaWrapper
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import (
@ -262,6 +263,99 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
return type(source_layer) is VocabParallelEmbedding
class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: ReplicatedLinear) -> None:
super().__init__()
self.base_layer = base_layer
self.input_size = self.base_layer.input_size
self.output_size = self.base_layer.output_size
self.device = _get_lora_device(self.base_layer)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> None:
self.lora_config = lora_config
lora_a_output_size = lora_config.max_lora_rank
self.lora_a_stacked = torch.zeros(
max_loras,
1,
lora_a_output_size,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
max_loras,
1,
self.output_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
)
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
return output
def forward(self, input_):
"""Forward of ReplicatedLinearWithLoRA
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = (self.base_layer.bias
if not self.base_layer.skip_bias_add else None)
# Matrix multiply.
output = self.apply(input_, bias)
output_bias = (self.base_layer.bias
if self.base_layer.skip_bias_add else None)
return output, output_bias
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is ReplicatedLinear
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
"""
LoRA on top of ColumnParallelLinear layer.

View File

@ -23,6 +23,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
@ -38,6 +39,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
QKVParallelLinearWithLora,
MergedQKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLora,