mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 02:55:40 +08:00
[Model][VLM] Support R-4B Model (#23246)
Signed-off-by: yannqi <yannqi@qq.com> Signed-off-by: 杨奇(yann qi) <51905299+yannqi@users.noreply.github.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: yannqiyang <yannqiyang@tencent.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
f94bf9b924
commit
655a09f653
@ -652,6 +652,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
|
||||
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ |
|
||||
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
|
||||
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
|
||||
| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
|
||||
|
||||
@ -1436,6 +1436,28 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
|
||||
)
|
||||
|
||||
|
||||
# R-4B
|
||||
def run_r_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
model_name = "YannQi/R-4B"
|
||||
|
||||
prompts = [
|
||||
f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n"
|
||||
for question in questions
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=16384,
|
||||
limit_mm_per_prompt={modality: 1},
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompts=prompts,
|
||||
)
|
||||
|
||||
|
||||
# SkyworkR1V
|
||||
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||
assert modality == "image"
|
||||
@ -1622,6 +1644,7 @@ model_example_map = {
|
||||
"qwen2_vl": run_qwen2_vl,
|
||||
"qwen2_5_vl": run_qwen2_5_vl,
|
||||
"qwen2_5_omni": run_qwen2_5_omni,
|
||||
"rvl": run_r_vl,
|
||||
"skywork_chat": run_skyworkr1v,
|
||||
"smolvlm": run_smolvlm,
|
||||
"step3": run_step3,
|
||||
|
||||
@ -992,6 +992,39 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
)
|
||||
|
||||
|
||||
def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "YannQi/R-4B"
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=16384,
|
||||
max_num_seqs=16,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt=prompt,
|
||||
image_data=[fetch_image(url) for url in image_urls],
|
||||
)
|
||||
|
||||
|
||||
def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
|
||||
|
||||
@ -1193,6 +1226,7 @@ model_example_map = {
|
||||
"qwen_vl_chat": load_qwen_vl_chat,
|
||||
"qwen2_vl": load_qwen2_vl,
|
||||
"qwen2_5_vl": load_qwen2_5_vl,
|
||||
"rvl": load_r_vl,
|
||||
"smolvlm": load_smolvlm,
|
||||
"step3": load_step3,
|
||||
"tarsier": load_tarsier,
|
||||
|
||||
@ -316,6 +316,7 @@ def _test_processing_correctness_one(
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"Qwen/Qwen2.5-Omni-3B",
|
||||
"YannQi/R-4B",
|
||||
"Skywork/Skywork-R1V-38B",
|
||||
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
|
||||
"stepfun-ai/step3",
|
||||
|
||||
@ -489,6 +489,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
max_model_len=4096),
|
||||
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"),
|
||||
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
|
||||
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
|
||||
trust_remote_code=True),
|
||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
|
||||
trust_remote_code=True),
|
||||
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501
|
||||
|
||||
@ -217,6 +217,7 @@ _MULTIMODAL_MODELS = {
|
||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
|
||||
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
|
||||
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
|
||||
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
|
||||
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
|
||||
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
|
||||
|
||||
103
vllm/model_executor/models/rvl.py
Normal file
103
vllm/model_executor/models/rvl.py
Normal file
@ -0,0 +1,103 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.activations import GELUActivation
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalDataDict
|
||||
|
||||
from .llava_next import (LlavaDummyInputsBuilder, LlavaNextMultiModalProcessor,
|
||||
LlavaNextProcessingInfo)
|
||||
from .llava_onevision import LlavaOnevisionForConditionalGeneration
|
||||
from .utils import WeightsMapper
|
||||
|
||||
|
||||
class RVLProcessingInfo(LlavaNextProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config()
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(**kwargs)
|
||||
|
||||
|
||||
class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
image_token = "<image>"
|
||||
|
||||
return image_token * num_images
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
target_width, target_height = (
|
||||
self.info.get_image_size_with_most_features())
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images),
|
||||
}
|
||||
|
||||
|
||||
class RVLMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size,
|
||||
eps=1e-06)
|
||||
self.linear_1 = nn.Linear(
|
||||
config.vision_config.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
self.act = GELUActivation()
|
||||
self.linear_2 = nn.Linear(
|
||||
config.text_config.hidden_size,
|
||||
config.text_config.hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, image_feature: torch.Tensor) -> torch.Tensor:
|
||||
image_feature = self.pre_norm(image_feature)
|
||||
hidden_states = self.linear_1(image_feature)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
LlavaNextMultiModalProcessor,
|
||||
info=RVLProcessingInfo,
|
||||
dummy_inputs=RVLDummyInputsBuilder,
|
||||
)
|
||||
class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration):
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
# mapping for new names in checkpoint saved after transformers
|
||||
# v4.52
|
||||
"model.language_model.": "language_model.model.",
|
||||
"model.vision_tower.": "vision_tower.",
|
||||
"model.multi_modal_projector.": "multi_modal_projector.",
|
||||
"model.image_newline": "image_newline",
|
||||
"lm_head.": "language_model.lm_head.",
|
||||
})
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.multi_modal_projector = RVLMultiModalProjector(config)
|
||||
Loading…
x
Reference in New Issue
Block a user