mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[Bugfix] Fix Qwen2-VL LoRA weight loading (#11430)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
9edca6bf8f
commit
b1b1038fbd
@ -200,6 +200,11 @@ def minicpmv_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def qwen2vl_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tinyllama_lora_files():
|
||||
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from vllm.lora.models import LoRAModel
|
||||
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
lora_lst = [
|
||||
"baichuan7B", "baichuan7B-zero", "baichuan7B-zero-regex", "chatglm3-6b"
|
||||
@ -71,3 +72,32 @@ def test_load_checkpoints(
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules)
|
||||
|
||||
|
||||
def test_lora_weights_mapping(baichuan_lora_files, ):
|
||||
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
|
||||
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
|
||||
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
|
||||
embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
|
||||
expected_lora_modules: List[str] = []
|
||||
for module in supported_lora_modules:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||
"model.": "language_model.model.",
|
||||
}, )
|
||||
|
||||
lora_model = LoRAModel.from_local_checkpoint(
|
||||
baichuan_lora_files,
|
||||
expected_lora_modules,
|
||||
lora_model_id=1,
|
||||
device="cpu",
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embed_padding_modules,
|
||||
weights_mapper=hf_to_vllm_mapper,
|
||||
)
|
||||
for name in lora_model.loras:
|
||||
assert name.startswith(hf_to_vllm_mapper.orig_to_new_prefix["model."])
|
||||
|
||||
78
tests/lora/test_qwen2vl.py
Normal file
78
tests/lora/test_qwen2vl.py
Normal file
@ -0,0 +1,78 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_PATH = "Qwen/Qwen2-VL-7B-Instruct"
|
||||
|
||||
PROMPT_TEMPLATE = (
|
||||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
|
||||
"\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
|
||||
"What is in the image?<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
|
||||
IMAGE_ASSETS = [
|
||||
ImageAsset("stop_sign"),
|
||||
ImageAsset("cherry_blossom"),
|
||||
]
|
||||
|
||||
# After fine-tuning with LoRA, all generated content should start begin `A`.
|
||||
EXPECTED_OUTPUT = [
|
||||
"A stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
|
||||
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
|
||||
]
|
||||
|
||||
|
||||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
sampling_params = vllm.SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=5,
|
||||
)
|
||||
|
||||
inputs = [{
|
||||
"prompt": PROMPT_TEMPLATE,
|
||||
"multi_modal_data": {
|
||||
"image": asset.pil_image
|
||||
},
|
||||
} for asset in IMAGE_ASSETS]
|
||||
|
||||
outputs = llm.generate(
|
||||
inputs,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
if lora_id else None,
|
||||
)
|
||||
# Print the outputs.
|
||||
generated_texts: List[str] = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.mark.xfail(current_platform.is_rocm(),
|
||||
reason="Qwen2-VL dependency xformers incompatible with ROCm"
|
||||
)
|
||||
def test_qwen2vl_lora(qwen2vl_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_loras=2,
|
||||
max_lora_rank=16,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs={
|
||||
"min_pixels": 28 * 28,
|
||||
"max_pixels": 1280 * 28 * 28,
|
||||
},
|
||||
max_model_len=4096,
|
||||
)
|
||||
output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1)
|
||||
for i in range(len(EXPECTED_OUTPUT)):
|
||||
assert EXPECTED_OUTPUT[i].startswith(output1[i])
|
||||
@ -28,7 +28,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.utils import PPMissingLayer
|
||||
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -113,13 +113,14 @@ class LoRAModel(AdapterModel):
|
||||
target_embedding_padding: Optional[int] = None,
|
||||
embedding_modules: Optional[Dict[str, str]] = None,
|
||||
embedding_padding_modules: Optional[List[str]] = None,
|
||||
weights_mapper: Optional[WeightsMapper] = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a dictionary of tensors."""
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
loras: Dict[str, LoRALayerWeights] = {}
|
||||
for tensor_name, tensor in tensors.items():
|
||||
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
|
||||
tensor_name)
|
||||
tensor_name, weights_mapper)
|
||||
if module_name not in loras:
|
||||
lora_embeddings_tensor = None
|
||||
if embeddings:
|
||||
@ -187,6 +188,7 @@ class LoRAModel(AdapterModel):
|
||||
target_embedding_padding: Optional[int] = None,
|
||||
embedding_modules: Optional[Dict[str, str]] = None,
|
||||
embedding_padding_modules: Optional[List[str]] = None,
|
||||
weights_mapper: Optional[WeightsMapper] = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a local checkpoint.
|
||||
|
||||
@ -289,7 +291,8 @@ class LoRAModel(AdapterModel):
|
||||
embeddings=embeddings,
|
||||
target_embedding_padding=target_embedding_padding,
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embedding_padding_modules)
|
||||
embedding_padding_modules=embedding_padding_modules,
|
||||
weights_mapper=weights_mapper)
|
||||
|
||||
|
||||
class LoRAModelManager(AdapterModelManager):
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Set, Tuple, Type, Union
|
||||
@ -30,6 +31,8 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -91,28 +94,54 @@ def replace_submodule(model: nn.Module, module_name: str,
|
||||
return new_module
|
||||
|
||||
|
||||
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]:
|
||||
def parse_fine_tuned_lora_name(
|
||||
name: str,
|
||||
weights_mapper: Optional[WeightsMapper] = None
|
||||
) -> Tuple[str, bool, bool]:
|
||||
"""Parse the name of lora weights.
|
||||
|
||||
args:
|
||||
name: the name of the fine-tuned LoRA, e.g.
|
||||
base_model.model.dense1.weight
|
||||
weights_mapper: maps the name of weight, e.g.
|
||||
`model.` -> `language_model.model.`,
|
||||
return:
|
||||
Tuple(module_name, is_lora_a):
|
||||
module_name: the name of the module, e.g. model.dense1,
|
||||
is_lora_a whether the tensor is lora_a or lora_b.
|
||||
is_bias whether the tensor is lora bias.
|
||||
"""
|
||||
|
||||
w_mapper = None
|
||||
if weights_mapper:
|
||||
w_mapper = copy.deepcopy(weights_mapper)
|
||||
# TODO: Currently only supports mapping for prefix, mapping for
|
||||
# substr and subfix will be supported in the future.
|
||||
for attr, mapping in [
|
||||
("orig_to_new_substr", w_mapper.orig_to_new_substr),
|
||||
("orig_to_new_suffix", w_mapper.orig_to_new_suffix),
|
||||
]:
|
||||
if mapping:
|
||||
print_warning_once(
|
||||
f"vLLM currently does not support mapping of LoRA weights "
|
||||
f"for {mapping}.")
|
||||
setattr(w_mapper, attr, {})
|
||||
|
||||
mapper = (lambda name: w_mapper._map_name(name)
|
||||
if w_mapper is not None else name)
|
||||
parts = name.split(".")
|
||||
if parts[-1] == "weight" and (parts[-2] == "lora_A"
|
||||
or parts[-2] == "lora_B"):
|
||||
return ".".join(parts[2:-2]), parts[-2] == "lora_A", False
|
||||
new_name = ".".join(parts[2:-2])
|
||||
return mapper(new_name), parts[-2] == "lora_A", False
|
||||
|
||||
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False
|
||||
new_name = ".".join(parts[2:-1])
|
||||
return mapper(new_name), parts[-1] == "lora_embedding_A", False
|
||||
|
||||
if parts[-1] == "bias":
|
||||
return ".".join(parts[2:-2]), False, True
|
||||
new_name = ".".join(parts[2:-2])
|
||||
return mapper(new_name), False, True
|
||||
|
||||
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||
|
||||
|
||||
@ -92,6 +92,14 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
lora_path = get_adapter_absolute_path(lora_request.lora_path)
|
||||
|
||||
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of lora weights.
|
||||
hf_to_vllm_mapper = None
|
||||
if (hasattr(model, "hf_to_vllm_mapper")
|
||||
and model.hf_to_vllm_mapper is not None):
|
||||
hf_to_vllm_mapper = model.hf_to_vllm_mapper
|
||||
|
||||
lora = self._lora_model_cls.from_local_checkpoint(
|
||||
lora_path,
|
||||
expected_lora_modules,
|
||||
@ -103,7 +111,8 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
embedding_modules=self.embedding_modules,
|
||||
embedding_padding_modules=self.embedding_padding_modules,
|
||||
)
|
||||
weights_mapper=hf_to_vllm_mapper)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Loading lora {lora_path} failed") from e
|
||||
if lora.rank > self.lora_config.max_lora_rank:
|
||||
|
||||
@ -901,6 +901,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.": "language_model.model.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -1190,11 +1195,6 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
"model.": "language_model.model.",
|
||||
})
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user