Enable more models to inference based on LoRA (#3382)

Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
Jee Li 2024-03-26 09:09:31 +08:00 committed by GitHub
parent dfeb2ecc3a
commit 8af890a865
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 401 additions and 44 deletions

View File

@ -16,10 +16,13 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 768) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1152) \
f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1536) \
f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2304) \
f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \
@ -27,10 +30,12 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 4608) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
@ -45,6 +50,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \

View File

@ -134,6 +134,16 @@ def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
@pytest.fixture(scope="session")
def chatglm3_lora_files():
return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
@pytest.fixture(scope="session")
def baichuan_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
@pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()

108
tests/lora/test_baichuan.py Normal file
View File

@ -0,0 +1,108 @@
import pytest
import vllm
from vllm.lora.request import LoRARequest
from .conftest import cleanup
MODEL_PATH = "baichuan-inc/Baichuan-7B"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
),
PROMPT_TEMPLATE.format(
query=
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=256)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_baichuan_lora(baichuan_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True)
expected_lora_output = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE Country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age ASC",
]
output1 = do_sample(llm, baichuan_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
output2 = do_sample(llm, baichuan_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]
@pytest.mark.skip("Requires multiple GPUs")
def test_llama_tensor_parallel_equality(baichuan_lora_files):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=1,
trust_remote_code=True)
output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1)
del llm_tp1
cleanup()
llm_tp2 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True)
output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2)
del llm_tp2
cleanup()
assert output_tp1 == output_tp2
llm_tp4 = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True)
output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2)
del llm_tp4
cleanup()
assert output_tp1 == output_tp4

View File

@ -0,0 +1,57 @@
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "THUDM/chatglm3-6b"
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
def do_sample(llm, lora_path: str, lora_id: int) -> str:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
),
PROMPT_TEMPLATE.format(
query=
"Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501
),
]
print(prompts)
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_chatglm3_lora(chatglm3_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True)
expected_lora_output = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT name , country , age FROM singer ORDER BY age",
]
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
for i in range(len(expected_lora_output)):
assert output1[i] == expected_lora_output[i]
output2 = do_sample(llm, chatglm3_lora_files, lora_id=2)
for i in range(len(expected_lora_output)):
assert output2[i] == expected_lora_output[i]

View File

@ -8,12 +8,16 @@ import torch
import torch.nn.functional as F
from vllm.config import LoRAConfig
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LogitsProcessorWithLoRA, LoRAMapping,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLora,
QKVParallelLinearWithLora,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
# yapf: enable
from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights,
convert_mapping)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -93,8 +97,7 @@ def populate_loras(
lora_dict: Dict[int, LoRALayerWeights] = dict()
# Dictionary that maps the lora ID to the
# corresponding subloras. Only useful when
# repeats > 1.
# corresponding subloras.
sublora_dict: Dict[int, List[LoRALayerWeights]] = dict()
for slot_idx, lora_id in enumerate(id_to_index):
@ -607,7 +610,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [2, 3])
@pytest.mark.parametrize("repeats", [1, 2, 3])
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
@ -623,6 +626,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
bias=False)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedColumnParallelLinearWithLoRA(linear)
elif repeats == 3:
linear = QKVParallelLinear(4096, 64, 32, bias=False)
linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = MergedQKVParallelLinearWithLora(linear)
else:
linear = QKVParallelLinear(4096, 64, 32, bias=False)
linear.weight.data = torch.rand_like(linear.weight.data)

View File

@ -43,9 +43,10 @@ def _lora_ref_impl(
H1 = H2 = [
128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120,
5504, 5632, 6144, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336,
22016, 24576, 32000, 32256, 32512, 32768, 33024
128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
32768, 33024
]
SEED = [0xabcdabcd987]

View File

@ -1,7 +1,8 @@
# pylint: disable=unused-argument
import inspect
import math
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type
import torch
import torch.nn as nn
@ -114,8 +115,11 @@ class LoRAMapping:
class BaseLayerWithLoRA(nn.Module):
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
model_config: PretrainedConfig) -> None:
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
"""Initializes lora matrices."""
...
@ -144,6 +148,13 @@ class BaseLayerWithLoRA(nn.Module):
"""Sets the mapping indices."""
...
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
@ -278,12 +289,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.indices[:self.indices_len[0]], 0, 1.0)
return full_output.view_as(full_output_org)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is VocabParallelEmbedding
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__()
self.base_layer = base_layer
self.tp_size = get_tensor_model_parallel_world_size()
def create_lora_weights(
self,
@ -309,7 +327,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None
self.output_dim = self.lora_b_stacked.shape[1]
self.output_dim = self.lora_b_stacked.shape[2]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
@ -323,7 +341,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
@ -383,6 +406,14 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def linear_weights(self):
return self.base_layer.linear_weights
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 1)
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
@ -485,8 +516,80 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
)
return output
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is MergedColumnParallelLinear and len(
packed_modules_list) == 2
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""
ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
only contains a single LoRA within their qkv_proj layer.
During inference with Tensor Parallel, the weights of lora_b
must be accurately partitioned according to the respective ranks.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
self.tp_size = get_tensor_model_parallel_world_size()
self.q_proj_total_size = (self.base_layer.total_num_heads *
self.base_layer.head_size)
self.q_proj_shard_size = (self.base_layer.num_heads *
self.base_layer.head_size)
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
self.base_layer.head_size)
self.kv_proj_total_size = (self.base_layer.total_num_kv_heads *
self.base_layer.head_size)
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
tp_rank = get_tensor_model_parallel_rank()
self.q_shard_id = tp_rank
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[:, self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[:, k_offset + self.kv_proj_shard_size *
self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[:, v_offset + self.kv_proj_shard_size *
self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True)
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len(
packed_modules_list) == 1
class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj).
@ -654,6 +757,13 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
)
return output
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len(
packed_modules_list) == 3
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@ -780,6 +890,12 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def weight(self):
return self.base_layer.weight
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is RowParallelLinear
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
@ -900,7 +1016,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
hidden_states: torch.Tensor,
embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
@ -949,22 +1065,30 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs)
@classmethod
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
# Special handling for the LogitsProcessor.
return False
def from_layer(
layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
supported_layer_types = {
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLora,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer)
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
cls
for cls in globals().values() if inspect.isclass(cls)
and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA
}
def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: List,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list,
model_config):
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
return layer

View File

@ -413,11 +413,12 @@ class LoRAModelManager:
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name):
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule(
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
self.model.config))
packed_moduled_lst, self.model.config))
# (yard1): TODO make this more robust
if "lm_head" in module_name:
logits_processor_module = self.model.get_submodule(
@ -510,8 +511,10 @@ class LoRAModelManager:
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
replacements = self.packed_modules_mapping.get(module_name)
if not replacements:
replacements = self.packed_modules_mapping.get(module_name, [])
# When replacements is less than or equal to 1, it indicates that this
# module is not a packed module.
if len(replacements) <= 1:
return
prefix = ".".join(parts[:-1])
self.packed_modules[module_full_name] = [

View File

@ -26,6 +26,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@ -282,11 +283,30 @@ class BaiChuanModel(nn.Module):
class BaiChuanBaseForCausalLM(nn.Module):
packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
"W_pack",
"o_proj",
"gate_up_proj",
"down_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
def __init__(
self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
@ -371,19 +391,25 @@ class BaiChuanBaseForCausalLM(nn.Module):
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B."""
def __init__(self,
config,
linear_method: Optional[LinearMethodBase] = None):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", linear_method)
super().__init__(config, "ROPE", linear_method, lora_config)
else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", linear_method)
super().__init__(config, "ALIBI", linear_method, lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 7B."""
def __init__(self,
config,
linear_method: Optional[LinearMethodBase] = None):
super().__init__(config, "ROPE", linear_method)
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__(config, "ROPE", linear_method, lora_config)

View File

@ -9,6 +9,7 @@ from torch import nn
from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
@ -317,11 +318,25 @@ class ChatGLMModel(nn.Module):
class ChatGLMForCausalLM(nn.Module):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
config: ChatGLMConfig,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config