mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-13 22:37:38 +08:00
[Model] Add LoRA support for TransformersModel (#13770)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
d54990da47
commit
cc5e8f6db8
@ -275,7 +275,7 @@ steps:
|
|||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/lora
|
- vllm/lora
|
||||||
- tests/lora
|
- tests/lora
|
||||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py
|
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py
|
||||||
parallelism: 4
|
parallelism: 4
|
||||||
|
|
||||||
- label: PyTorch Fullgraph Smoke Test # 9min
|
- label: PyTorch Fullgraph Smoke Test # 9min
|
||||||
@ -589,6 +589,7 @@ steps:
|
|||||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||||
- pytest -v -s -x lora/test_llama_tp.py
|
- pytest -v -s -x lora/test_llama_tp.py
|
||||||
- pytest -v -s -x lora/test_minicpmv_tp.py
|
- pytest -v -s -x lora/test_minicpmv_tp.py
|
||||||
|
- pytest -v -s -x lora/test_transfomers_model.py
|
||||||
|
|
||||||
|
|
||||||
- label: Weight Loading Multiple GPU Test # 33min
|
- label: Weight Loading Multiple GPU Test # 33min
|
||||||
|
|||||||
@ -62,20 +62,7 @@ Transformers fallback has supported most of available quantization in vLLM (exce
|
|||||||
|
|
||||||
##### LoRA
|
##### LoRA
|
||||||
|
|
||||||
LoRA hasn't supported on transformers fallback yet! Make sure to open an issue and we'll work on this together with the `transformers` team!
|
Transformers fallback has supported LoRA. The usage way is identical to how LoRA works with models supported by vLLM. If you encounter any issues, please open an issue.
|
||||||
|
|
||||||
Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.
|
|
||||||
|
|
||||||
Hints as to how this would look like:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class TransformersModel(nn.Module, SupportsLoRA):
|
|
||||||
def __init__(*):
|
|
||||||
...
|
|
||||||
self.model.load_adapter(vllm_config.load_config.model_loader_extra_config["qlora_adapter_name_or_path"])
|
|
||||||
```
|
|
||||||
|
|
||||||
Blocker is that you need to specify supported lora layers, when we would ideally want to load whatever is inside the checkpoint!
|
|
||||||
|
|
||||||
##### Remote code
|
##### Remote code
|
||||||
|
|
||||||
|
|||||||
@ -240,6 +240,11 @@ def baichuan_regex_lora_files():
|
|||||||
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
|
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def ilama_lora_files():
|
||||||
|
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def minicpmv_lora_files():
|
def minicpmv_lora_files():
|
||||||
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
|
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
|
||||||
|
|||||||
120
tests/lora/test_transfomers_model.py
Normal file
120
tests/lora/test_transfomers_model.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
from tests.utils import fork_new_process_for_each_test
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
from ..utils import multi_gpu_test
|
||||||
|
|
||||||
|
MODEL_PATH = "ArthurZ/ilama-3.2-1B"
|
||||||
|
|
||||||
|
PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501
|
||||||
|
|
||||||
|
EXPECTED_LORA_OUTPUT = [
|
||||||
|
"SELECT count(*) FROM singer",
|
||||||
|
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
|
||||||
|
"SELECT DISTINCT Country FROM singer WHERE Age > 20",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||||
|
prompts = [
|
||||||
|
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
|
||||||
|
PROMPT_TEMPLATE.format(
|
||||||
|
query=
|
||||||
|
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
|
||||||
|
),
|
||||||
|
PROMPT_TEMPLATE.format(
|
||||||
|
query=
|
||||||
|
"What are all distinct countries where singers above age 20 are from?" # noqa: E501
|
||||||
|
),
|
||||||
|
]
|
||||||
|
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts,
|
||||||
|
sampling_params,
|
||||||
|
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||||
|
if lora_id else None)
|
||||||
|
# Print the outputs.
|
||||||
|
generated_texts: List[str] = []
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text.strip()
|
||||||
|
generated_texts.append(generated_text)
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
return generated_texts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def v1(run_with_both_engines_lora):
|
||||||
|
# Simple autouse wrapper to run both engines for each test
|
||||||
|
# This can be promoted up to conftest.py to run for every
|
||||||
|
# test in a package
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_v1
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_ilama_lora(ilama_lora_files):
|
||||||
|
llm = vllm.LLM(MODEL_PATH,
|
||||||
|
max_model_len=1024,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=4,
|
||||||
|
max_lora_rank=16,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
trust_remote_code=True,
|
||||||
|
enable_chunked_prefill=True)
|
||||||
|
|
||||||
|
output1 = do_sample(llm, ilama_lora_files, lora_id=1)
|
||||||
|
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||||
|
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
|
||||||
|
output2 = do_sample(llm, ilama_lora_files, lora_id=2)
|
||||||
|
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||||
|
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_v1
|
||||||
|
@multi_gpu_test(num_gpus=4)
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_ilama_lora_tp4(ilama_lora_files):
|
||||||
|
llm = vllm.LLM(MODEL_PATH,
|
||||||
|
max_model_len=1024,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=4,
|
||||||
|
max_lora_rank=16,
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
trust_remote_code=True,
|
||||||
|
fully_sharded_loras=False,
|
||||||
|
enable_chunked_prefill=True)
|
||||||
|
|
||||||
|
output1 = do_sample(llm, ilama_lora_files, lora_id=1)
|
||||||
|
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||||
|
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
|
||||||
|
output2 = do_sample(llm, ilama_lora_files, lora_id=2)
|
||||||
|
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||||
|
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_v1
|
||||||
|
@multi_gpu_test(num_gpus=4)
|
||||||
|
@fork_new_process_for_each_test
|
||||||
|
def test_ilama_lora_tp4_fully_sharded_loras(ilama_lora_files):
|
||||||
|
llm = vllm.LLM(MODEL_PATH,
|
||||||
|
max_model_len=1024,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=4,
|
||||||
|
max_lora_rank=16,
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
trust_remote_code=True,
|
||||||
|
fully_sharded_loras=True,
|
||||||
|
enable_chunked_prefill=True)
|
||||||
|
output1 = do_sample(llm, ilama_lora_files, lora_id=1)
|
||||||
|
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||||
|
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
|
||||||
|
output2 = do_sample(llm, ilama_lora_files, lora_id=2)
|
||||||
|
for i in range(len(EXPECTED_LORA_OUTPUT)):
|
||||||
|
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
|
||||||
@ -401,6 +401,11 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.output_slices)
|
self.output_slices)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_source_layer(cls, source_layer: nn.Module) -> type:
|
||||||
|
# Check parent_cls in case source_layer is a HFCompatibleLinear.
|
||||||
|
return getattr(source_layer, "parent_cls", type(source_layer))
|
||||||
|
|
||||||
|
|
||||||
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||||
|
|
||||||
@ -443,7 +448,8 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
packed_modules_list: List,
|
packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: Optional[PretrainedConfig],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return type(source_layer) is ReplicatedLinear
|
source_layer = cls.get_source_layer(source_layer)
|
||||||
|
return source_layer is ReplicatedLinear
|
||||||
|
|
||||||
|
|
||||||
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||||
@ -539,8 +545,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
packed_modules_list: List,
|
packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: Optional[PretrainedConfig],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return type(source_layer) is ColumnParallelLinear or (
|
source_layer = cls.get_source_layer(source_layer)
|
||||||
type(source_layer) is MergedColumnParallelLinear
|
return source_layer is ColumnParallelLinear or (
|
||||||
|
source_layer is MergedColumnParallelLinear
|
||||||
and len(packed_modules_list) == 1)
|
and len(packed_modules_list) == 1)
|
||||||
|
|
||||||
|
|
||||||
@ -682,7 +689,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
packed_modules_list: List,
|
packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: Optional[PretrainedConfig],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return (type(source_layer) is MergedColumnParallelLinear
|
source_layer = cls.get_source_layer(source_layer)
|
||||||
|
return (source_layer is MergedColumnParallelLinear
|
||||||
and len(packed_modules_list) == 2)
|
and len(packed_modules_list) == 2)
|
||||||
|
|
||||||
|
|
||||||
@ -750,7 +758,8 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
def can_replace_layer(cls, source_layer: nn.Module,
|
def can_replace_layer(cls, source_layer: nn.Module,
|
||||||
lora_config: LoRAConfig, packed_modules_list: List,
|
lora_config: LoRAConfig, packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig]) -> bool:
|
model_config: Optional[PretrainedConfig]) -> bool:
|
||||||
return type(source_layer) is QKVParallelLinear and len(
|
source_layer = cls.get_source_layer(source_layer)
|
||||||
|
return source_layer is QKVParallelLinear and len(
|
||||||
packed_modules_list) == 1
|
packed_modules_list) == 1
|
||||||
|
|
||||||
|
|
||||||
@ -811,7 +820,8 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
|||||||
packed_modules_list: List,
|
packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: Optional[PretrainedConfig],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return (type(source_layer) is QKVParallelLinear
|
source_layer = cls.get_source_layer(source_layer)
|
||||||
|
return (source_layer is QKVParallelLinear
|
||||||
and len(packed_modules_list) == 3)
|
and len(packed_modules_list) == 3)
|
||||||
|
|
||||||
|
|
||||||
@ -896,7 +906,8 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
packed_modules_list: List,
|
packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: Optional[PretrainedConfig],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return type(source_layer) is RowParallelLinear
|
source_layer = cls.get_source_layer(source_layer)
|
||||||
|
return source_layer is RowParallelLinear
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||||
|
|||||||
@ -66,17 +66,20 @@ def from_layer(layer: nn.Module,
|
|||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
packed_modules_list=packed_modules_list,
|
packed_modules_list=packed_modules_list,
|
||||||
model_config=model_config):
|
model_config=model_config):
|
||||||
ret = lora_cls(layer)
|
instance_layer = lora_cls(layer)
|
||||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
if layer.__class__.__name__ == "HFCompatibleLinear":
|
||||||
return ret
|
# HACK: Make the forward method compatible with the original
|
||||||
|
# forward method of the instance_layer.
|
||||||
|
original_forward = instance_layer.forward
|
||||||
|
|
||||||
# The Case for HFCompatibleLinear
|
def new_forward(input):
|
||||||
if (hasattr(layer, "get_lora_class")
|
input = input.squeeze(0)
|
||||||
and layer.__class__.__name__ == "HFCompatibleLinear"):
|
return original_forward(input)[0] # noqa: B023
|
||||||
lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras)
|
|
||||||
ret = lora_cls(layer)
|
instance_layer.forward = new_forward
|
||||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
instance_layer.create_lora_weights(max_loras, lora_config,
|
||||||
return ret
|
model_config)
|
||||||
|
return instance_layer
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -27,11 +27,6 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.distributed.utils import divide
|
from vllm.distributed.utils import divide
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.fully_sharded_layers import (
|
|
||||||
ColumnParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA)
|
|
||||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
|
||||||
ReplicatedLinearWithLoRA,
|
|
||||||
RowParallelLinearWithLoRA)
|
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@ -43,7 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsQuant
|
from .interfaces import SupportsLoRA, SupportsQuant
|
||||||
from .utils import maybe_prefix
|
from .utils import maybe_prefix
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -102,44 +97,18 @@ def replace_linear_class(
|
|||||||
"rowwise": RowParallelLinear,
|
"rowwise": RowParallelLinear,
|
||||||
}.get(style, ReplicatedLinear)
|
}.get(style, ReplicatedLinear)
|
||||||
|
|
||||||
lora_linear_cls = {
|
|
||||||
ColumnParallelLinear: {
|
|
||||||
True: ColumnParallelLinearWithShardedLoRA, # fully sharded
|
|
||||||
False: ColumnParallelLinearWithLoRA # not fully sharded
|
|
||||||
},
|
|
||||||
RowParallelLinear: {
|
|
||||||
True: RowParallelLinearWithShardedLoRA,
|
|
||||||
False: RowParallelLinearWithLoRA
|
|
||||||
},
|
|
||||||
# ReplicatedLinear doesn't support fully sharded LoRA yet,
|
|
||||||
# so we use the same class for both cases.
|
|
||||||
ReplicatedLinear: {
|
|
||||||
True: ReplicatedLinearWithLoRA,
|
|
||||||
False: ReplicatedLinearWithLoRA
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class HFCompatibleLinear(vllm_linear_cls):
|
class HFCompatibleLinear(vllm_linear_cls):
|
||||||
"""
|
"""
|
||||||
Wrapper class that removes `output_bias` from returned output.
|
Wrapper class that removes `output_bias` from returned output.
|
||||||
"""
|
"""
|
||||||
|
# NOTE: The LoRA layer needs to use `parent_cls`.
|
||||||
|
@property
|
||||||
|
def parent_cls(self) -> type:
|
||||||
|
return vllm_linear_cls
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
return super().forward(input)[0]
|
return super().forward(input)[0]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_lora_class(cls, fully_sharded: bool = False):
|
|
||||||
"""
|
|
||||||
Get the LoRA class corresponding to the current transformer
|
|
||||||
linear class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fully_sharded (bool): If True, select the LoRA class variant
|
|
||||||
that supports fully sharded LoRA. Defaults to False.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return lora_linear_cls[vllm_linear_cls][fully_sharded]
|
|
||||||
|
|
||||||
return HFCompatibleLinear(
|
return HFCompatibleLinear(
|
||||||
input_size=linear.in_features,
|
input_size=linear.in_features,
|
||||||
output_size=linear.out_features,
|
output_size=linear.out_features,
|
||||||
@ -148,7 +117,7 @@ def replace_linear_class(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel(nn.Module, SupportsQuant):
|
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
embedding_modules = ["embed_tokens"
|
embedding_modules = ["embed_tokens"
|
||||||
] # TODO transformers will have a util to get it
|
] # TODO transformers will have a util to get it
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user