mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 18:07:28 +08:00
[LoRA] ReplicatedLinear support LoRA (#7081)
This commit is contained in:
parent
fb2c1c86c1
commit
99d7cabd7b
@ -22,6 +22,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
QKVParallelLinearWithLora,
|
||||
ReplicatedLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
# yapf: enable
|
||||
@ -31,6 +32,7 @@ from vllm.lora.punica import PunicaWrapper
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@ -545,6 +547,107 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_random_linear_replicated_layer():
|
||||
|
||||
linear = ReplicatedLinear(4096,
|
||||
4096,
|
||||
bias=False,
|
||||
params_dtype=torch.float16)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = ReplicatedLinearWithLoRA(linear)
|
||||
|
||||
lora_linear.create_lora_weights(max_loras, lora_config)
|
||||
|
||||
return linear, lora_linear
|
||||
|
||||
for i in range(10):
|
||||
set_random_seed(i)
|
||||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
linear, lora_linear = create_random_linear_replicated_layer()
|
||||
lora_linear.set_mapping(punica_wrapper)
|
||||
lora_dict, _ = populate_loras(
|
||||
id_to_index,
|
||||
layer=lora_linear,
|
||||
layer_weights=linear.weight,
|
||||
)
|
||||
|
||||
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||
active_lora_ids=list(lora_dict.keys()),
|
||||
num_inputs=32 * num_loras,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
punica_wrapper.update_metadata(
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
512,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
|
||||
expected_results: List[torch.Tensor] = []
|
||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||
lora = lora_dict[lora_id]
|
||||
result = linear(input_)[0]
|
||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
# Check that resetting the lora weights succeeds
|
||||
|
||||
for slot_idx in range(max_loras):
|
||||
lora_linear.reset_lora(slot_idx)
|
||||
|
||||
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||
active_lora_ids=[0],
|
||||
num_inputs=32 * num_loras,
|
||||
input_size=(1, 4096),
|
||||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
|
||||
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
expected_result = linear(torch.cat(inputs))[0]
|
||||
|
||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||
assert torch.allclose(lora_result,
|
||||
expected_result,
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("orientation", ["row", "column"])
|
||||
|
||||
@ -21,6 +21,7 @@ from vllm.lora.punica import PunicaWrapper
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
@ -262,6 +263,99 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
return type(source_layer) is VocabParallelEmbedding
|
||||
|
||||
|
||||
class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(self, base_layer: ReplicatedLinear) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.input_size = self.base_layer.input_size
|
||||
self.output_size = self.base_layer.output_size
|
||||
self.device = _get_lora_device(self.base_layer)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
self.lora_config = lora_config
|
||||
lora_a_output_size = lora_config.max_lora_rank
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_a_output_size,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.output_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
return output
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of ReplicatedLinearWithLoRA
|
||||
|
||||
Args:
|
||||
input_: Tensor whose last dimension is `input_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
bias = (self.base_layer.bias
|
||||
if not self.base_layer.skip_bias_add else None)
|
||||
|
||||
# Matrix multiply.
|
||||
output = self.apply(input_, bias)
|
||||
|
||||
output_bias = (self.base_layer.bias
|
||||
if self.base_layer.skip_bias_add else None)
|
||||
return output, output_bias
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is ReplicatedLinear
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
"""
|
||||
LoRA on top of ColumnParallelLinear layer.
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
QKVParallelLinearWithLora,
|
||||
ReplicatedLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
# yapf: enable
|
||||
@ -38,6 +39,7 @@ _all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
|
||||
QKVParallelLinearWithLora,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
RowParallelLinearWithLoRA,
|
||||
ReplicatedLinearWithLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
QKVParallelLinearWithShardedLora,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user