From 105b8ce4c04022d96ca401907303f04752a6ad6b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 22 Feb 2025 16:21:30 +0800 Subject: [PATCH] [Misc] Reduce LoRA-related static variable (#13166) --- tests/lora/conftest.py | 17 +++++++-- tests/lora/test_lora_checkpoints.py | 13 ++++--- tests/lora/test_lora_huggingface.py | 7 ++-- tests/lora/test_lora_manager.py | 26 +++++--------- vllm/lora/models.py | 21 +++++------ vllm/lora/utils.py | 26 ++++++++++++++ vllm/lora/worker_manager.py | 8 +++-- vllm/model_executor/models/baichuan.py | 9 ----- vllm/model_executor/models/bamba.py | 6 ---- vllm/model_executor/models/chatglm.py | 10 ------ vllm/model_executor/models/commandr.py | 4 --- vllm/model_executor/models/exaone.py | 8 ----- vllm/model_executor/models/gemma.py | 12 ------- vllm/model_executor/models/gemma2.py | 11 ------ vllm/model_executor/models/glm4v.py | 15 -------- vllm/model_executor/models/gpt_bigcode.py | 5 +-- vllm/model_executor/models/granite.py | 4 --- vllm/model_executor/models/granitemoe.py | 7 ---- vllm/model_executor/models/idefics3.py | 15 -------- vllm/model_executor/models/interfaces.py | 12 +++---- vllm/model_executor/models/internlm2.py | 10 ------ vllm/model_executor/models/jamba.py | 4 --- vllm/model_executor/models/llama.py | 4 --- vllm/model_executor/models/minicpm.py | 8 ----- vllm/model_executor/models/minicpm3.py | 16 --------- vllm/model_executor/models/minicpmv.py | 42 ---------------------- vllm/model_executor/models/mixtral.py | 4 --- vllm/model_executor/models/molmo.py | 20 ----------- vllm/model_executor/models/nemotron.py | 3 -- vllm/model_executor/models/phi.py | 11 ------ vllm/model_executor/models/phimoe.py | 10 ------ vllm/model_executor/models/qwen.py | 9 ----- vllm/model_executor/models/qwen2.py | 20 ----------- vllm/model_executor/models/qwen2_5_vl.py | 21 ----------- vllm/model_executor/models/qwen2_rm.py | 10 ------ vllm/model_executor/models/qwen2_vl.py | 18 ---------- vllm/model_executor/models/qwen_vl.py | 15 -------- vllm/model_executor/models/solar.py | 8 ----- vllm/model_executor/models/transformers.py | 35 ++++++++++++++++++ vllm/model_executor/models/ultravox.py | 8 ----- vllm/worker/hpu_model_runner.py | 3 -- 41 files changed, 120 insertions(+), 395 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 92ff52b839ed8..a414c3bcb6f01 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -23,6 +23,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.interfaces import SupportsLoRA from vllm.platforms import current_platform @@ -98,9 +99,13 @@ def dist_init_torch_only(): backend=backend) +class DummyLoRAModel(nn.Sequential, SupportsLoRA): + pass + + @pytest.fixture def dummy_model() -> nn.Module: - model = nn.Sequential( + model = DummyLoRAModel( OrderedDict([ ("dense1", ColumnParallelLinear(764, 100)), ("dense2", RowParallelLinear(100, 50)), @@ -121,12 +126,13 @@ def dummy_model() -> nn.Module: ("sampler", Sampler()) ])) model.config = MagicMock() + model.embedding_modules = {"lm_head": "lm_head"} return model @pytest.fixture def dummy_model_gate_up() -> nn.Module: - model = nn.Sequential( + model = DummyLoRAModel( OrderedDict([ ("dense1", ColumnParallelLinear(764, 100)), ("dense2", RowParallelLinear(100, 50)), @@ -147,6 +153,13 @@ def dummy_model_gate_up() -> nn.Module: ("sampler", Sampler()) ])) model.config = MagicMock() + model.packed_modules_mapping = { + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + model.embedding_modules = {"lm_head": "lm_head"} return model diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index d2a4b901bd8d7..e2c3d20d327fe 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -12,6 +12,12 @@ from vllm.model_executor.models.utils import WeightsMapper lora_lst = [ "baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b" ] +BAICHUAN_LORA_MODULES = [ + "W_pack", + "o_proj", + "gate_up_proj", + "down_proj", +] @pytest.mark.parametrize("lora_name", lora_lst) @@ -22,12 +28,11 @@ def test_load_checkpoints( baichuan_regex_lora_files, chatglm3_lora_files, ): - supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules expected_lora_modules: List[str] = [] - for module in supported_lora_modules: + for module in BAICHUAN_LORA_MODULES: if module in packed_modules_mapping: expected_lora_modules.extend(packed_modules_mapping[module]) else: @@ -90,12 +95,12 @@ def test_load_checkpoints( def test_lora_weights_mapping(baichuan_lora_files): - supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules + packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules expected_lora_modules: List[str] = [] - for module in supported_lora_modules: + for module in BAICHUAN_LORA_MODULES: if module in packed_modules_mapping: expected_lora_modules.extend(packed_modules_mapping[module]) else: diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index 273fe9ae0eb55..44d111732d2ae 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -11,17 +11,20 @@ from vllm.model_executor.models.llama import LlamaForCausalLM # Provide absolute path and huggingface lora ids lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"] +LLAMA_LORA_MODULES = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" +] @pytest.mark.parametrize("lora_fixture_name", lora_fixture_name) def test_load_checkpoints_from_huggingface(lora_fixture_name, request): lora_name = request.getfixturevalue(lora_fixture_name) - supported_lora_modules = LlamaForCausalLM.supported_lora_modules packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping embedding_modules = LlamaForCausalLM.embedding_modules embed_padding_modules = LlamaForCausalLM.embedding_padding_modules expected_lora_modules: List[str] = [] - for module in supported_lora_modules: + for module in LLAMA_LORA_MODULES: if module in packed_modules_mapping: expected_lora_modules.extend(packed_modules_mapping[module]) else: diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 9fecd11f57afe..7ab46b7ff9c9c 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -19,7 +19,6 @@ from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, WorkerLoRAManager) -from vllm.model_executor.layers.linear import RowParallelLinear from vllm.platforms import current_platform EMBEDDING_MODULES = { @@ -114,19 +113,16 @@ def create_packed_lora( def test_replace_submodules(dist_init, dummy_model): model = dummy_model - model.supported_lora_modules = ["dense1", "layer1.dense2"] - model.packed_modules_mapping = {} manager = LoRAModelManager( model, 1, 1, 1, LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8), torch.device(DEVICES[0])) model = manager.model - assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA) assert isinstance(model.get_submodule("layer1.dense1"), ColumnParallelLinearWithLoRA) - assert isinstance(model.get_submodule("dense2"), RowParallelLinear) + assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA) assert isinstance(model.get_submodule("layer1.dense2"), RowParallelLinearWithLoRA) @@ -134,8 +130,6 @@ def test_replace_submodules(dist_init, dummy_model): @pytest.mark.parametrize("device", DEVICES) def test_lora_model_manager(dist_init, dummy_model, device): model = dummy_model - model.supported_lora_modules = ["dense1", "dense2", "lm_head"] - model.packed_modules_mapping = {} model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"], device=device) @@ -190,13 +184,18 @@ def test_lora_model_manager(dist_init, dummy_model, device): assert manager.device == device assert manager.punica_wrapper.device == device + assert hasattr(manager, "supported_lora_modules") + assert sorted(manager.supported_lora_modules) == [ + "dense1", + "dense2", + "lm_head", + "output", + ] @pytest.mark.parametrize("device", DEVICES) def test_lora_lru_cache_model_manager(dist_init, dummy_model, device): model = dummy_model - model.supported_lora_modules = ["dense1", "dense2", "lm_head"] - model.packed_modules_mapping = {} model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"], device=device) @@ -289,8 +288,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device): # This tests just the LRU cache functionality, everything else is # tested in test_lora_model_manager model = dummy_model - model.supported_lora_modules = ["dense1", "dense2", "lm_head"] - model.packed_modules_mapping = {} model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"], device=device) @@ -572,13 +569,6 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, @pytest.mark.parametrize("device", DEVICES) def test_packed_loras(dist_init, dummy_model_gate_up, device): model = dummy_model_gate_up - model.supported_lora_modules = ["gate_up_proj"] - model.packed_modules_mapping = { - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } model_lora = create_packed_lora( 1, model, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index b7403980d0b0d..eb53513a28307 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -26,6 +26,7 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, + get_supported_lora_modules, is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models import SupportsLoRA, supports_multimodal @@ -332,15 +333,15 @@ class LoRAModelManager(AdapterModelManager): # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} super().__init__(model) - if hasattr(self.model, "supported_lora_modules"): - self.supported_lora_modules = copy.deepcopy( - self.model.supported_lora_modules) - if lora_config.long_lora_scaling_factors: - # We need to replace rotary emb layer to do batch computation - # for long lora. - self.supported_lora_modules.append("rotary_emb") - self.packed_modules_mapping = copy.deepcopy( - self.model.packed_modules_mapping) + self.supported_lora_modules = get_supported_lora_modules(self.model) + assert self.supported_lora_modules, "No supported LoRA modules found in" + f"{self.model.__class__.__name__}." + if lora_config.long_lora_scaling_factors: + # We need to replace rotary emb layer to do batch computation + # for long lora. + self.supported_lora_modules.append("rotary_emb") + self.packed_modules_mapping = copy.deepcopy( + self.model.packed_modules_mapping) # Used to indicate whether the model is a multimodal model self.supports_mm: bool = ( supports_multimodal(self.model) @@ -756,7 +757,7 @@ def create_lora_manager( lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, **kwargs) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" - if not hasattr(model, "supported_lora_modules"): + if not hasattr(model, "packed_modules_mapping"): raise ValueError(f"Model {type(model)} is not supported for LoRA.") lora_manager = lora_manager_cls( model=model, diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index f47b0af155226..361dac5b33139 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -29,6 +29,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) +from vllm.model_executor.layers.linear import LinearBase # yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -68,6 +69,14 @@ def from_layer(layer: nn.Module, ret = lora_cls(layer) ret.create_lora_weights(max_loras, lora_config, model_config) return ret + + # The Case for HFCompatibleLinear + if (hasattr(layer, "get_lora_class") + and layer.__class__.__name__ == "HFCompatibleLinear"): + lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras) + ret = lora_cls(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret return layer @@ -170,6 +179,23 @@ def is_regex_target_modules(load_modules: Union[str, List[str]], return False +def get_supported_lora_modules(model: nn.Module) -> List[str]: + """ + In vLLM, all linear layers support LoRA. + """ + supported_lora_modules: Set[str] = set() + # step1: traverse the model to get all the linear subfixes. + for name, module in model.named_modules(): + if isinstance(module, (LinearBase, )): + supported_lora_modules.add(name.split(".")[-1]) + # step 2: get the embedding modules if the model's mbedding_modules + # is not empty. + if model.embedding_modules: + for name in model.embedding_modules: + supported_lora_modules.add(name) + return list(supported_lora_modules) + + def get_adapter_absolute_path(lora_path: str) -> str: """ Resolves the given lora_path to an absolute local path. diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index b103acefe4aaf..108beb34b244a 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -84,9 +84,10 @@ class WorkerLoRAManager(AbstractWorkerManager): def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: try: - model = self._adapter_manager.model - supported_lora_modules = model.supported_lora_modules - packed_modules_mapping = model.packed_modules_mapping + supported_lora_modules = ( + self._adapter_manager.supported_lora_modules) + packed_modules_mapping = ( + self._adapter_manager.packed_modules_mapping) expected_lora_modules: List[str] = [] for module in supported_lora_modules: if module in packed_modules_mapping: @@ -107,6 +108,7 @@ class WorkerLoRAManager(AbstractWorkerManager): # For some models like Qwen2VL, we need to use hf_to_vllm_mapper # to ensure correct loading of lora weights. + model = self._adapter_manager.model hf_to_vllm_mapper = None if (hasattr(model, "hf_to_vllm_mapper") and model.hf_to_vllm_mapper is not None): diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 5dfaa727b75ae..b613b70a7564a 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -342,15 +342,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "up_proj", ], } - # LoRA specific attributes - supported_lora_modules = [ - "W_pack", - "o_proj", - "gate_up_proj", - "down_proj", - ] - embedding_modules = {} - embedding_padding_modules = [] def __init__( self, diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index b9310108543c2..22ae1775c3d97 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -389,12 +389,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 26b4a95c530e8..ecf417655452a 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -477,16 +477,6 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP): "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"] } - # LoRA specific attributes - supported_lora_modules = [ - "query_key_value", - "dense", - "dense_h_to_4h", - "dense_4h_to_h", - ] - - embedding_modules = {} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index e73627da05d40..0ceefc3e93aaa 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -357,11 +357,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ], } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens" - ] embedding_modules = {"embed_tokens": "input_embeddings"} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 2eb91a682242c..e795c7e288c44 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -415,14 +415,6 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "out_proj", - "gate_up_proj", - "c_proj", - "wte", - "lm_head", - ] embedding_modules = { "wte": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index cb81aa41e2542..d0589e60a72b6 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -344,18 +344,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - - # Gemma does not apply LoRA to the embedding layer. - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index a6dc8f84772b4..6ee257d65c50a 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -390,17 +390,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - # Gemma does not apply LoRA to the embedding layer. - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 40010ec559066..8fc5a797f8243 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -534,21 +534,6 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, "dense_h_to_4h": ["dense_h_to_4h"], "merged_proj": ["gate_proj", "dense_h_to_4h"] } - # LoRA specific attributes - supported_lora_modules = [ - "query_key_value", - "dense", - "dense_h_to_4h", - "dense_4h_to_h", - # vision - "fc1", - "fc2", - "merged_proj", - "linear_proj" - ] - - embedding_modules = {} - embedding_padding_modules = [] def get_mm_mapping(self) -> MultiModelKeys: """ diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 887a444748ae2..799edff46ea31 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -261,15 +261,12 @@ class GPTBigCodeModel(nn.Module): class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"c_attn": ["c_attn"]} - supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] - + # LoRA specific attributes embedding_modules = { "wte": "input_embeddings", "lm_head": "output_embeddings", } - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index 85911a0f41c2f..2aeb179ee9326 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -351,10 +351,6 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", - "lm_head" - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 8ae661bf15c49..40df9c72c5617 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -329,13 +329,6 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - "layer", - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 579253632c81e..3a7e2a9a6a576 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -597,21 +597,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, "up_proj", ], } - # LoRA specific attributes - supported_lora_modules = [ - # vision_model - "fc1", - "fc2", - "out_proj", - # text_model - "qkv_proj", # same name with vision encoder - "o_proj", - "gate_up_proj", - "down_proj", - ] - - embedding_modules = {} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index bd6661d668d9f..47bd05f140c81 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -118,11 +118,11 @@ class SupportsLoRA(Protocol): There is no need to redefine this flag if this class is in the MRO of your model class. """ - - packed_modules_mapping: ClassVar[Dict[str, List[str]]] - supported_lora_modules: ClassVar[List[str]] - embedding_modules: ClassVar[Dict[str, str]] - embedding_padding_modules: ClassVar[List[str]] + # The `embedding_module` and `embedding_padding_modules` + # are empty by default. + embedding_modules: ClassVar[Dict[str, str]] = {} + embedding_padding_modules: ClassVar[List[str]] = [] + packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {} # We can't use runtime_checkable with ClassVar for issubclass checks @@ -132,7 +132,6 @@ class _SupportsLoRAType(Protocol): supports_lora: Literal[True] packed_modules_mapping: Dict[str, List[str]] - supported_lora_modules: List[str] embedding_modules: Dict[str, str] embedding_padding_modules: List[str] @@ -155,7 +154,6 @@ def supports_lora( if not result: lora_attrs = ( "packed_modules_mapping", - "supported_lora_modules", "embedding_modules", "embedding_padding_modules", ) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index c211ca5f4f8e9..b21933dd5da7c 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -329,16 +329,6 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA): "gate_up_proj": ["w1", "w3"], } - # LoRA specific attributes - supported_lora_modules = [ - "wqkv", - "wo", - "gate_up_proj", - "w2", - ] - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index efc1496d44f05..5530e3ca708ca 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -380,10 +380,6 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "embed_tokens", "lm_head", "up_proj", - "down_proj", "gate_proj", "out_proj", "in_proj", "x_proj" - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2ff52dd789125..011d0a7aafaa3 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -452,10 +452,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", - "lm_head" - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings" diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 29473f5bbaa0a..52ab89488785e 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -522,14 +522,6 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index 878f0c895c34b..b85306c408804 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -227,21 +227,5 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM): ], } - # LoRA specific attributes - supported_lora_modules = [ - "kv_a_proj_with_mqa", - "q_a_proj", - "q_b_proj", - "kv_b_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] - - # `embedding_modules` and `embedding_padding_modules` - # are inherited from MiniCPMForCausalLM - def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 97596f9e82c64..1f278b65740c4 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1228,23 +1228,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): "up_proj", ], } - # LoRA specific attributes - supported_lora_modules = [ - # vision encoder - "fc1", - "fc2", - "out_proj", - # language model - "qkv_proj", # same name with vision encoder - "o_proj", - "gate_up_proj", - "down_proj", - # resampler - "kv_proj", - ] - - embedding_modules = {} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -1338,23 +1321,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): "up_proj", ], } - # LoRA specific attributes - supported_lora_modules = [ - # vision encoder - "fc1", - "fc2", - "out_proj", - # language model - "qkv_proj", # same name with vision encoder - "o_proj", - "gate_up_proj", - "down_proj", - # resampler - "kv_proj", - ] - - embedding_modules = {} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) @@ -1460,13 +1426,6 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): which is not conducive to the current integration logic of LoRA and bitsandbytes in vLLM. Therefore, it is necessary to separate them. """ - # Ensure that the LoRA support check passes when the class is not - # initialized, but set all these attributes to empty. - # These will be updated when an instance class is selected - packed_modules_mapping = {} - supported_lora_modules = [] - embedding_modules = {} - embedding_padding_modules = [] def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config @@ -1487,7 +1446,6 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA): # quant_config references base class members, # so update values before init is called cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) - cls.supported_lora_modules += instance_cls.supported_lora_modules cls.embedding_modules.update(instance_cls.embedding_modules) cls.embedding_padding_modules += instance_cls.embedding_padding_modules return instance_cls(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 70880eb752246..b83b69fd2c2df 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -332,10 +332,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "embed_tokens", "lm_head", "w1", "w2", "w3", - "gate" - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 1d84d25c96acb..6ce9fbda182f5 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1440,26 +1440,6 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, "merged_linear": ["gate_proj", "up_proj"] # image_projector } - # LoRA specific attributes - supported_lora_modules = [ - # language model - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", # same name with image_projector - # vision tower - "wq", - "wk", - "wv", - "wo", - "w1", - "w2", - # image_projector - "merged_linear", - ] - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 6f0b831ac2727..a42734edb39a6 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -389,9 +389,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", "o_proj", "up_proj", "down_proj", "embed_tokens", "lm_head" - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 6b05bfee94922..1ca8cad22ad93 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -273,17 +273,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ] } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "dense", - "fc1", - "fc2", - ] - - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index aa4bb52c444f7..17369cb58e36b 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -526,16 +526,6 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "embed_tokens", - "lm_head", - "w1", - "w2", - "w3", - "gate", - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a45e9463ab67b..7c46270362030 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -354,15 +354,6 @@ class QWenLMHeadModel(QWenBaseModel, SupportsPP, SupportsLoRA): "w1", ], } - # LoRA specific attributes - supported_lora_modules = [ - "c_attn", - "gate_up_proj", - "c_proj", - ] - - embedding_modules = {} - embedding_padding_modules = [] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index e3de6b64fbb39..7da6e558ff33a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -430,16 +430,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -528,16 +518,6 @@ class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - embedding_modules = {} - embedding_padding_modules = [] - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index ff10fcb4315cc..ef31f18445fd3 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -734,27 +734,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, "up_proj", ], } - # LoRA specific attributes - supported_lora_modules = [ - # language model - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", # Same name with vision encoder - # vision tower - "qkv", - "gate_proj", - "up_proj", - "attn.proj", # Distinguish patch_embed.proj - "fc1", - "fc2", - # projector - "mlp.0", - "mlp.2" - ] - - embedding_modules = {} - embedding_padding_modules = [] # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 00e4159e28cf7..c6588a47d8810 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -47,16 +47,6 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - ] - embedding_modules = {} - embedding_padding_modules = [] - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 919445267f4a6..31701abd33396 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -1048,24 +1048,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ], } - # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - # vision tower - "qkv", - "attn.proj", # Distinguish patch_embed.proj - "fc1", - "fc2", - # projector - "mlp.0", - "mlp.2" - ] - embedding_modules = {} - embedding_padding_modules = [] - # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ "lm_head.": "language_model.lm_head.", diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 61a4584abf852..56faa390fc5d1 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -667,21 +667,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA, "w1", ], } - # LoRA specific attributes - supported_lora_modules = [ - "c_attn", - "gate_up_proj", - "c_proj", - # visual module - "out_proj", - "in_proj", - "c_fc", - # resampler - "kv_proj", - ] - - embedding_modules = {} - embedding_padding_modules = [] def get_mm_mapping(self) -> MultiModelKeys: """ diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 6215ed814bf42..ad98f3b07034b 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -386,14 +386,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 9b456b2489525..b431abb76b693 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -27,6 +27,11 @@ from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.utils import divide from vllm.logger import init_logger +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA) +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA) from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) @@ -103,6 +108,23 @@ def replace_linear_class( "rowwise": RowParallelLinear, }.get(style, ReplicatedLinear) + lora_linear_cls = { + ColumnParallelLinear: { + True: ColumnParallelLinearWithShardedLoRA, # fully sharded + False: ColumnParallelLinearWithLoRA # not fully sharded + }, + RowParallelLinear: { + True: RowParallelLinearWithShardedLoRA, + False: RowParallelLinearWithLoRA + }, + # ReplicatedLinear doesn't support fully sharded LoRA yet, + # so we use the same class for both cases. + ReplicatedLinear: { + True: ReplicatedLinearWithLoRA, + False: ReplicatedLinearWithLoRA + } + } + class HFCompatibleLinear(vllm_linear_cls): """ Wrapper class that removes `output_bias` from returned output. @@ -111,6 +133,19 @@ def replace_linear_class( def forward(self, input: torch.Tensor) -> torch.Tensor: return super().forward(input)[0] + @classmethod + def get_lora_class(cls, fully_sharded: bool = False): + """ + Get the LoRA class corresponding to the current transformer + linear class. + + Args: + fully_sharded (bool): If True, select the LoRA class variant + that supports fully sharded LoRA. Defaults to False. + + """ + return lora_linear_cls[vllm_linear_cls][fully_sharded] + return HFCompatibleLinear( input_size=linear.in_features, output_size=linear.out_features, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index e24b4aeb8ae84..b99094e5d4ca6 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -360,14 +360,6 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): "gate_up_proj": ["gate_proj", "up_proj"] } - # LoRA specific attributes - # TODO : Add LoRA to the audio tower and projector. - supported_lora_modules = [ - "qkv_proj", "o_proj", "gate_up_proj", "down_proj" - ] - embedding_modules = {} - embedding_padding_modules = [] - hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index fe7c776d0a238..f22526cfad70b 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -650,9 +650,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): logger.info(msg) if self.lora_config: - assert hasattr(self.model, "supported_lora_modules" - ) and self.model.supported_lora_modules, ( - "Model does not support LoRA") assert hasattr(self.model, "embedding_modules" ), "Model does not have embedding_modules" assert hasattr(