mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 02:09:33 +08:00
[model] Support MiniCPM-V 4.0 (#22166)
Co-authored-by: imning3 <hbning@pku.edu.cn>
This commit is contained in:
parent
e8961e963a
commit
41b67f4263
@ -622,7 +622,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
|||||||
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ |
|
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ |
|
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ |
|
| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, `openbmb/MiniCPM-V-4`, etc. | ✅︎ | | ✅︎ |
|
||||||
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ |
|
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
|
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
|
||||||
|
|||||||
@ -427,7 +427,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6",
|
"MiniCPMO": _HfExamplesInfo("openbmb/MiniCPM-o-2_6",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
|
"MiniCPMV": _HfExamplesInfo("openbmb/MiniCPM-Llama3-V-2_5",
|
||||||
extras={"2.6": "openbmb/MiniCPM-V-2_6"}, # noqa: E501
|
extras={"2.6": "openbmb/MiniCPM-V-2_6", "4.0": "openbmb/MiniCPM-V-4"}, # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
|
"MiniMaxVL01ForConditionalGeneration": _HfExamplesInfo("MiniMaxAI/MiniMax-VL-01", # noqa: E501
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
|||||||
@ -38,6 +38,8 @@ from typing_extensions import TypeVar
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||||
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||||
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
|
from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
|
||||||
get_2d_sincos_pos_embed)
|
get_2d_sincos_pos_embed)
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
@ -339,7 +341,9 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
|||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
mm_limits = {"image": None}
|
mm_limits = {"image": None}
|
||||||
if self.get_model_version() == (2, 6):
|
if self.get_model_version() == (2,
|
||||||
|
6) or self.get_model_version() == (4,
|
||||||
|
0):
|
||||||
mm_limits["video"] = None
|
mm_limits["video"] = None
|
||||||
|
|
||||||
return mm_limits
|
return mm_limits
|
||||||
@ -620,7 +624,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
out_keys: set[str],
|
out_keys: set[str],
|
||||||
) -> dict[str, NestedTensors]:
|
) -> dict[str, NestedTensors]:
|
||||||
# This processor supports zipping prompt and mm_data together
|
# This processor supports zipping prompt and mm_data together
|
||||||
if self.info.get_model_version() == (2, 6):
|
if self.info.get_model_version() == (
|
||||||
|
2, 6) or self.info.get_model_version() == (4, 0):
|
||||||
inputs = super()._call_hf_processor(
|
inputs = super()._call_hf_processor(
|
||||||
prompt=prompts, # type: ignore
|
prompt=prompts, # type: ignore
|
||||||
mm_data=mm_data,
|
mm_data=mm_data,
|
||||||
@ -679,10 +684,18 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
out_mm_kwargs: MultiModalKwargs,
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
) -> Sequence[PromptUpdate]:
|
) -> Sequence[PromptUpdate]:
|
||||||
placeholder = {
|
placeholders = [("image", self.info.image_pattern),
|
||||||
"image": self.info.image_pattern,
|
("video", self.info.video_pattern)]
|
||||||
"video": self.info.video_pattern,
|
|
||||||
}
|
# hard code for inconsistency of encode-decode image_pattern
|
||||||
|
additional_placeholders = []
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
for modality, pattern in placeholders:
|
||||||
|
sub_pattern = tokenizer.decode(
|
||||||
|
tokenizer.encode(pattern, add_special_tokens=False))
|
||||||
|
if sub_pattern != pattern:
|
||||||
|
additional_placeholders.append((modality, sub_pattern))
|
||||||
|
placeholders += additional_placeholders
|
||||||
|
|
||||||
def get_image_replacement(item_idx: int):
|
def get_image_replacement(item_idx: int):
|
||||||
images = mm_items.get_items(
|
images = mm_items.get_items(
|
||||||
@ -714,9 +727,9 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(modality=modality,
|
PromptReplacement(modality=modality,
|
||||||
target=placeholder[modality],
|
target=pattern,
|
||||||
replacement=get_replacement[modality])
|
replacement=get_replacement[modality])
|
||||||
for modality in ("image", "video")
|
for modality, pattern in placeholders
|
||||||
]
|
]
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
@ -1262,11 +1275,124 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
|
|
||||||
return self.resampler(vision_embedding, tgt_sizes)
|
return self.resampler(vision_embedding, tgt_sizes)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(self,
|
||||||
|
skip_prefixes=["apm.", "audio", "tts"])
|
||||||
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
|
||||||
|
class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
assert self.version == (4, 0)
|
||||||
|
|
||||||
|
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||||
|
if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)):
|
||||||
|
return None
|
||||||
|
return quant_config
|
||||||
|
|
||||||
|
def init_llm(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> nn.Module:
|
||||||
|
return LlamaForCausalLM(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
def init_vision_module(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> nn.Module:
|
||||||
|
quant_config = self._maybe_ignore_quant_config(quant_config)
|
||||||
|
model = Idefics2VisionTransformer(config.vision_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix)
|
||||||
|
if self.config.drop_vision_last_layer:
|
||||||
|
model.encoder.layers = model.encoder.layers[:-1]
|
||||||
|
return model
|
||||||
|
|
||||||
|
def init_resampler(
|
||||||
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
vision_dim: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> nn.Module:
|
||||||
|
quant_config = self._maybe_ignore_quant_config(quant_config)
|
||||||
|
with set_default_torch_dtype(torch.float16):
|
||||||
|
# The resampler in 4.0 remains consistent with the one in 2.5/2.6.
|
||||||
|
resampler = Resampler2_5(num_queries=self.config.query_num,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
num_heads=embed_dim // 128,
|
||||||
|
kv_dim=vision_dim,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix)
|
||||||
|
|
||||||
|
return resampler.to(device=current_platform.device_type,
|
||||||
|
dtype=torch.get_default_dtype())
|
||||||
|
|
||||||
|
def get_vision_hidden_states(
|
||||||
|
self, data: MiniCPMVImagePixelInputs) -> torch.Tensor:
|
||||||
|
pixel_values = data["pixel_values"]
|
||||||
|
tgt_sizes = data["tgt_sizes"]
|
||||||
|
|
||||||
|
B = len(pixel_values)
|
||||||
|
P = pixel_values[0].shape[-2]
|
||||||
|
L = max(item.shape[-1] for item in pixel_values)
|
||||||
|
device = pixel_values[0].device
|
||||||
|
dtype = pixel_values[0].dtype
|
||||||
|
|
||||||
|
all_pixel_values = torch.zeros((B, 3, P, L),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device)
|
||||||
|
for i, pixel_values_item in enumerate(pixel_values):
|
||||||
|
L_item = pixel_values_item.shape[-1]
|
||||||
|
all_pixel_values[i, ..., :L_item] = pixel_values_item
|
||||||
|
|
||||||
|
num_patches = tgt_sizes.prod(-1)
|
||||||
|
max_patches = num_patches.max().item()
|
||||||
|
assert isinstance(max_patches, int)
|
||||||
|
|
||||||
|
patch_attn_mask = torch.zeros((B, max_patches),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=device)
|
||||||
|
for i, num_patches_item in enumerate(num_patches):
|
||||||
|
patch_attn_mask[i, :num_patches_item] = True
|
||||||
|
|
||||||
|
vision_embedding = self.vpm(
|
||||||
|
all_pixel_values,
|
||||||
|
patch_attention_mask=patch_attn_mask.unsqueeze(1),
|
||||||
|
tgt_sizes=tgt_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.resampler(vision_embedding, tgt_sizes)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(self,
|
||||||
|
skip_prefixes=["apm.", "audio", "tts"])
|
||||||
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
|
||||||
_SUPPORT_VERSION = {
|
_SUPPORT_VERSION = {
|
||||||
(2, 0): MiniCPMV2_0,
|
(2, 0): MiniCPMV2_0,
|
||||||
(2, 5): MiniCPMV2_5,
|
(2, 5): MiniCPMV2_5,
|
||||||
(2, 6): MiniCPMV2_6,
|
(2, 6): MiniCPMV2_6,
|
||||||
|
(4, 0): MiniCPMV4_0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -1294,8 +1420,10 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsMultiModal, SupportsLoRA):
|
|||||||
# Dispatch class based on version
|
# Dispatch class based on version
|
||||||
instance_cls = _SUPPORT_VERSION.get(version)
|
instance_cls = _SUPPORT_VERSION.get(version)
|
||||||
if instance_cls is None:
|
if instance_cls is None:
|
||||||
raise ValueError(
|
supported_versions = ", ".join(
|
||||||
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
|
[f"{v[0]}.{v[1]}" for v in sorted(_SUPPORT_VERSION.keys())])
|
||||||
|
raise ValueError(f"Currently, MiniCPMV only supports versions "
|
||||||
|
f"{supported_versions}. Got version: {version}")
|
||||||
|
|
||||||
# quant_config references base class members,
|
# quant_config references base class members,
|
||||||
# so update values before init is called
|
# so update values before init is called
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user