mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 22:35:22 +08:00
Enable more models to inference based on LoRA (#3382)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
dfeb2ecc3a
commit
8af890a865
@ -16,10 +16,13 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 512) \
|
||||
f(in_T, out_T, W_T, narrow, 768) \
|
||||
f(in_T, out_T, W_T, narrow, 1024) \
|
||||
f(in_T, out_T, W_T, narrow, 1152) \
|
||||
f(in_T, out_T, W_T, narrow, 1280) \
|
||||
f(in_T, out_T, W_T, narrow, 1536) \
|
||||
f(in_T, out_T, W_T, narrow, 1728) \
|
||||
f(in_T, out_T, W_T, narrow, 1792) \
|
||||
f(in_T, out_T, W_T, narrow, 2048) \
|
||||
f(in_T, out_T, W_T, narrow, 2304) \
|
||||
f(in_T, out_T, W_T, narrow, 2560) \
|
||||
f(in_T, out_T, W_T, narrow, 2752) \
|
||||
f(in_T, out_T, W_T, narrow, 2816) \
|
||||
@ -27,10 +30,12 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 3456) \
|
||||
f(in_T, out_T, W_T, narrow, 3584) \
|
||||
f(in_T, out_T, W_T, narrow, 4096) \
|
||||
f(in_T, out_T, W_T, narrow, 4608) \
|
||||
f(in_T, out_T, W_T, narrow, 5120) \
|
||||
f(in_T, out_T, W_T, narrow, 5504) \
|
||||
f(in_T, out_T, W_T, narrow, 5632) \
|
||||
f(in_T, out_T, W_T, narrow, 6144) \
|
||||
f(in_T, out_T, W_T, narrow, 6848) \
|
||||
f(in_T, out_T, W_T, narrow, 6912) \
|
||||
f(in_T, out_T, W_T, narrow, 7168) \
|
||||
f(in_T, out_T, W_T, narrow, 8192) \
|
||||
@ -45,6 +50,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 20480) \
|
||||
f(in_T, out_T, W_T, narrow, 22016) \
|
||||
f(in_T, out_T, W_T, narrow, 24576) \
|
||||
f(in_T, out_T, W_T, narrow, 27392) \
|
||||
f(in_T, out_T, W_T, narrow, 28672) \
|
||||
f(in_T, out_T, W_T, narrow, 32000) \
|
||||
f(in_T, out_T, W_T, narrow, 32256) \
|
||||
|
||||
@ -134,6 +134,16 @@ def gemma_lora_files():
|
||||
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def chatglm3_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def baichuan_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||
cleanup()
|
||||
|
||||
108
tests/lora/test_baichuan.py
Normal file
108
tests/lora/test_baichuan.py
Normal file
@ -0,0 +1,108 @@
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from .conftest import cleanup
|
||||
|
||||
MODEL_PATH = "baichuan-inc/Baichuan-7B"
|
||||
|
||||
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
|
||||
|
||||
|
||||
def do_sample(llm, lora_path: str, lora_id: int) -> str:
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
|
||||
PROMPT_TEMPLATE.format(
|
||||
query=
|
||||
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
|
||||
),
|
||||
PROMPT_TEMPLATE.format(
|
||||
query=
|
||||
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
|
||||
),
|
||||
]
|
||||
print(prompts)
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
if lora_id else None)
|
||||
# Print the outputs.
|
||||
generated_texts = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
def test_baichuan_lora(baichuan_lora_files):
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
trust_remote_code=True)
|
||||
|
||||
expected_lora_output = [
|
||||
"SELECT count(*) FROM singer",
|
||||
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501
|
||||
"SELECT name , country , age FROM singer ORDER BY age ASC",
|
||||
]
|
||||
|
||||
output1 = do_sample(llm, baichuan_lora_files, lora_id=1)
|
||||
for i in range(len(expected_lora_output)):
|
||||
assert output1[i] == expected_lora_output[i]
|
||||
output2 = do_sample(llm, baichuan_lora_files, lora_id=2)
|
||||
for i in range(len(expected_lora_output)):
|
||||
assert output2[i] == expected_lora_output[i]
|
||||
|
||||
|
||||
@pytest.mark.skip("Requires multiple GPUs")
|
||||
def test_llama_tensor_parallel_equality(baichuan_lora_files):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < 4:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
|
||||
|
||||
llm_tp1 = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
tensor_parallel_size=1,
|
||||
trust_remote_code=True)
|
||||
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp1
|
||||
cleanup()
|
||||
|
||||
llm_tp2 = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
tensor_parallel_size=2,
|
||||
trust_remote_code=True)
|
||||
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
|
||||
|
||||
del llm_tp2
|
||||
cleanup()
|
||||
|
||||
assert output_tp1 == output_tp2
|
||||
|
||||
llm_tp4 = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
tensor_parallel_size=4,
|
||||
trust_remote_code=True)
|
||||
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
|
||||
|
||||
del llm_tp4
|
||||
cleanup()
|
||||
|
||||
assert output_tp1 == output_tp4
|
||||
57
tests/lora/test_chatglm3.py
Normal file
57
tests/lora/test_chatglm3.py
Normal file
@ -0,0 +1,57 @@
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = "THUDM/chatglm3-6b"
|
||||
|
||||
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
|
||||
|
||||
|
||||
def do_sample(llm, lora_path: str, lora_id: int) -> str:
|
||||
prompts = [
|
||||
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
|
||||
PROMPT_TEMPLATE.format(
|
||||
query=
|
||||
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
|
||||
),
|
||||
PROMPT_TEMPLATE.format(
|
||||
query=
|
||||
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
|
||||
),
|
||||
]
|
||||
print(prompts)
|
||||
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
if lora_id else None)
|
||||
# Print the outputs.
|
||||
generated_texts = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
def test_chatglm3_lora(chatglm3_lora_files):
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
max_model_len=1024,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
trust_remote_code=True)
|
||||
|
||||
expected_lora_output = [
|
||||
"SELECT count(*) FROM singer",
|
||||
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
|
||||
"SELECT name , country , age FROM singer ORDER BY age",
|
||||
]
|
||||
|
||||
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
|
||||
for i in range(len(expected_lora_output)):
|
||||
assert output1[i] == expected_lora_output[i]
|
||||
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
|
||||
for i in range(len(expected_lora_output)):
|
||||
assert output2[i] == expected_lora_output[i]
|
||||
@ -8,12 +8,16 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
LogitsProcessorWithLoRA, LoRAMapping,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
QKVParallelLinearWithLora,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
# yapf: enable
|
||||
from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights,
|
||||
convert_mapping)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@ -93,8 +97,7 @@ def populate_loras(
|
||||
lora_dict: Dict[int, LoRALayerWeights] = dict()
|
||||
|
||||
# Dictionary that maps the lora ID to the
|
||||
# corresponding subloras. Only useful when
|
||||
# repeats > 1.
|
||||
# corresponding subloras.
|
||||
sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()
|
||||
|
||||
for slot_idx, lora_id in enumerate(id_to_index):
|
||||
@ -607,7 +610,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("repeats", [2, 3])
|
||||
@pytest.mark.parametrize("repeats", [1, 2, 3])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
|
||||
@ -623,6 +626,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
bias=False)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
|
||||
elif repeats == 3:
|
||||
linear = QKVParallelLinear(4096, 64, 32, bias=False)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
lora_linear = MergedQKVParallelLinearWithLora(linear)
|
||||
else:
|
||||
linear = QKVParallelLinear(4096, 64, 32, bias=False)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
|
||||
@ -43,9 +43,10 @@ def _lora_ref_impl(
|
||||
|
||||
|
||||
H1 = H2 = [
|
||||
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
|
||||
5504, 5632, 6144, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336,
|
||||
22016, 24576, 32000, 32256, 32512, 32768, 33024
|
||||
128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
|
||||
3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
|
||||
10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
|
||||
32768, 33024
|
||||
]
|
||||
SEED = [0xabcdabcd987]
|
||||
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
# pylint: disable=unused-argument
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -114,8 +115,11 @@ class LoRAMapping:
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
|
||||
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig) -> None:
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
...
|
||||
|
||||
@ -144,6 +148,13 @@ class BaseLayerWithLoRA(nn.Module):
|
||||
"""Sets the mapping indices."""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
@ -278,12 +289,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
return full_output.view_as(full_output_org)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is VocabParallelEmbedding
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(self, base_layer: ColumnParallelLinear) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
@ -309,7 +327,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
self.output_dim = self.lora_b_stacked.shape[1]
|
||||
self.output_dim = self.lora_b_stacked.shape[2]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
@ -323,7 +341,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_dim
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
@ -383,6 +406,14 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def linear_weights(self):
|
||||
return self.base_layer.linear_weights
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is ColumnParallelLinear or (
|
||||
type(source_layer) is MergedColumnParallelLinear
|
||||
and len(packed_modules_list) == 1)
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
|
||||
@ -485,8 +516,80 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is MergedColumnParallelLinear and len(
|
||||
packed_modules_list) == 2
|
||||
|
||||
|
||||
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
ColumnParallelLinear layer that is specifically designed for
|
||||
qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
|
||||
only contains a single LoRA within their qkv_proj layer.
|
||||
|
||||
During inference with Tensor Parallel, the weights of lora_b
|
||||
must be accurately partitioned according to the respective ranks.
|
||||
|
||||
Q slice may have different shape than K and V slices (which both have
|
||||
the same shape).
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.q_proj_total_size = (self.base_layer.total_num_heads *
|
||||
self.base_layer.head_size)
|
||||
self.q_proj_shard_size = (self.base_layer.num_heads *
|
||||
self.base_layer.head_size)
|
||||
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
||||
self.base_layer.head_size)
|
||||
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
|
||||
self.base_layer.head_size)
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
if self.tp_size > 1:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.q_shard_id = tp_rank
|
||||
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
||||
lora_b_q = lora_b[:, self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
k_offset = self.q_proj_total_size
|
||||
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
|
||||
self.kv_shard_id:k_offset +
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||
v_offset = k_offset + self.kv_proj_total_size
|
||||
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
|
||||
self.kv_shard_id:v_offset +
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is QKVParallelLinear and len(
|
||||
packed_modules_list) == 1
|
||||
|
||||
|
||||
class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
|
||||
packed together in qkv proj fashion
|
||||
(q_proj + k_proj + v_proj -> qkv_proj).
|
||||
@ -654,6 +757,13 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is QKVParallelLinear and len(
|
||||
packed_modules_list) == 3
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
@ -780,6 +890,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def weight(self):
|
||||
return self.base_layer.weight
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is RowParallelLinear
|
||||
|
||||
|
||||
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
@ -900,7 +1016,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
hidden_states: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
@ -949,22 +1065,30 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
def forward(self, *args, **kwargs):
|
||||
return type(self.base_layer).forward(self, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# Special handling for the LogitsProcessor.
|
||||
return False
|
||||
|
||||
def from_layer(
|
||||
layer: nn.Module,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
|
||||
supported_layer_types = {
|
||||
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
||||
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
||||
QKVParallelLinear: QKVParallelLinearWithLora,
|
||||
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
||||
RowParallelLinear: RowParallelLinearWithLoRA,
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
|
||||
ret = lora_layer_type(layer)
|
||||
|
||||
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
|
||||
cls
|
||||
for cls in globals().values() if inspect.isclass(cls)
|
||||
and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
|
||||
}
|
||||
|
||||
|
||||
def from_layer(layer: nn.Module,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
|
||||
for lora_cls in _all_lora_classes:
|
||||
if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
|
||||
model_config):
|
||||
ret = lora_cls(layer)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
return layer
|
||||
|
||||
@ -413,11 +413,12 @@ class LoRAModelManager:
|
||||
for module_name, module in self.model.named_modules():
|
||||
if not self._match_target_modules(module_name):
|
||||
continue
|
||||
|
||||
parts = module_name.split(".")[-1]
|
||||
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
|
||||
new_module = replace_submodule(
|
||||
self.model, module_name,
|
||||
from_layer(module, self.lora_slots, self.lora_config,
|
||||
self.model.config))
|
||||
packed_moduled_lst, self.model.config))
|
||||
# (yard1): TODO make this more robust
|
||||
if "lm_head" in module_name:
|
||||
logits_processor_module = self.model.get_submodule(
|
||||
@ -510,8 +511,10 @@ class LoRAModelManager:
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
parts = module_full_name.split(".")
|
||||
module_name = parts[-1]
|
||||
replacements = self.packed_modules_mapping.get(module_name)
|
||||
if not replacements:
|
||||
replacements = self.packed_modules_mapping.get(module_name, [])
|
||||
# When replacements is less than or equal to 1, it indicates that this
|
||||
# module is not a packed module.
|
||||
if len(replacements) <= 1:
|
||||
return
|
||||
prefix = ".".join(parts[:-1])
|
||||
self.packed_modules[module_full_name] = [
|
||||
|
||||
@ -26,6 +26,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
@ -282,11 +283,30 @@ class BaiChuanModel(nn.Module):
|
||||
|
||||
|
||||
class BaiChuanBaseForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"W_pack": ["W_pack"],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"W_pack",
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
@ -371,19 +391,25 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
"""Baichuan 13B and Baichuan2 7B/13B."""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
if config.hidden_size == 4096: # baichuan2 7b
|
||||
super().__init__(config, "ROPE", linear_method)
|
||||
super().__init__(config, "ROPE", linear_method, lora_config)
|
||||
else: # baichuan 13b, baichuan2 13b
|
||||
super().__init__(config, "ALIBI", linear_method)
|
||||
super().__init__(config, "ALIBI", linear_method, lora_config)
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
"""Baichuan 7B."""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__(config, "ROPE", linear_method)
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__(config, "ROPE", linear_method, lora_config)
|
||||
|
||||
@ -9,6 +9,7 @@ from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
@ -317,11 +318,25 @@ class ChatGLMModel(nn.Module):
|
||||
|
||||
|
||||
class ChatGLMForCausalLM(nn.Module):
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"query_key_value",
|
||||
"dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config: ChatGLMConfig = config
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user