mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-17 12:54:34 +08:00
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**
commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:18:24 2025 -0500
Add SPDX license headers to python source files
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
also be easily used by tools to help manage license compliance.
The Linux Foundation runs license scans against the codebase to help
ensure
we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
More information can be found on the SPDX site:
- https://spdx.dev/learn/handling-license-info/
Signed-off-by: Russell Bryant <rbryant@redhat.com>
commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date: Fri Jan 31 14:36:32 2025 -0500
Check for SPDX headers using pre-commit
Signed-off-by: Russell Bryant <rbryant@redhat.com>
---------
Signed-off-by: Russell Bryant <rbryant@redhat.com>
214 lines
7.9 KiB
Python
214 lines
7.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import os
|
|
import re
|
|
from typing import List, Optional, Set, Tuple, Type, Union
|
|
|
|
import huggingface_hub
|
|
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
|
HFValidationError, RepositoryNotFoundError)
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.config import LoRAConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.fully_sharded_layers import (
|
|
ColumnParallelLinearWithShardedLoRA,
|
|
MergedColumnParallelLinearWithShardedLoRA,
|
|
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
|
|
RowParallelLinearWithShardedLoRA)
|
|
# being imported for _all_lora_classes below
|
|
# yapf conflicts with isort for this block
|
|
# yapf: disable
|
|
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
|
LinearScalingRotaryEmbeddingWithLora,
|
|
LogitsProcessorWithLoRA,
|
|
MergedColumnParallelLinearWithLoRA,
|
|
MergedQKVParallelLinearWithLora,
|
|
QKVParallelLinearWithLora,
|
|
ReplicatedLinearWithLoRA,
|
|
RowParallelLinearWithLoRA,
|
|
VocabParallelEmbeddingWithLoRA)
|
|
# yapf: enable
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
|
from vllm.model_executor.models.utils import WeightsMapper
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
|
|
VocabParallelEmbeddingWithLoRA,
|
|
ColumnParallelLinearWithLoRA,
|
|
MergedColumnParallelLinearWithLoRA,
|
|
QKVParallelLinearWithLora,
|
|
MergedQKVParallelLinearWithLora,
|
|
RowParallelLinearWithLoRA,
|
|
ReplicatedLinearWithLoRA,
|
|
LogitsProcessorWithLoRA,
|
|
ColumnParallelLinearWithShardedLoRA,
|
|
QKVParallelLinearWithShardedLora,
|
|
MergedColumnParallelLinearWithShardedLoRA,
|
|
MergedQKVParallelLinearWithShardedLora,
|
|
RowParallelLinearWithShardedLoRA,
|
|
LinearScalingRotaryEmbeddingWithLora,
|
|
}
|
|
|
|
|
|
def from_layer(layer: nn.Module,
|
|
max_loras: int,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: List,
|
|
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
|
|
for lora_cls in _all_lora_classes:
|
|
# specifying kwargs so they can be easily accessed in decorator
|
|
if lora_cls.can_replace_layer(source_layer=layer,
|
|
lora_config=lora_config,
|
|
packed_modules_list=packed_modules_list,
|
|
model_config=model_config):
|
|
ret = lora_cls(layer)
|
|
ret.create_lora_weights(max_loras, lora_config, model_config)
|
|
return ret
|
|
return layer
|
|
|
|
|
|
def from_layer_logits_processor(
|
|
layer: LogitsProcessor,
|
|
lm_head: ParallelLMHead,
|
|
max_loras: int,
|
|
lora_config: LoRAConfig,
|
|
model_config: Optional[PretrainedConfig] = None,
|
|
) -> LogitsProcessorWithLoRA:
|
|
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
|
|
lm_head.weight.dtype, lm_head.weight.device,
|
|
lm_head.get_sharded_to_full_mapping())
|
|
ret.create_lora_weights(max_loras, lora_config, model_config)
|
|
return ret
|
|
|
|
|
|
def replace_submodule(model: nn.Module, module_name: str,
|
|
new_module: nn.Module) -> nn.Module:
|
|
"""Replace a submodule in a model with a new module."""
|
|
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
|
target_name = module_name.split(".")[-1]
|
|
setattr(parent, target_name, new_module)
|
|
return new_module
|
|
|
|
|
|
def parse_fine_tuned_lora_name(
|
|
name: str,
|
|
weights_mapper: Optional[WeightsMapper] = None
|
|
) -> Tuple[str, bool, bool]:
|
|
"""Parse the name of lora weights.
|
|
|
|
args:
|
|
name: the name of the fine-tuned LoRA, e.g.
|
|
base_model.model.dense1.weight
|
|
weights_mapper: maps the name of weight, e.g.
|
|
`model.` -> `language_model.model.`,
|
|
return:
|
|
Tuple(module_name, is_lora_a):
|
|
module_name: the name of the module, e.g. model.dense1,
|
|
is_lora_a whether the tensor is lora_a or lora_b.
|
|
is_bias whether the tensor is lora bias.
|
|
"""
|
|
|
|
# LoRA weight qualified name always starts with `base_model.model.`,
|
|
# so we remove the prefix `base_model.model.` to make the following
|
|
# mapping correctly.
|
|
if "base_model.model." in name:
|
|
name = name.replace("base_model.model.", "")
|
|
name = weights_mapper._map_name(name) if weights_mapper else name
|
|
# recover the prefix `base_model.model.`
|
|
name = "base_model.model." + name
|
|
|
|
parts = name.split(".")
|
|
if parts[-1] == "weight" and (parts[-2] == "lora_A"
|
|
or parts[-2] == "lora_B"):
|
|
new_name = ".".join(parts[2:-2])
|
|
return new_name, parts[-2] == "lora_A", False
|
|
|
|
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
|
new_name = ".".join(parts[2:-1])
|
|
return new_name, parts[-1] == "lora_embedding_A", False
|
|
|
|
if parts[-1] == "bias":
|
|
new_name = ".".join(parts[2:-2])
|
|
return new_name, False, True
|
|
|
|
raise ValueError(f"{name} is unsupported LoRA weight")
|
|
|
|
|
|
def is_regex_target_modules(load_modules: Union[str, List[str]],
|
|
expected_lora_modules: List[str]) -> bool:
|
|
"""
|
|
PEFT supports passing `target_modules` in the form of regular expressions,
|
|
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
|
|
determine whether the suffix in the regular expression is present in the
|
|
`expected_lora_modules`.
|
|
"""
|
|
|
|
def is_valid_regex(pattern):
|
|
try:
|
|
re.compile(pattern)
|
|
return True
|
|
except re.error:
|
|
return False
|
|
|
|
def is_subset(sub_list, full_list):
|
|
return set(sub_list).issubset(set(full_list))
|
|
|
|
# Similar to PEFT's processing logic, regex-related operations are only
|
|
# executed when the load_modules is a `str`.
|
|
if not isinstance(load_modules, str):
|
|
return False
|
|
|
|
if is_valid_regex(load_modules):
|
|
match = re.search(r"\((.*?)\)\$?$", load_modules)
|
|
if match:
|
|
suffix = match.group(1).split("|")
|
|
return is_subset(suffix, expected_lora_modules)
|
|
return False
|
|
|
|
|
|
def get_adapter_absolute_path(lora_path: str) -> str:
|
|
"""
|
|
Resolves the given lora_path to an absolute local path.
|
|
|
|
If the lora_path is identified as a Hugging Face model identifier,
|
|
it will download the model and return the local snapshot path.
|
|
Otherwise, it treats the lora_path as a local file path and
|
|
converts it to an absolute path.
|
|
|
|
Parameters:
|
|
lora_path (str): The path to the lora model, which can be an absolute path,
|
|
a relative path, or a Hugging Face model identifier.
|
|
|
|
Returns:
|
|
str: The resolved absolute local path to the lora model.
|
|
"""
|
|
|
|
# Check if the path is an absolute path. Return it no matter exists or not.
|
|
if os.path.isabs(lora_path):
|
|
return lora_path
|
|
|
|
# If the path starts with ~, expand the user home directory.
|
|
if lora_path.startswith('~'):
|
|
return os.path.expanduser(lora_path)
|
|
|
|
# Check if the expanded relative path exists locally.
|
|
if os.path.exists(lora_path):
|
|
return os.path.abspath(lora_path)
|
|
|
|
# If the path does not exist locally, assume it's a Hugging Face repo.
|
|
try:
|
|
local_snapshot_path = huggingface_hub.snapshot_download(
|
|
repo_id=lora_path)
|
|
except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
|
|
HFValidationError):
|
|
# Handle errors that may occur during the download
|
|
# Return original path instead instead of throwing error here
|
|
logger.exception("Error downloading the HuggingFace model")
|
|
return lora_path
|
|
|
|
return local_snapshot_path
|