From 67312cad11835bd75ca55fda83708d4806b82436 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 9 Dec 2025 00:59:31 +0800 Subject: [PATCH] [Misc] Split the LoRA code (#30253) Signed-off-by: Jee Jee Li --- tests/lora/test_layers.py | 2 +- tests/lora/test_lora_checkpoints.py | 2 +- tests/lora/test_lora_huggingface.py | 2 +- tests/lora/test_lora_manager.py | 4 +- tests/lora/test_worker.py | 2 +- vllm/lora/lora_model.py | 246 ++++++++++++++++++++++ vllm/lora/{models.py => model_manager.py} | 237 +-------------------- vllm/lora/utils.py | 9 + vllm/lora/worker_manager.py | 4 +- 9 files changed, 265 insertions(+), 243 deletions(-) create mode 100644 vllm/lora/lora_model.py rename vllm/lora/{models.py => model_manager.py} (74%) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 9df3a07a9e5e..47d1fcfe9a0c 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -28,7 +28,7 @@ from vllm.lora.layers import ( RowParallelLinearWithShardedLoRA, VocabParallelEmbeddingWithLoRA, ) -from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.model_executor.layers.linear import ( ColumnParallelLinear, diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index e9653a2fedfa..e6816e83da00 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -3,7 +3,7 @@ import pytest -from vllm.lora.models import LoRAModel +from vllm.lora.lora_model import LoRAModel from vllm.lora.peft_helper import PEFTHelper from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM from vllm.model_executor.models.utils import WeightsMapper diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index 3348d2f8ce65..7c7f4eb4b626 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -3,7 +3,7 @@ import pytest -from vllm.lora.models import LoRAModel +from vllm.lora.lora_model import LoRAModel from vllm.lora.peft_helper import PEFTHelper from vllm.lora.utils import get_adapter_absolute_path from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 081f14d6fabf..50f17ced5dd7 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -15,10 +15,10 @@ from vllm.lora.layers import ( MergedColumnParallelLinearWithLoRA, RowParallelLinearWithLoRA, ) +from vllm.lora.lora_model import LoRAModel from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.models import ( +from vllm.lora.model_manager import ( LoRAMapping, - LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, ) diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 54059ec56190..445aaf9cb7d1 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -16,7 +16,7 @@ from vllm.config import ( ) from vllm.config.load import LoadConfig from vllm.config.lora import LoRAConfig -from vllm.lora.models import LoRAMapping +from vllm.lora.model_manager import LoRAMapping from vllm.lora.request import LoRARequest from vllm.v1.worker.gpu_worker import Worker diff --git a/vllm/lora/lora_model.py b/vllm/lora/lora_model.py new file mode 100644 index 000000000000..db170f13ae1c --- /dev/null +++ b/vllm/lora/lora_model.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os + +import safetensors.torch +import torch + +from vllm.logger import init_logger +from vllm.lora.lora_weights import LoRALayerWeights +from vllm.lora.peft_helper import PEFTHelper +from vllm.lora.utils import ( + get_lora_id, + is_base_embeddding_weights, + is_regex_target_modules, + parse_fine_tuned_lora_name, +) +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_executor.models.utils import WeightsMapper +from vllm.utils.platform_utils import is_pin_memory_available + +logger = init_logger(__name__) + + +class LoRAModel: + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: dict[str, LoRALayerWeights], + ) -> None: + """ + Args: + lora_model_id: The integer id for the lora model. + rank: lora rank. + loras: module name -> weights for lora-replaced layers. + + """ + self.id = lora_model_id + + assert lora_model_id > 0, ( + f"a valid lora id should be greater than 0, got {self.id}" + ) + self.rank = rank + self.loras: dict[str, LoRALayerWeights] = loras + + def clone(self, lora_model_id: int) -> "LoRAModel": + """Return a copy of the object with different ids. + + Will share the underlying tensors.""" + return self.__class__( + lora_model_id, + rank=self.rank, + loras=self.loras.copy(), + ) + + def get_lora(self, module_name: str) -> LoRALayerWeights | None: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + def check_lora_name(self, lora_name: str) -> bool: + return lora_name in self.loras + + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + tensors: dict[str, torch.Tensor], + peft_helper: PEFTHelper, + device: str = "cuda", + dtype: torch.dtype | None = None, + model_vocab_size: int | None = None, + weights_mapper: WeightsMapper | None = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and is_pin_memory_available() + loras: dict[str, LoRALayerWeights] = {} + for tensor_name, tensor in tensors.items(): + if is_base_embeddding_weights(tensor_name): + continue + module_name, is_lora_a = parse_fine_tuned_lora_name( + tensor_name, weights_mapper + ) + if module_name not in loras: + loras[module_name] = LoRALayerWeights.from_config( + module_name, peft_helper + ) + + if is_lora_a: + if ( + "lora_embedding_A" in tensor_name + and model_vocab_size is not None + and model_vocab_size != tensor.shape[1] + ): + raise RuntimeError( + f"The embedding LoRA size({tensor.shape[1]}) must be consistent" + f" with the base model's vocabulary size({model_vocab_size})." + ) + loras[module_name].lora_a = tensor.to(device=device, dtype=dtype) + if pin_memory: + loras[module_name].lora_a = loras[module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, dtype=dtype) + + if pin_memory: + loras[module_name].lora_b = loras[module_name].lora_b.pin_memory() + + return cls(lora_model_id, peft_helper.r, loras) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + expected_lora_modules: set[str], + peft_helper: PEFTHelper, + *, + lora_model_id: int | None = None, + device: str = "cuda", + dtype: torch.dtype | None = None, + model_vocab_size: int | None = None, + weights_mapper: WeightsMapper | None = None, + tensorizer_config_dict: dict | None = None, + ) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint. + + Args: + lora_dir: The local path that has lora data. + expected_lora_modules: Name of modules that are expected to be + replaced by lora. + peft_helper: Loaded lora configuration information. + lora_model_id: LoRA model id. If not given, automatically set by + a global counter. + device: Device where the lora model is loaded. + dtype: dtype of the lora model weights. + + Returns: + Loaded LoRA Model. + """ + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") + + tensors: dict[str, torch.Tensor] = {} + unexpected_modules: list[list[str] | str] = [] + + def check_unexpected_modules(modules: dict): + for lora_module in modules.keys(): # noqa + if is_base_embeddding_weights(lora_module): + continue + # Handle PEFT file format where experts.base_layer is the + # gate_up_proj and experts is the down_proj + if "base_layer" in lora_module: + continue + module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) + # Case for expert lora weights + if ".experts" in module_name: + expert_idx = module_name.find(".experts") + expert_suffix = module_name[expert_idx + 1 :] + if expert_suffix not in expected_lora_modules: + unexpected_modules.append(module_name) + + elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules: + unexpected_modules.append(module_name) + + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct" + ) + + if tensorizer_config_dict: + from tensorizer import TensorDeserializer + + tensorizer_config = TensorizerConfig(**tensorizer_config_dict) + lora_tensor_path = os.path.join( + tensorizer_config.tensorizer_dir, "adapter_model.tensors" + ) + tensorizer_args = tensorizer_config._construct_tensorizer_args() + tensors = TensorDeserializer( + lora_tensor_path, + dtype=tensorizer_config.dtype, + **tensorizer_args.deserialization_kwargs, + ) + check_unexpected_modules(tensors) + + elif os.path.isfile(lora_tensor_path): + # Find unexpected modules. + # Use safetensor key as a source of truth to find expected modules. + # in peft if you have target_modules A, B, C and C does not exist + # in the model it won’t error and model will be trained with A, B + # loraified. C won’t exist in the safetensor but it will exist in + # the target_modules of the adapter_config.json. + unexpected_modules = [] + with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore + # Load tensors if there are only expected modules. + check_unexpected_modules(f) + for module in f.keys(): # noqa + tensors[module] = f.get_tensor(module) + elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path): + # When a bin/pt file is provided, we rely on config to find + # unexpected modules. + unexpected_modules = [] + target_modules = peft_helper.target_modules + if not isinstance(target_modules, list): + target_modules = [target_modules] + for module in target_modules: + # Compatible with more modules, + # such as:layers.11.self_attn.k_proj + part_name = module.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of + # expected_lora_modules. It is not reliable. See + # https://github.com/vllm-project/vllm/pull/5909. But there's no + # other better mechanism. + if unexpected_modules and not is_regex_target_modules( + peft_helper.target_modules, expected_lora_modules + ): + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct" + ) + lora_file_path = ( + lora_bin_file_path + if os.path.isfile(lora_bin_file_path) + else lora_pt_file_path + ) + tensors = torch.load(lora_file_path, map_location=device, weights_only=True) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + return cls.from_lora_tensors( + lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, + tensors=tensors, + peft_helper=peft_helper, + device=device, + dtype=dtype, + model_vocab_size=model_vocab_size, + weights_mapper=weights_mapper, + ) diff --git a/vllm/lora/models.py b/vllm/lora/model_manager.py similarity index 74% rename from vllm/lora/models.py rename to vllm/lora/model_manager.py index 567ffce4e75f..44e0448d92de 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/model_manager.py @@ -2,38 +2,32 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math -import os from collections.abc import Callable from typing import TypeVar import regex as re -import safetensors.torch import torch from torch import nn from vllm.config.lora import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping +from vllm.lora.lora_model import LoRAModel from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import get_punica_wrapper from vllm.lora.utils import ( from_layer, from_layer_logits_processor, get_supported_lora_modules, - is_base_embeddding_weights, is_moe_model, - is_regex_target_modules, - parse_fine_tuned_lora_name, process_packed_modules_mapping, replace_submodule, ) from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.interfaces import is_pooling_model from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper +from vllm.model_executor.models.utils import PPMissingLayer from vllm.utils.cache import LRUCache from vllm.utils.platform_utils import is_pin_memory_available @@ -53,233 +47,6 @@ class AdapterLRUCache(LRUCache[int, T]): return super()._on_remove(key, value) -_GLOBAL_LORA_ID = 0 - - -def get_lora_id(): - global _GLOBAL_LORA_ID - _GLOBAL_LORA_ID += 1 - return _GLOBAL_LORA_ID - - -class LoRAModel: - """A LoRA fine-tuned model.""" - - def __init__( - self, - lora_model_id: int, - rank: int, - loras: dict[str, LoRALayerWeights], - ) -> None: - """ - Args: - lora_model_id: The integer id for the lora model. - rank: lora rank. - loras: module name -> weights for lora-replaced layers. - - """ - self.id = lora_model_id - - assert lora_model_id > 0, ( - f"a valid lora id should be greater than 0, got {self.id}" - ) - self.rank = rank - self.loras: dict[str, LoRALayerWeights] = loras - - def clone(self, lora_model_id: int) -> "LoRAModel": - """Return a copy of the object with different ids. - - Will share the underlying tensors.""" - return self.__class__( - lora_model_id, - rank=self.rank, - loras=self.loras.copy(), - ) - - def get_lora(self, module_name: str) -> LoRALayerWeights | None: - """Get LoRA for a given module by name""" - return self.loras.get(module_name, None) - - def check_lora_name(self, lora_name: str) -> bool: - return lora_name in self.loras - - @classmethod - def from_lora_tensors( - cls, - lora_model_id: int, - tensors: dict[str, torch.Tensor], - peft_helper: PEFTHelper, - device: str = "cuda", - dtype: torch.dtype | None = None, - model_vocab_size: int | None = None, - weights_mapper: WeightsMapper | None = None, - ) -> "LoRAModel": - """Create a LoRAModel from a dictionary of tensors.""" - - loras: dict[str, LoRALayerWeights] = {} - for tensor_name, tensor in tensors.items(): - if is_base_embeddding_weights(tensor_name): - continue - module_name, is_lora_a = parse_fine_tuned_lora_name( - tensor_name, weights_mapper - ) - if module_name not in loras: - loras[module_name] = LoRALayerWeights.from_config( - module_name, peft_helper - ) - - if is_lora_a: - if ( - "lora_embedding_A" in tensor_name - and model_vocab_size is not None - and model_vocab_size != tensor.shape[1] - ): - raise RuntimeError( - f"The embedding LoRA size({tensor.shape[1]}) must be consistent" - f" with the base model's vocabulary size({model_vocab_size})." - ) - loras[module_name].lora_a = tensor.to(device=device, dtype=dtype) - else: - loras[module_name].lora_b = tensor.to(device=device, dtype=dtype) - return cls(lora_model_id, peft_helper.r, loras) - - @classmethod - def from_local_checkpoint( - cls, - lora_dir: str, - expected_lora_modules: set[str], - peft_helper: PEFTHelper, - *, - lora_model_id: int | None = None, - device: str = "cuda", - dtype: torch.dtype | None = None, - model_vocab_size: int | None = None, - weights_mapper: WeightsMapper | None = None, - tensorizer_config_dict: dict | None = None, - ) -> "LoRAModel": - """Create a LoRAModel from a local checkpoint. - - Args: - lora_dir: The local path that has lora data. - expected_lora_modules: Name of modules that are expected to be - replaced by lora. - peft_helper: Loaded lora configuration information. - lora_model_id: LoRA model id. If not given, automatically set by - a global counter. - device: Device where the lora model is loaded. - dtype: dtype of the lora model weights. - - Returns: - Loaded LoRA Model. - """ - lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") - lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") - lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt") - - tensors: dict[str, torch.Tensor] = {} - unexpected_modules: list[list[str] | str] = [] - - def check_unexpected_modules(modules: dict): - for lora_module in modules.keys(): # noqa - if is_base_embeddding_weights(lora_module): - continue - # Handle PEFT file format where experts.base_layer is the - # gate_up_proj and experts is the down_proj - if "base_layer" in lora_module: - continue - module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper) - # Case for expert lora weights - if ".experts" in module_name: - expert_idx = module_name.find(".experts") - expert_suffix = module_name[expert_idx + 1 :] - if expert_suffix not in expected_lora_modules: - unexpected_modules.append(module_name) - - elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules: - unexpected_modules.append(module_name) - - if unexpected_modules: - raise ValueError( - f"While loading {lora_dir}, expected" - f" target modules in {expected_lora_modules}" - f" but received {unexpected_modules}." - f" Please verify that the loaded LoRA module is correct" - ) - - if tensorizer_config_dict: - from tensorizer import TensorDeserializer - - tensorizer_config = TensorizerConfig(**tensorizer_config_dict) - lora_tensor_path = os.path.join( - tensorizer_config.tensorizer_dir, "adapter_model.tensors" - ) - tensorizer_args = tensorizer_config._construct_tensorizer_args() - tensors = TensorDeserializer( - lora_tensor_path, - dtype=tensorizer_config.dtype, - **tensorizer_args.deserialization_kwargs, - ) - check_unexpected_modules(tensors) - - elif os.path.isfile(lora_tensor_path): - # Find unexpected modules. - # Use safetensor key as a source of truth to find expected modules. - # in peft if you have target_modules A, B, C and C does not exist - # in the model it won’t error and model will be trained with A, B - # loraified. C won’t exist in the safetensor but it will exist in - # the target_modules of the adapter_config.json. - unexpected_modules = [] - with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore - # Load tensors if there are only expected modules. - check_unexpected_modules(f) - for module in f.keys(): # noqa - tensors[module] = f.get_tensor(module) - elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path): - # When a bin/pt file is provided, we rely on config to find - # unexpected modules. - unexpected_modules = [] - target_modules = peft_helper.target_modules - if not isinstance(target_modules, list): - target_modules = [target_modules] - for module in target_modules: - # Compatible with more modules, - # such as:layers.11.self_attn.k_proj - part_name = module.split(".")[-1] - if part_name not in expected_lora_modules: - unexpected_modules.append(module) - # loaded lora's target modules must be a subset of - # expected_lora_modules. It is not reliable. See - # https://github.com/vllm-project/vllm/pull/5909. But there's no - # other better mechanism. - if unexpected_modules and not is_regex_target_modules( - peft_helper.target_modules, expected_lora_modules - ): - raise ValueError( - f"While loading {lora_dir}, expected" - f" target modules in {expected_lora_modules}" - f" but received {unexpected_modules}." - f" Please verify that the loaded LoRA module is correct" - ) - lora_file_path = ( - lora_bin_file_path - if os.path.isfile(lora_bin_file_path) - else lora_pt_file_path - ) - tensors = torch.load(lora_file_path, map_location=device, weights_only=True) - else: - raise ValueError(f"{lora_dir} doesn't contain tensors") - - return cls.from_lora_tensors( - lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id, - tensors=tensors, - peft_helper=peft_helper, - device=device, - dtype=dtype, - model_vocab_size=model_vocab_size, - weights_mapper=weights_mapper, - ) - - class LoRAModelManager: """A manager that manages multiple LoRA-fine-tuned models.""" diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 47484b2b984d..4d264c06826b 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -48,6 +48,15 @@ if TYPE_CHECKING: logger = init_logger(__name__) +_GLOBAL_LORA_ID = 0 + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + _all_lora_classes: set[type[BaseLayerWithLoRA]] = { VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 7d77ba7247ef..28c2a53d84e4 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -8,8 +8,8 @@ import torch from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.lora.models import ( - LoRAModel, +from vllm.lora.lora_model import LoRAModel +from vllm.lora.model_manager import ( LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager,