mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 10:26:34 +08:00
[Misc] Split the LoRA code (#30253)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
87aee9ed2b
commit
67312cad11
@ -28,7 +28,7 @@ from vllm.lora.layers import (
|
|||||||
RowParallelLinearWithShardedLoRA,
|
RowParallelLinearWithShardedLoRA,
|
||||||
VocabParallelEmbeddingWithLoRA,
|
VocabParallelEmbeddingWithLoRA,
|
||||||
)
|
)
|
||||||
from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.lora.models import LoRAModel
|
from vllm.lora.lora_model import LoRAModel
|
||||||
from vllm.lora.peft_helper import PEFTHelper
|
from vllm.lora.peft_helper import PEFTHelper
|
||||||
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||||
from vllm.model_executor.models.utils import WeightsMapper
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.lora.models import LoRAModel
|
from vllm.lora.lora_model import LoRAModel
|
||||||
from vllm.lora.peft_helper import PEFTHelper
|
from vllm.lora.peft_helper import PEFTHelper
|
||||||
from vllm.lora.utils import get_adapter_absolute_path
|
from vllm.lora.utils import get_adapter_absolute_path
|
||||||
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
|
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
|
||||||
|
|||||||
@ -15,10 +15,10 @@ from vllm.lora.layers import (
|
|||||||
MergedColumnParallelLinearWithLoRA,
|
MergedColumnParallelLinearWithLoRA,
|
||||||
RowParallelLinearWithLoRA,
|
RowParallelLinearWithLoRA,
|
||||||
)
|
)
|
||||||
|
from vllm.lora.lora_model import LoRAModel
|
||||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.models import (
|
from vllm.lora.model_manager import (
|
||||||
LoRAMapping,
|
LoRAMapping,
|
||||||
LoRAModel,
|
|
||||||
LoRAModelManager,
|
LoRAModelManager,
|
||||||
LRUCacheLoRAModelManager,
|
LRUCacheLoRAModelManager,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from vllm.config import (
|
|||||||
)
|
)
|
||||||
from vllm.config.load import LoadConfig
|
from vllm.config.load import LoadConfig
|
||||||
from vllm.config.lora import LoRAConfig
|
from vllm.config.lora import LoRAConfig
|
||||||
from vllm.lora.models import LoRAMapping
|
from vllm.lora.model_manager import LoRAMapping
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.v1.worker.gpu_worker import Worker
|
from vllm.v1.worker.gpu_worker import Worker
|
||||||
|
|
||||||
|
|||||||
246
vllm/lora/lora_model.py
Normal file
246
vllm/lora/lora_model.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import safetensors.torch
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.lora_weights import LoRALayerWeights
|
||||||
|
from vllm.lora.peft_helper import PEFTHelper
|
||||||
|
from vllm.lora.utils import (
|
||||||
|
get_lora_id,
|
||||||
|
is_base_embeddding_weights,
|
||||||
|
is_regex_target_modules,
|
||||||
|
parse_fine_tuned_lora_name,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
|
from vllm.model_executor.models.utils import WeightsMapper
|
||||||
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAModel:
|
||||||
|
"""A LoRA fine-tuned model."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lora_model_id: int,
|
||||||
|
rank: int,
|
||||||
|
loras: dict[str, LoRALayerWeights],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
lora_model_id: The integer id for the lora model.
|
||||||
|
rank: lora rank.
|
||||||
|
loras: module name -> weights for lora-replaced layers.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.id = lora_model_id
|
||||||
|
|
||||||
|
assert lora_model_id > 0, (
|
||||||
|
f"a valid lora id should be greater than 0, got {self.id}"
|
||||||
|
)
|
||||||
|
self.rank = rank
|
||||||
|
self.loras: dict[str, LoRALayerWeights] = loras
|
||||||
|
|
||||||
|
def clone(self, lora_model_id: int) -> "LoRAModel":
|
||||||
|
"""Return a copy of the object with different ids.
|
||||||
|
|
||||||
|
Will share the underlying tensors."""
|
||||||
|
return self.__class__(
|
||||||
|
lora_model_id,
|
||||||
|
rank=self.rank,
|
||||||
|
loras=self.loras.copy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_lora(self, module_name: str) -> LoRALayerWeights | None:
|
||||||
|
"""Get LoRA for a given module by name"""
|
||||||
|
return self.loras.get(module_name, None)
|
||||||
|
|
||||||
|
def check_lora_name(self, lora_name: str) -> bool:
|
||||||
|
return lora_name in self.loras
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_lora_tensors(
|
||||||
|
cls,
|
||||||
|
lora_model_id: int,
|
||||||
|
tensors: dict[str, torch.Tensor],
|
||||||
|
peft_helper: PEFTHelper,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
model_vocab_size: int | None = None,
|
||||||
|
weights_mapper: WeightsMapper | None = None,
|
||||||
|
) -> "LoRAModel":
|
||||||
|
"""Create a LoRAModel from a dictionary of tensors."""
|
||||||
|
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||||
|
loras: dict[str, LoRALayerWeights] = {}
|
||||||
|
for tensor_name, tensor in tensors.items():
|
||||||
|
if is_base_embeddding_weights(tensor_name):
|
||||||
|
continue
|
||||||
|
module_name, is_lora_a = parse_fine_tuned_lora_name(
|
||||||
|
tensor_name, weights_mapper
|
||||||
|
)
|
||||||
|
if module_name not in loras:
|
||||||
|
loras[module_name] = LoRALayerWeights.from_config(
|
||||||
|
module_name, peft_helper
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_lora_a:
|
||||||
|
if (
|
||||||
|
"lora_embedding_A" in tensor_name
|
||||||
|
and model_vocab_size is not None
|
||||||
|
and model_vocab_size != tensor.shape[1]
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"The embedding LoRA size({tensor.shape[1]}) must be consistent"
|
||||||
|
f" with the base model's vocabulary size({model_vocab_size})."
|
||||||
|
)
|
||||||
|
loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
|
||||||
|
if pin_memory:
|
||||||
|
loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
|
||||||
|
else:
|
||||||
|
loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if pin_memory:
|
||||||
|
loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
|
||||||
|
|
||||||
|
return cls(lora_model_id, peft_helper.r, loras)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_local_checkpoint(
|
||||||
|
cls,
|
||||||
|
lora_dir: str,
|
||||||
|
expected_lora_modules: set[str],
|
||||||
|
peft_helper: PEFTHelper,
|
||||||
|
*,
|
||||||
|
lora_model_id: int | None = None,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype | None = None,
|
||||||
|
model_vocab_size: int | None = None,
|
||||||
|
weights_mapper: WeightsMapper | None = None,
|
||||||
|
tensorizer_config_dict: dict | None = None,
|
||||||
|
) -> "LoRAModel":
|
||||||
|
"""Create a LoRAModel from a local checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_dir: The local path that has lora data.
|
||||||
|
expected_lora_modules: Name of modules that are expected to be
|
||||||
|
replaced by lora.
|
||||||
|
peft_helper: Loaded lora configuration information.
|
||||||
|
lora_model_id: LoRA model id. If not given, automatically set by
|
||||||
|
a global counter.
|
||||||
|
device: Device where the lora model is loaded.
|
||||||
|
dtype: dtype of the lora model weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded LoRA Model.
|
||||||
|
"""
|
||||||
|
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
||||||
|
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
||||||
|
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
|
||||||
|
|
||||||
|
tensors: dict[str, torch.Tensor] = {}
|
||||||
|
unexpected_modules: list[list[str] | str] = []
|
||||||
|
|
||||||
|
def check_unexpected_modules(modules: dict):
|
||||||
|
for lora_module in modules.keys(): # noqa
|
||||||
|
if is_base_embeddding_weights(lora_module):
|
||||||
|
continue
|
||||||
|
# Handle PEFT file format where experts.base_layer is the
|
||||||
|
# gate_up_proj and experts is the down_proj
|
||||||
|
if "base_layer" in lora_module:
|
||||||
|
continue
|
||||||
|
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
|
||||||
|
# Case for expert lora weights
|
||||||
|
if ".experts" in module_name:
|
||||||
|
expert_idx = module_name.find(".experts")
|
||||||
|
expert_suffix = module_name[expert_idx + 1 :]
|
||||||
|
if expert_suffix not in expected_lora_modules:
|
||||||
|
unexpected_modules.append(module_name)
|
||||||
|
|
||||||
|
elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
|
||||||
|
unexpected_modules.append(module_name)
|
||||||
|
|
||||||
|
if unexpected_modules:
|
||||||
|
raise ValueError(
|
||||||
|
f"While loading {lora_dir}, expected"
|
||||||
|
f" target modules in {expected_lora_modules}"
|
||||||
|
f" but received {unexpected_modules}."
|
||||||
|
f" Please verify that the loaded LoRA module is correct"
|
||||||
|
)
|
||||||
|
|
||||||
|
if tensorizer_config_dict:
|
||||||
|
from tensorizer import TensorDeserializer
|
||||||
|
|
||||||
|
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
|
||||||
|
lora_tensor_path = os.path.join(
|
||||||
|
tensorizer_config.tensorizer_dir, "adapter_model.tensors"
|
||||||
|
)
|
||||||
|
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||||
|
tensors = TensorDeserializer(
|
||||||
|
lora_tensor_path,
|
||||||
|
dtype=tensorizer_config.dtype,
|
||||||
|
**tensorizer_args.deserialization_kwargs,
|
||||||
|
)
|
||||||
|
check_unexpected_modules(tensors)
|
||||||
|
|
||||||
|
elif os.path.isfile(lora_tensor_path):
|
||||||
|
# Find unexpected modules.
|
||||||
|
# Use safetensor key as a source of truth to find expected modules.
|
||||||
|
# in peft if you have target_modules A, B, C and C does not exist
|
||||||
|
# in the model it won’t error and model will be trained with A, B
|
||||||
|
# loraified. C won’t exist in the safetensor but it will exist in
|
||||||
|
# the target_modules of the adapter_config.json.
|
||||||
|
unexpected_modules = []
|
||||||
|
with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore
|
||||||
|
# Load tensors if there are only expected modules.
|
||||||
|
check_unexpected_modules(f)
|
||||||
|
for module in f.keys(): # noqa
|
||||||
|
tensors[module] = f.get_tensor(module)
|
||||||
|
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
|
||||||
|
# When a bin/pt file is provided, we rely on config to find
|
||||||
|
# unexpected modules.
|
||||||
|
unexpected_modules = []
|
||||||
|
target_modules = peft_helper.target_modules
|
||||||
|
if not isinstance(target_modules, list):
|
||||||
|
target_modules = [target_modules]
|
||||||
|
for module in target_modules:
|
||||||
|
# Compatible with more modules,
|
||||||
|
# such as:layers.11.self_attn.k_proj
|
||||||
|
part_name = module.split(".")[-1]
|
||||||
|
if part_name not in expected_lora_modules:
|
||||||
|
unexpected_modules.append(module)
|
||||||
|
# loaded lora's target modules must be a subset of
|
||||||
|
# expected_lora_modules. It is not reliable. See
|
||||||
|
# https://github.com/vllm-project/vllm/pull/5909. But there's no
|
||||||
|
# other better mechanism.
|
||||||
|
if unexpected_modules and not is_regex_target_modules(
|
||||||
|
peft_helper.target_modules, expected_lora_modules
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"While loading {lora_dir}, expected"
|
||||||
|
f" target modules in {expected_lora_modules}"
|
||||||
|
f" but received {unexpected_modules}."
|
||||||
|
f" Please verify that the loaded LoRA module is correct"
|
||||||
|
)
|
||||||
|
lora_file_path = (
|
||||||
|
lora_bin_file_path
|
||||||
|
if os.path.isfile(lora_bin_file_path)
|
||||||
|
else lora_pt_file_path
|
||||||
|
)
|
||||||
|
tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
||||||
|
|
||||||
|
return cls.from_lora_tensors(
|
||||||
|
lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
|
||||||
|
tensors=tensors,
|
||||||
|
peft_helper=peft_helper,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
model_vocab_size=model_vocab_size,
|
||||||
|
weights_mapper=weights_mapper,
|
||||||
|
)
|
||||||
@ -2,38 +2,32 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import os
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
import safetensors.torch
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.config.lora import LoRAConfig
|
from vllm.config.lora import LoRAConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
|
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoE3DWithLoRA, LoRAMapping
|
||||||
|
from vllm.lora.lora_model import LoRAModel
|
||||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||||
from vllm.lora.peft_helper import PEFTHelper
|
|
||||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||||
from vllm.lora.utils import (
|
from vllm.lora.utils import (
|
||||||
from_layer,
|
from_layer,
|
||||||
from_layer_logits_processor,
|
from_layer_logits_processor,
|
||||||
get_supported_lora_modules,
|
get_supported_lora_modules,
|
||||||
is_base_embeddding_weights,
|
|
||||||
is_moe_model,
|
is_moe_model,
|
||||||
is_regex_target_modules,
|
|
||||||
parse_fine_tuned_lora_name,
|
|
||||||
process_packed_modules_mapping,
|
process_packed_modules_mapping,
|
||||||
replace_submodule,
|
replace_submodule,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
|
||||||
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||||
from vllm.model_executor.models.interfaces import is_pooling_model
|
from vllm.model_executor.models.interfaces import is_pooling_model
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
from vllm.model_executor.models.utils import PPMissingLayer
|
||||||
from vllm.utils.cache import LRUCache
|
from vllm.utils.cache import LRUCache
|
||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
|
|
||||||
@ -53,233 +47,6 @@ class AdapterLRUCache(LRUCache[int, T]):
|
|||||||
return super()._on_remove(key, value)
|
return super()._on_remove(key, value)
|
||||||
|
|
||||||
|
|
||||||
_GLOBAL_LORA_ID = 0
|
|
||||||
|
|
||||||
|
|
||||||
def get_lora_id():
|
|
||||||
global _GLOBAL_LORA_ID
|
|
||||||
_GLOBAL_LORA_ID += 1
|
|
||||||
return _GLOBAL_LORA_ID
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModel:
|
|
||||||
"""A LoRA fine-tuned model."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
lora_model_id: int,
|
|
||||||
rank: int,
|
|
||||||
loras: dict[str, LoRALayerWeights],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
lora_model_id: The integer id for the lora model.
|
|
||||||
rank: lora rank.
|
|
||||||
loras: module name -> weights for lora-replaced layers.
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.id = lora_model_id
|
|
||||||
|
|
||||||
assert lora_model_id > 0, (
|
|
||||||
f"a valid lora id should be greater than 0, got {self.id}"
|
|
||||||
)
|
|
||||||
self.rank = rank
|
|
||||||
self.loras: dict[str, LoRALayerWeights] = loras
|
|
||||||
|
|
||||||
def clone(self, lora_model_id: int) -> "LoRAModel":
|
|
||||||
"""Return a copy of the object with different ids.
|
|
||||||
|
|
||||||
Will share the underlying tensors."""
|
|
||||||
return self.__class__(
|
|
||||||
lora_model_id,
|
|
||||||
rank=self.rank,
|
|
||||||
loras=self.loras.copy(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_lora(self, module_name: str) -> LoRALayerWeights | None:
|
|
||||||
"""Get LoRA for a given module by name"""
|
|
||||||
return self.loras.get(module_name, None)
|
|
||||||
|
|
||||||
def check_lora_name(self, lora_name: str) -> bool:
|
|
||||||
return lora_name in self.loras
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_lora_tensors(
|
|
||||||
cls,
|
|
||||||
lora_model_id: int,
|
|
||||||
tensors: dict[str, torch.Tensor],
|
|
||||||
peft_helper: PEFTHelper,
|
|
||||||
device: str = "cuda",
|
|
||||||
dtype: torch.dtype | None = None,
|
|
||||||
model_vocab_size: int | None = None,
|
|
||||||
weights_mapper: WeightsMapper | None = None,
|
|
||||||
) -> "LoRAModel":
|
|
||||||
"""Create a LoRAModel from a dictionary of tensors."""
|
|
||||||
|
|
||||||
loras: dict[str, LoRALayerWeights] = {}
|
|
||||||
for tensor_name, tensor in tensors.items():
|
|
||||||
if is_base_embeddding_weights(tensor_name):
|
|
||||||
continue
|
|
||||||
module_name, is_lora_a = parse_fine_tuned_lora_name(
|
|
||||||
tensor_name, weights_mapper
|
|
||||||
)
|
|
||||||
if module_name not in loras:
|
|
||||||
loras[module_name] = LoRALayerWeights.from_config(
|
|
||||||
module_name, peft_helper
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_lora_a:
|
|
||||||
if (
|
|
||||||
"lora_embedding_A" in tensor_name
|
|
||||||
and model_vocab_size is not None
|
|
||||||
and model_vocab_size != tensor.shape[1]
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"The embedding LoRA size({tensor.shape[1]}) must be consistent"
|
|
||||||
f" with the base model's vocabulary size({model_vocab_size})."
|
|
||||||
)
|
|
||||||
loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
|
|
||||||
else:
|
|
||||||
loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
|
|
||||||
return cls(lora_model_id, peft_helper.r, loras)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_local_checkpoint(
|
|
||||||
cls,
|
|
||||||
lora_dir: str,
|
|
||||||
expected_lora_modules: set[str],
|
|
||||||
peft_helper: PEFTHelper,
|
|
||||||
*,
|
|
||||||
lora_model_id: int | None = None,
|
|
||||||
device: str = "cuda",
|
|
||||||
dtype: torch.dtype | None = None,
|
|
||||||
model_vocab_size: int | None = None,
|
|
||||||
weights_mapper: WeightsMapper | None = None,
|
|
||||||
tensorizer_config_dict: dict | None = None,
|
|
||||||
) -> "LoRAModel":
|
|
||||||
"""Create a LoRAModel from a local checkpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
lora_dir: The local path that has lora data.
|
|
||||||
expected_lora_modules: Name of modules that are expected to be
|
|
||||||
replaced by lora.
|
|
||||||
peft_helper: Loaded lora configuration information.
|
|
||||||
lora_model_id: LoRA model id. If not given, automatically set by
|
|
||||||
a global counter.
|
|
||||||
device: Device where the lora model is loaded.
|
|
||||||
dtype: dtype of the lora model weights.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Loaded LoRA Model.
|
|
||||||
"""
|
|
||||||
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
|
||||||
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
|
||||||
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
|
|
||||||
|
|
||||||
tensors: dict[str, torch.Tensor] = {}
|
|
||||||
unexpected_modules: list[list[str] | str] = []
|
|
||||||
|
|
||||||
def check_unexpected_modules(modules: dict):
|
|
||||||
for lora_module in modules.keys(): # noqa
|
|
||||||
if is_base_embeddding_weights(lora_module):
|
|
||||||
continue
|
|
||||||
# Handle PEFT file format where experts.base_layer is the
|
|
||||||
# gate_up_proj and experts is the down_proj
|
|
||||||
if "base_layer" in lora_module:
|
|
||||||
continue
|
|
||||||
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
|
|
||||||
# Case for expert lora weights
|
|
||||||
if ".experts" in module_name:
|
|
||||||
expert_idx = module_name.find(".experts")
|
|
||||||
expert_suffix = module_name[expert_idx + 1 :]
|
|
||||||
if expert_suffix not in expected_lora_modules:
|
|
||||||
unexpected_modules.append(module_name)
|
|
||||||
|
|
||||||
elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
|
|
||||||
unexpected_modules.append(module_name)
|
|
||||||
|
|
||||||
if unexpected_modules:
|
|
||||||
raise ValueError(
|
|
||||||
f"While loading {lora_dir}, expected"
|
|
||||||
f" target modules in {expected_lora_modules}"
|
|
||||||
f" but received {unexpected_modules}."
|
|
||||||
f" Please verify that the loaded LoRA module is correct"
|
|
||||||
)
|
|
||||||
|
|
||||||
if tensorizer_config_dict:
|
|
||||||
from tensorizer import TensorDeserializer
|
|
||||||
|
|
||||||
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
|
|
||||||
lora_tensor_path = os.path.join(
|
|
||||||
tensorizer_config.tensorizer_dir, "adapter_model.tensors"
|
|
||||||
)
|
|
||||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
|
||||||
tensors = TensorDeserializer(
|
|
||||||
lora_tensor_path,
|
|
||||||
dtype=tensorizer_config.dtype,
|
|
||||||
**tensorizer_args.deserialization_kwargs,
|
|
||||||
)
|
|
||||||
check_unexpected_modules(tensors)
|
|
||||||
|
|
||||||
elif os.path.isfile(lora_tensor_path):
|
|
||||||
# Find unexpected modules.
|
|
||||||
# Use safetensor key as a source of truth to find expected modules.
|
|
||||||
# in peft if you have target_modules A, B, C and C does not exist
|
|
||||||
# in the model it won’t error and model will be trained with A, B
|
|
||||||
# loraified. C won’t exist in the safetensor but it will exist in
|
|
||||||
# the target_modules of the adapter_config.json.
|
|
||||||
unexpected_modules = []
|
|
||||||
with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore
|
|
||||||
# Load tensors if there are only expected modules.
|
|
||||||
check_unexpected_modules(f)
|
|
||||||
for module in f.keys(): # noqa
|
|
||||||
tensors[module] = f.get_tensor(module)
|
|
||||||
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
|
|
||||||
# When a bin/pt file is provided, we rely on config to find
|
|
||||||
# unexpected modules.
|
|
||||||
unexpected_modules = []
|
|
||||||
target_modules = peft_helper.target_modules
|
|
||||||
if not isinstance(target_modules, list):
|
|
||||||
target_modules = [target_modules]
|
|
||||||
for module in target_modules:
|
|
||||||
# Compatible with more modules,
|
|
||||||
# such as:layers.11.self_attn.k_proj
|
|
||||||
part_name = module.split(".")[-1]
|
|
||||||
if part_name not in expected_lora_modules:
|
|
||||||
unexpected_modules.append(module)
|
|
||||||
# loaded lora's target modules must be a subset of
|
|
||||||
# expected_lora_modules. It is not reliable. See
|
|
||||||
# https://github.com/vllm-project/vllm/pull/5909. But there's no
|
|
||||||
# other better mechanism.
|
|
||||||
if unexpected_modules and not is_regex_target_modules(
|
|
||||||
peft_helper.target_modules, expected_lora_modules
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"While loading {lora_dir}, expected"
|
|
||||||
f" target modules in {expected_lora_modules}"
|
|
||||||
f" but received {unexpected_modules}."
|
|
||||||
f" Please verify that the loaded LoRA module is correct"
|
|
||||||
)
|
|
||||||
lora_file_path = (
|
|
||||||
lora_bin_file_path
|
|
||||||
if os.path.isfile(lora_bin_file_path)
|
|
||||||
else lora_pt_file_path
|
|
||||||
)
|
|
||||||
tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
|
||||||
|
|
||||||
return cls.from_lora_tensors(
|
|
||||||
lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
|
|
||||||
tensors=tensors,
|
|
||||||
peft_helper=peft_helper,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
model_vocab_size=model_vocab_size,
|
|
||||||
weights_mapper=weights_mapper,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelManager:
|
class LoRAModelManager:
|
||||||
"""A manager that manages multiple LoRA-fine-tuned models."""
|
"""A manager that manages multiple LoRA-fine-tuned models."""
|
||||||
|
|
||||||
@ -48,6 +48,15 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_GLOBAL_LORA_ID = 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_lora_id():
|
||||||
|
global _GLOBAL_LORA_ID
|
||||||
|
_GLOBAL_LORA_ID += 1
|
||||||
|
return _GLOBAL_LORA_ID
|
||||||
|
|
||||||
|
|
||||||
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
|
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
|
||||||
VocabParallelEmbeddingWithLoRA,
|
VocabParallelEmbeddingWithLoRA,
|
||||||
ColumnParallelLinearWithLoRA,
|
ColumnParallelLinearWithLoRA,
|
||||||
|
|||||||
@ -8,8 +8,8 @@ import torch
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.models import (
|
from vllm.lora.lora_model import LoRAModel
|
||||||
LoRAModel,
|
from vllm.lora.model_manager import (
|
||||||
LoRAModelManager,
|
LoRAModelManager,
|
||||||
LRUCacheLoRAModelManager,
|
LRUCacheLoRAModelManager,
|
||||||
create_lora_manager,
|
create_lora_manager,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user