From cc5e8f6db826158ffa730f74ed779d328fa93885 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 2 Mar 2025 09:17:34 +0800 Subject: [PATCH] [Model] Add LoRA support for TransformersModel (#13770) Signed-off-by: Jee Jee Li --- .buildkite/test-pipeline.yaml | 3 +- docs/source/models/supported_models.md | 15 +-- tests/lora/conftest.py | 5 + tests/lora/test_transfomers_model.py | 120 +++++++++++++++++++++ vllm/lora/layers.py | 25 +++-- vllm/lora/utils.py | 23 ++-- vllm/model_executor/models/transformers.py | 43 ++------ 7 files changed, 165 insertions(+), 69 deletions(-) create mode 100644 tests/lora/test_transfomers_model.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 05c4d26169906..d0f5c94ffd8db 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -275,7 +275,7 @@ steps: source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py parallelism: 4 - label: PyTorch Fullgraph Smoke Test # 9min @@ -589,6 +589,7 @@ steps: - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py - pytest -v -s -x lora/test_minicpmv_tp.py + - pytest -v -s -x lora/test_transfomers_model.py - label: Weight Loading Multiple GPU Test # 33min diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 4b1f3e180ed57..0e93a15b84fc9 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -62,20 +62,7 @@ Transformers fallback has supported most of available quantization in vLLM (exce ##### LoRA -LoRA hasn't supported on transformers fallback yet! Make sure to open an issue and we'll work on this together with the `transformers` team! - -Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly. - -Hints as to how this would look like: - -```python -class TransformersModel(nn.Module, SupportsLoRA): - def __init__(*): - ... - self.model.load_adapter(vllm_config.load_config.model_loader_extra_config["qlora_adapter_name_or_path"]) -``` - -Blocker is that you need to specify supported lora layers, when we would ideally want to load whatever is inside the checkpoint! +Transformers fallback has supported LoRA. The usage way is identical to how LoRA works with models supported by vLLM. If you encounter any issues, please open an issue. ##### Remote code diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index a414c3bcb6f01..59c1570b542e9 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -240,6 +240,11 @@ def baichuan_regex_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex") +@pytest.fixture(scope="session") +def ilama_lora_files(): + return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider") + + @pytest.fixture(scope="session") def minicpmv_lora_files(): return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon") diff --git a/tests/lora/test_transfomers_model.py b/tests/lora/test_transfomers_model.py new file mode 100644 index 0000000000000..07af1e9f449da --- /dev/null +++ b/tests/lora/test_transfomers_model.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import pytest + +import vllm +from tests.utils import fork_new_process_for_each_test +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "ArthurZ/ilama-3.2-1B" + +PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 + +EXPECTED_LORA_OUTPUT = [ + "SELECT count(*) FROM singer", + "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 + "SELECT DISTINCT Country FROM singer WHERE Age > 20", +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + prompts = [ + PROMPT_TEMPLATE.format(query="How many singers do we have?"), + PROMPT_TEMPLATE.format( + query= + "What is the average, minimum, and maximum age of all singers from France?" # noqa: E501 + ), + PROMPT_TEMPLATE.format( + query= + "What are all distinct countries where singers above age 20 are from?" # noqa: E501 + ), + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +@pytest.mark.skip_v1 +@fork_new_process_for_each_test +def test_ilama_lora(ilama_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=1, + trust_remote_code=True, + enable_chunked_prefill=True) + + output1 = do_sample(llm, ilama_lora_files, lora_id=1) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] + output2 = do_sample(llm, ilama_lora_files, lora_id=2) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] + + +@pytest.mark.skip_v1 +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_ilama_lora_tp4(ilama_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=False, + enable_chunked_prefill=True) + + output1 = do_sample(llm, ilama_lora_files, lora_id=1) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] + output2 = do_sample(llm, ilama_lora_files, lora_id=2) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] + + +@pytest.mark.skip_v1 +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_ilama_lora_tp4_fully_sharded_loras(ilama_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=16, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=True, + enable_chunked_prefill=True) + output1 = do_sample(llm, ilama_lora_files, lora_id=1) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] + output2 = do_sample(llm, ilama_lora_files, lora_id=2) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 6c48173c201b3..5a4d991da1b53 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -401,6 +401,11 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): self.output_slices) return output + @classmethod + def get_source_layer(cls, source_layer: nn.Module) -> type: + # Check parent_cls in case source_layer is a HFCompatibleLinear. + return getattr(source_layer, "parent_cls", type(source_layer)) + class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): @@ -443,7 +448,8 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - return type(source_layer) is ReplicatedLinear + source_layer = cls.get_source_layer(source_layer) + return source_layer is ReplicatedLinear class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): @@ -539,8 +545,9 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - return type(source_layer) is ColumnParallelLinear or ( - type(source_layer) is MergedColumnParallelLinear + source_layer = cls.get_source_layer(source_layer) + return source_layer is ColumnParallelLinear or ( + source_layer is MergedColumnParallelLinear and len(packed_modules_list) == 1) @@ -682,7 +689,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - return (type(source_layer) is MergedColumnParallelLinear + source_layer = cls.get_source_layer(source_layer) + return (source_layer is MergedColumnParallelLinear and len(packed_modules_list) == 2) @@ -750,7 +758,8 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: - return type(source_layer) is QKVParallelLinear and len( + source_layer = cls.get_source_layer(source_layer) + return source_layer is QKVParallelLinear and len( packed_modules_list) == 1 @@ -811,7 +820,8 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - return (type(source_layer) is QKVParallelLinear + source_layer = cls.get_source_layer(source_layer) + return (source_layer is QKVParallelLinear and len(packed_modules_list) == 3) @@ -896,7 +906,8 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - return type(source_layer) is RowParallelLinear + source_layer = cls.get_source_layer(source_layer) + return source_layer is RowParallelLinear class LogitsProcessorWithLoRA(BaseLayerWithLoRA): diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 63b465fdf7432..9f1b14b49704a 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -66,17 +66,20 @@ def from_layer(layer: nn.Module, lora_config=lora_config, packed_modules_list=packed_modules_list, model_config=model_config): - ret = lora_cls(layer) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret + instance_layer = lora_cls(layer) + if layer.__class__.__name__ == "HFCompatibleLinear": + # HACK: Make the forward method compatible with the original + # forward method of the instance_layer. + original_forward = instance_layer.forward - # The Case for HFCompatibleLinear - if (hasattr(layer, "get_lora_class") - and layer.__class__.__name__ == "HFCompatibleLinear"): - lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras) - ret = lora_cls(layer) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret + def new_forward(input): + input = input.squeeze(0) + return original_forward(input)[0] # noqa: B023 + + instance_layer.forward = new_forward + instance_layer.create_lora_weights(max_loras, lora_config, + model_config) + return instance_layer return layer diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 1c3c443b29413..61cfc566dd31a 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -27,11 +27,6 @@ from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.utils import divide from vllm.logger import init_logger -from vllm.lora.fully_sharded_layers import ( - ColumnParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA) -from vllm.lora.layers import (ColumnParallelLinearWithLoRA, - ReplicatedLinearWithLoRA, - RowParallelLinearWithLoRA) from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -43,7 +38,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsQuant +from .interfaces import SupportsLoRA, SupportsQuant from .utils import maybe_prefix logger = init_logger(__name__) @@ -102,44 +97,18 @@ def replace_linear_class( "rowwise": RowParallelLinear, }.get(style, ReplicatedLinear) - lora_linear_cls = { - ColumnParallelLinear: { - True: ColumnParallelLinearWithShardedLoRA, # fully sharded - False: ColumnParallelLinearWithLoRA # not fully sharded - }, - RowParallelLinear: { - True: RowParallelLinearWithShardedLoRA, - False: RowParallelLinearWithLoRA - }, - # ReplicatedLinear doesn't support fully sharded LoRA yet, - # so we use the same class for both cases. - ReplicatedLinear: { - True: ReplicatedLinearWithLoRA, - False: ReplicatedLinearWithLoRA - } - } - class HFCompatibleLinear(vllm_linear_cls): """ Wrapper class that removes `output_bias` from returned output. """ + # NOTE: The LoRA layer needs to use `parent_cls`. + @property + def parent_cls(self) -> type: + return vllm_linear_cls def forward(self, input: torch.Tensor) -> torch.Tensor: return super().forward(input)[0] - @classmethod - def get_lora_class(cls, fully_sharded: bool = False): - """ - Get the LoRA class corresponding to the current transformer - linear class. - - Args: - fully_sharded (bool): If True, select the LoRA class variant - that supports fully sharded LoRA. Defaults to False. - - """ - return lora_linear_cls[vllm_linear_cls][fully_sharded] - return HFCompatibleLinear( input_size=linear.in_features, output_size=linear.out_features, @@ -148,7 +117,7 @@ def replace_linear_class( ) -class TransformersModel(nn.Module, SupportsQuant): +class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): embedding_padding_modules = ["lm_head"] embedding_modules = ["embed_tokens" ] # TODO transformers will have a util to get it