[Model][LoRA]LoRA support added for MiniCPMV2.5 (#7199)

This commit is contained in:
Jee Jee Li 2024-09-29 14:59:45 +08:00 committed by GitHub
parent bc2ef1f77c
commit 3d49776bbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 377 additions and 30 deletions

View File

@ -194,6 +194,11 @@ def baichuan_zero_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
@pytest.fixture(scope="session")
def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")

View File

@ -0,0 +1,71 @@
from typing import List
import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")
IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]
# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=5,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)
inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]
outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
def test_minicpmv_lora(minicpmv_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_num_seqs=2,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
trust_remote_code=True,
)
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output1[i])
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output2[i])

View File

@ -0,0 +1,95 @@
from typing import List
import pytest
import vllm
from vllm.assets.image import ImageAsset
from vllm.lora.request import LoRARequest
from ..utils import multi_gpu_test
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
PROMPT_TEMPLATE = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n")
IMAGE_ASSETS = [
ImageAsset("stop_sign"),
ImageAsset("cherry_blossom"),
]
# After fine-tuning with LoRA, all generated content should start begin `A`.
EXPECTED_OUTPUT = [
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
"A pink cherry blossom tree with a blue sky in the background.",
]
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
sampling_params = vllm.SamplingParams(
temperature=0,
max_tokens=5,
stop_token_ids=[128001, 128009], # eos_id, eot_id
)
inputs = [{
"prompt": PROMPT_TEMPLATE,
"multi_modal_data": {
"image": asset.pil_image
},
} for asset in IMAGE_ASSETS]
outputs = llm.generate(
inputs,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None,
)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=2,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=2,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
@multi_gpu_test(num_gpus=4)
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=2,
max_loras=4,
max_lora_rank=64,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=fully_sharded,
)
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
for i in range(len(EXPECTED_OUTPUT)):
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])

View File

@ -24,7 +24,9 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica import PunicaWrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.model_executor.models.interfaces import (SupportsLoRA,
supports_multimodal)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.utils import is_pin_memory_available
@ -332,6 +334,8 @@ class LoRAModelManager(AdapterModelManager):
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = copy.deepcopy(
self.model.packed_modules_mapping)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = supports_multimodal(self.model)
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
# Dict instead of a Set for compatibility with LRUCache.
@ -437,12 +441,22 @@ class LoRAModelManager(AdapterModelManager):
continue
if not self._match_target_modules(module_name):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
module_name,
)
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule(
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config))
# LinearScalingRotaryEmbeddingWithLora is used to handle
# long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
@ -460,6 +474,15 @@ class LoRAModelManager(AdapterModelManager):
module, self.lora_slots,
self.lora_config,
self.model.config))
# In some models, especially multimodal ones, layers with the same
# name may have different types, such as nn.Linear and
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
# LoRA layers, leading to assertion error. The following check
# aims to prevent this error
if self.supports_mm and not isinstance(new_module,
BaseLayerWithLoRA):
continue
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
@ -478,9 +501,10 @@ class LoRAModelManager(AdapterModelManager):
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
if not self._match_target_modules(module_name) or not isinstance(
module, BaseLayerWithLoRA) or isinstance(
module, LinearScalingRotaryEmbeddingWithLora):
if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
or self._filter_unsupported_mm_module(module_name)):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
@ -541,6 +565,19 @@ class LoRAModelManager(AdapterModelManager):
module_name) or target_module == module_name
for target_module in self.supported_lora_modules)
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if self.supports_mm:
prefix = module_name.split(".")[0]
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
return (prefix in module_mapping.connector
or prefix in module_mapping.tower_model)
return False
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]

View File

@ -36,7 +36,7 @@ from transformers import PretrainedConfig
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -50,7 +50,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.models.utils import LLMWrapper
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
@ -59,10 +61,10 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA
_KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head",
"llm.model": "llm",
}
@ -621,6 +623,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
default_weight_loader)
weight_loader(param, loaded_weight)
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(language_model="llm",
connector="resampler",
tower_model="vpm")
def init_llm(
self,
config: PretrainedConfig,
@ -669,9 +679,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
return MiniCPMModel(config,
cache_config=cache_config,
quant_config=quant_config)
return LLMWrapper(MiniCPMModel(config,
cache_config=cache_config,
quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module:
# TODO :refactor this vision model
@ -697,6 +709,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return model
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_tokens(input_ids)
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
with set_default_torch_dtype(torch.float16):
resampler = Resampler2(
@ -743,7 +758,34 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
return "resampler" in name or "vpm" in name
class MiniCPMV2_5(MiniCPMVBaseModel):
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
# LoRA specific attributes
supported_lora_modules = [
# vision encoder
"fc1",
"fc2",
"out_proj",
# language model
"qkv_proj", # same name with vision encoder
"o_proj",
"gate_up_proj",
"down_proj",
# resampler
"kv_proj",
]
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
@ -751,6 +793,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__(config, multimodal_config, cache_config, quant_config)
assert self.version == (2, 5)
@ -761,9 +804,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
return LlamaModel(config,
cache_config=cache_config,
quant_config=quant_config)
return LLMWrapper(LlamaModel(config,
cache_config=cache_config,
quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module:
model = Idefics2VisionTransformer(self.config.vision_config)
@ -843,9 +887,11 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> nn.Module:
return Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config)
return LLMWrapper(Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config),
name="model")
def init_vision_module(self) -> nn.Module:
# A custom version of SiglipVisionTransformer, won't work with TP
@ -870,7 +916,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
num_heads=embed_dim // 128,
kv_dim=vision_dim,
)
return resampler
def get_vision_embedding(
@ -934,20 +979,25 @@ _SUPPORT_VERSION = {
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
class MiniCPMV(MiniCPMVBaseModel):
class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
"""
Different versions of MiniCPMV use different visual encoders and LLMs,
which is not conducive to the current integration logic of LoRA and
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
"""
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
embedding_padding_modules = []
def __new__(
cls,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __new__(cls,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
if not hasattr(config, "version"):
if config.hidden_size == 2304 and config.query_num == 64:
version = (2, 0)

View File

@ -0,0 +1,69 @@
# Adapted from
# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py
from dataclasses import dataclass, field
from typing import List, Union
@dataclass
class ModelKeys:
model_type: str = None
module_list: str = None
embedding: str = None
mlp: str = None
down_proj: str = None
attention: str = None
o_proj: str = None
q_proj: str = None
k_proj: str = None
v_proj: str = None
qkv_proj: str = None
qk_proj: str = None
qa_proj: str = None
qb_proj: str = None
kva_proj: str = None
kvb_proj: str = None
output: str = None
@dataclass
class MultiModelKeys(ModelKeys):
language_model: List[str] = field(default_factory=list)
connector: List[str] = field(default_factory=list)
# vision tower and audio tower
tower_model: List[str] = field(default_factory=list)
generator: List[str] = field(default_factory=list)
@staticmethod
def from_string_field(language_model: Union[str, List[str]] = None,
connector: Union[str, List[str]] = None,
tower_model: Union[str, List[str]] = None,
generator: Union[str, List[str]] = None,
**kwargs) -> 'MultiModelKeys':
def to_list(value):
if value is None:
return []
return [value] if isinstance(value, str) else list(value)
return MultiModelKeys(language_model=to_list(language_model),
connector=to_list(connector),
tower_model=to_list(tower_model),
generator=to_list(generator),
**kwargs)

View File

@ -1,7 +1,7 @@
import itertools
from collections import UserDict
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol,
Tuple, Union, overload)
import torch
import torch.nn as nn
@ -329,3 +329,21 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
})
return make_empty_intermediate_tensors
class LLMWrapper(nn.Module):
"""
To align with the key names of LoRA trained with PEFT, we need to add an
additional layer to the llm's implementation.
"""
def __init__(self, llm: nn.Module, name: str) -> None:
super().__init__()
self.model_name = name
setattr(self, name, llm)
def forward(self, *args, **kwargs) -> Any:
return getattr(self, self.model_name)(*args, **kwargs)
def embed_tokens(self, *args, **kwargs) -> Any:
return getattr(self, self.model_name).embed_tokens(*args, **kwargs)

View File

@ -1034,10 +1034,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.model_memory_usage / float(2**30))
if self.lora_config:
assert supports_lora(self.model), "Model does not support LoRA"
assert not supports_multimodal(
assert supports_lora(
self.model
), "To be tested: Multi-modal model with LoRA settings."
), f"{self.model.__class__.__name__} does not support LoRA yet."
if supports_multimodal(self.model):
logger.warning("Regarding multimodal models, vLLM currently "
"only supports adding LoRA to language model.")
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,