mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:45:19 +08:00
[Model][LoRA]LoRA support added for MiniCPMV2.5 (#7199)
This commit is contained in:
parent
bc2ef1f77c
commit
3d49776bbb
@ -194,6 +194,11 @@ def baichuan_zero_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def minicpmv_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tinyllama_lora_files():
|
||||
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
||||
|
||||
71
tests/lora/test_minicpmv.py
Normal file
71
tests/lora/test_minicpmv.py
Normal file
@ -0,0 +1,71 @@
|
||||
from typing import List
|
||||
|
||||
import vllm
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
|
||||
|
||||
PROMPT_TEMPLATE = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
|
||||
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n")
|
||||
|
||||
IMAGE_ASSETS = [
|
||||
ImageAsset("stop_sign"),
|
||||
ImageAsset("cherry_blossom"),
|
||||
]
|
||||
|
||||
# After fine-tuning with LoRA, all generated content should start begin `A`.
|
||||
EXPECTED_OUTPUT = [
|
||||
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
|
||||
"A pink cherry blossom tree with a blue sky in the background.",
|
||||
]
|
||||
|
||||
|
||||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
sampling_params = vllm.SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=5,
|
||||
stop_token_ids=[128001, 128009], # eos_id, eot_id
|
||||
)
|
||||
|
||||
inputs = [{
|
||||
"prompt": PROMPT_TEMPLATE,
|
||||
"multi_modal_data": {
|
||||
"image": asset.pil_image
|
||||
},
|
||||
} for asset in IMAGE_ASSETS]
|
||||
|
||||
outputs = llm.generate(
|
||||
inputs,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
if lora_id else None,
|
||||
)
|
||||
# Print the outputs.
|
||||
generated_texts: List[str] = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
def test_minicpmv_lora(minicpmv_lora_files):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
max_num_seqs=2,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
|
||||
for i in range(len(EXPECTED_OUTPUT)):
|
||||
assert EXPECTED_OUTPUT[i].startswith(output1[i])
|
||||
output2 = do_sample(llm, minicpmv_lora_files, lora_id=2)
|
||||
for i in range(len(EXPECTED_OUTPUT)):
|
||||
assert EXPECTED_OUTPUT[i].startswith(output2[i])
|
||||
95
tests/lora/test_minicpmv_tp.py
Normal file
95
tests/lora/test_minicpmv_tp.py
Normal file
@ -0,0 +1,95 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5"
|
||||
|
||||
PROMPT_TEMPLATE = (
|
||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n"
|
||||
"(<image>./</image>)\nWhat is in the image?<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n")
|
||||
|
||||
IMAGE_ASSETS = [
|
||||
ImageAsset("stop_sign"),
|
||||
ImageAsset("cherry_blossom"),
|
||||
]
|
||||
|
||||
# After fine-tuning with LoRA, all generated content should start begin `A`.
|
||||
EXPECTED_OUTPUT = [
|
||||
"A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501
|
||||
"A pink cherry blossom tree with a blue sky in the background.",
|
||||
]
|
||||
|
||||
|
||||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
sampling_params = vllm.SamplingParams(
|
||||
temperature=0,
|
||||
max_tokens=5,
|
||||
stop_token_ids=[128001, 128009], # eos_id, eot_id
|
||||
)
|
||||
|
||||
inputs = [{
|
||||
"prompt": PROMPT_TEMPLATE,
|
||||
"multi_modal_data": {
|
||||
"image": asset.pil_image
|
||||
},
|
||||
} for asset in IMAGE_ASSETS]
|
||||
|
||||
outputs = llm.generate(
|
||||
inputs,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
if lora_id else None,
|
||||
)
|
||||
# Print the outputs.
|
||||
generated_texts: List[str] = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text.strip()
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("fully_sharded", [True, False])
|
||||
def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_num_seqs=2,
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
tensor_parallel_size=2,
|
||||
trust_remote_code=True,
|
||||
fully_sharded_loras=fully_sharded,
|
||||
)
|
||||
|
||||
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
|
||||
|
||||
for i in range(len(EXPECTED_OUTPUT)):
|
||||
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
@pytest.mark.parametrize("fully_sharded", [True, False])
|
||||
def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_num_seqs=2,
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
tensor_parallel_size=4,
|
||||
trust_remote_code=True,
|
||||
fully_sharded_loras=fully_sharded,
|
||||
)
|
||||
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
|
||||
for i in range(len(EXPECTED_OUTPUT)):
|
||||
assert EXPECTED_OUTPUT[i].startswith(output_tp[i])
|
||||
@ -24,7 +24,9 @@ from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.punica import PunicaWrapper
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||
from vllm.model_executor.models.interfaces import (SupportsLoRA,
|
||||
supports_multimodal)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.utils import PPMissingLayer
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
@ -332,6 +334,8 @@ class LoRAModelManager(AdapterModelManager):
|
||||
self.supported_lora_modules.append("rotary_emb")
|
||||
self.packed_modules_mapping = copy.deepcopy(
|
||||
self.model.packed_modules_mapping)
|
||||
# Used to indicate whether the model is a multimodal model
|
||||
self.supports_mm: bool = supports_multimodal(self.model)
|
||||
self.packed_modules: Dict[str, List[str]] = {}
|
||||
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
@ -437,12 +441,22 @@ class LoRAModelManager(AdapterModelManager):
|
||||
continue
|
||||
if not self._match_target_modules(module_name):
|
||||
continue
|
||||
# A temporary approach for multimodal models to support LoRA
|
||||
# TODO: Remove this restriction
|
||||
if self._filter_unsupported_mm_module(module_name):
|
||||
logger.warning(
|
||||
"Regarding multimodal models, vLLM currently only supports "
|
||||
"adding LoRA to language model, %s will be ignored.",
|
||||
module_name,
|
||||
)
|
||||
continue
|
||||
parts = module_name.split(".")[-1]
|
||||
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
|
||||
new_module = replace_submodule(
|
||||
self.model, module_name,
|
||||
from_layer(module, self.lora_slots, self.lora_config,
|
||||
packed_moduled_lst, self.model.config))
|
||||
|
||||
# LinearScalingRotaryEmbeddingWithLora is used to handle
|
||||
# long context lora. Register relevant metadata.
|
||||
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
|
||||
@ -460,6 +474,15 @@ class LoRAModelManager(AdapterModelManager):
|
||||
module, self.lora_slots,
|
||||
self.lora_config,
|
||||
self.model.config))
|
||||
|
||||
# In some models, especially multimodal ones, layers with the same
|
||||
# name may have different types, such as nn.Linear and
|
||||
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
|
||||
# LoRA layers, leading to assertion error. The following check
|
||||
# aims to prevent this error
|
||||
if self.supports_mm and not isinstance(new_module,
|
||||
BaseLayerWithLoRA):
|
||||
continue
|
||||
self.register_module(module_name, new_module)
|
||||
self._register_packed_modules(module_name)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
@ -478,9 +501,10 @@ class LoRAModelManager(AdapterModelManager):
|
||||
"""Create zero-initialized LoRAModel for warmup."""
|
||||
model = LoRAModel(lora_id, rank, {}, scaling_factor)
|
||||
for module_name, module in self.model.named_modules():
|
||||
if not self._match_target_modules(module_name) or not isinstance(
|
||||
module, BaseLayerWithLoRA) or isinstance(
|
||||
module, LinearScalingRotaryEmbeddingWithLora):
|
||||
if (not self._match_target_modules(module_name)
|
||||
or not isinstance(module, BaseLayerWithLoRA)
|
||||
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
|
||||
or self._filter_unsupported_mm_module(module_name)):
|
||||
continue
|
||||
parts = module_name.split(".")
|
||||
if module_name not in self.packed_modules:
|
||||
@ -541,6 +565,19 @@ class LoRAModelManager(AdapterModelManager):
|
||||
module_name) or target_module == module_name
|
||||
for target_module in self.supported_lora_modules)
|
||||
|
||||
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
|
||||
"""
|
||||
Regarding multimodal models, vLLM currently only supports adding LoRA to
|
||||
language model. LoRA for other modules, such as the vision tower, will
|
||||
be filtered out.
|
||||
"""
|
||||
if self.supports_mm:
|
||||
prefix = module_name.split(".")[0]
|
||||
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
return (prefix in module_mapping.connector
|
||||
or prefix in module_mapping.tower_model)
|
||||
return False
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
parts = module_full_name.split(".")
|
||||
module_name = parts[-1]
|
||||
|
||||
@ -36,7 +36,7 @@ from transformers import PretrainedConfig
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@ -50,7 +50,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.models.minicpm import MiniCPMModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.models.utils import LLMWrapper
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
@ -59,10 +61,10 @@ from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SequenceData
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"llm.lm_head": "lm_head",
|
||||
"llm.model": "llm",
|
||||
}
|
||||
|
||||
|
||||
@ -621,6 +623,14 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(language_model="llm",
|
||||
connector="resampler",
|
||||
tower_model="vpm")
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@ -669,9 +679,11 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> nn.Module:
|
||||
return MiniCPMModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
return LLMWrapper(MiniCPMModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(self) -> nn.Module:
|
||||
# TODO :refactor this vision model
|
||||
@ -697,6 +709,9 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
|
||||
return model
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_tokens(input_ids)
|
||||
|
||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
||||
with set_default_torch_dtype(torch.float16):
|
||||
resampler = Resampler2(
|
||||
@ -743,7 +758,34 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
||||
return "resampler" in name or "vpm" in name
|
||||
|
||||
|
||||
class MiniCPMV2_5(MiniCPMVBaseModel):
|
||||
class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
# vision encoder
|
||||
"fc1",
|
||||
"fc2",
|
||||
"out_proj",
|
||||
# language model
|
||||
"qkv_proj", # same name with vision encoder
|
||||
"o_proj",
|
||||
"gate_up_proj",
|
||||
"down_proj",
|
||||
# resampler
|
||||
"kv_proj",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -751,6 +793,7 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__(config, multimodal_config, cache_config, quant_config)
|
||||
assert self.version == (2, 5)
|
||||
@ -761,9 +804,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> nn.Module:
|
||||
return LlamaModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
return LLMWrapper(LlamaModel(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(self) -> nn.Module:
|
||||
model = Idefics2VisionTransformer(self.config.vision_config)
|
||||
@ -843,9 +887,11 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> nn.Module:
|
||||
return Qwen2Model(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
return LLMWrapper(Qwen2Model(config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config),
|
||||
name="model")
|
||||
|
||||
def init_vision_module(self) -> nn.Module:
|
||||
# A custom version of SiglipVisionTransformer, won't work with TP
|
||||
@ -870,7 +916,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
num_heads=embed_dim // 128,
|
||||
kv_dim=vision_dim,
|
||||
)
|
||||
|
||||
return resampler
|
||||
|
||||
def get_vision_embedding(
|
||||
@ -934,20 +979,25 @@ _SUPPORT_VERSION = {
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv)
|
||||
class MiniCPMV(MiniCPMVBaseModel):
|
||||
class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
|
||||
"""
|
||||
Different versions of MiniCPMV use different visual encoders and LLMs,
|
||||
which is not conducive to the current integration logic of LoRA and
|
||||
bitsandbytes in vLLM. Therefore, it is necessary to separate them.
|
||||
"""
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
def __new__(cls,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None):
|
||||
if not hasattr(config, "version"):
|
||||
if config.hidden_size == 2304 and config.query_num == 64:
|
||||
version = (2, 0)
|
||||
|
||||
69
vllm/model_executor/models/module_mapping.py
Normal file
69
vllm/model_executor/models/module_mapping.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Adapted from
|
||||
# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Union
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelKeys:
|
||||
model_type: str = None
|
||||
|
||||
module_list: str = None
|
||||
|
||||
embedding: str = None
|
||||
|
||||
mlp: str = None
|
||||
|
||||
down_proj: str = None
|
||||
|
||||
attention: str = None
|
||||
|
||||
o_proj: str = None
|
||||
|
||||
q_proj: str = None
|
||||
|
||||
k_proj: str = None
|
||||
|
||||
v_proj: str = None
|
||||
|
||||
qkv_proj: str = None
|
||||
|
||||
qk_proj: str = None
|
||||
|
||||
qa_proj: str = None
|
||||
|
||||
qb_proj: str = None
|
||||
|
||||
kva_proj: str = None
|
||||
|
||||
kvb_proj: str = None
|
||||
|
||||
output: str = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModelKeys(ModelKeys):
|
||||
language_model: List[str] = field(default_factory=list)
|
||||
connector: List[str] = field(default_factory=list)
|
||||
# vision tower and audio tower
|
||||
tower_model: List[str] = field(default_factory=list)
|
||||
generator: List[str] = field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def from_string_field(language_model: Union[str, List[str]] = None,
|
||||
connector: Union[str, List[str]] = None,
|
||||
tower_model: Union[str, List[str]] = None,
|
||||
generator: Union[str, List[str]] = None,
|
||||
**kwargs) -> 'MultiModelKeys':
|
||||
|
||||
def to_list(value):
|
||||
if value is None:
|
||||
return []
|
||||
return [value] if isinstance(value, str) else list(value)
|
||||
|
||||
return MultiModelKeys(language_model=to_list(language_model),
|
||||
connector=to_list(connector),
|
||||
tower_model=to_list(tower_model),
|
||||
generator=to_list(generator),
|
||||
**kwargs)
|
||||
@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
from collections import UserDict
|
||||
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
|
||||
Union, overload)
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol,
|
||||
Tuple, Union, overload)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -329,3 +329,21 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
|
||||
})
|
||||
|
||||
return make_empty_intermediate_tensors
|
||||
|
||||
|
||||
class LLMWrapper(nn.Module):
|
||||
"""
|
||||
To align with the key names of LoRA trained with PEFT, we need to add an
|
||||
additional layer to the llm's implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, llm: nn.Module, name: str) -> None:
|
||||
super().__init__()
|
||||
self.model_name = name
|
||||
setattr(self, name, llm)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
return getattr(self, self.model_name)(*args, **kwargs)
|
||||
|
||||
def embed_tokens(self, *args, **kwargs) -> Any:
|
||||
return getattr(self, self.model_name).embed_tokens(*args, **kwargs)
|
||||
|
||||
@ -1034,10 +1034,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.model_memory_usage / float(2**30))
|
||||
|
||||
if self.lora_config:
|
||||
assert supports_lora(self.model), "Model does not support LoRA"
|
||||
assert not supports_multimodal(
|
||||
assert supports_lora(
|
||||
self.model
|
||||
), "To be tested: Multi-modal model with LoRA settings."
|
||||
), f"{self.model.__class__.__name__} does not support LoRA yet."
|
||||
if supports_multimodal(self.model):
|
||||
logger.warning("Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user