[Model][VLM] Support R-4B Model (#23246)

Signed-off-by: yannqi <yannqi@qq.com>
Signed-off-by: 杨奇(yann qi) <51905299+yannqi@users.noreply.github.com>
Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: yannqiyang <yannqiyang@tencent.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
杨奇(yann qi) 2025-08-21 12:08:52 +08:00 committed by GitHub
parent f94bf9b924
commit 655a09f653
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 165 additions and 0 deletions

View File

@ -652,6 +652,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
| `RForConditionalGeneration` | R-VL-4B | T + I<sup>E+</sup> | `YannQi/R-4B` | | ✅︎ | ✅︎ |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ | | `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ |

View File

@ -1436,6 +1436,28 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
) )
# R-4B
def run_r_vl(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
model_name = "YannQi/R-4B"
prompts = [
f"<|im_start|>user <image>\n{question}<|im_end|><|im_start|>assistant\n"
for question in questions
]
engine_args = EngineArgs(
model=model_name,
max_model_len=16384,
limit_mm_per_prompt={modality: 1},
)
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)
# SkyworkR1V # SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image" assert modality == "image"
@ -1622,6 +1644,7 @@ model_example_map = {
"qwen2_vl": run_qwen2_vl, "qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl, "qwen2_5_vl": run_qwen2_5_vl,
"qwen2_5_omni": run_qwen2_5_omni, "qwen2_5_omni": run_qwen2_5_omni,
"rvl": run_r_vl,
"skywork_chat": run_skyworkr1v, "skywork_chat": run_skyworkr1v,
"smolvlm": run_smolvlm, "smolvlm": run_smolvlm,
"step3": run_step3, "step3": run_step3,

View File

@ -992,6 +992,39 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
) )
def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "YannQi/R-4B"
engine_args = EngineArgs(
model=model_name,
max_model_len=16384,
max_num_seqs=16,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
}
]
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData: def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct" model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
@ -1193,6 +1226,7 @@ model_example_map = {
"qwen_vl_chat": load_qwen_vl_chat, "qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl, "qwen2_vl": load_qwen2_vl,
"qwen2_5_vl": load_qwen2_5_vl, "qwen2_5_vl": load_qwen2_5_vl,
"rvl": load_r_vl,
"smolvlm": load_smolvlm, "smolvlm": load_smolvlm,
"step3": load_step3, "step3": load_step3,
"tarsier": load_tarsier, "tarsier": load_tarsier,

View File

@ -316,6 +316,7 @@ def _test_processing_correctness_one(
"Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct",
"Qwen/Qwen2.5-Omni-3B", "Qwen/Qwen2.5-Omni-3B",
"YannQi/R-4B",
"Skywork/Skywork-R1V-38B", "Skywork/Skywork-R1V-38B",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct", "HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"stepfun-ai/step3", "stepfun-ai/step3",

View File

@ -489,6 +489,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_model_len=4096), max_model_len=4096),
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
"RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
trust_remote_code=True), trust_remote_code=True),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501 "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501

View File

@ -217,6 +217,7 @@ _MULTIMODAL_MODELS = {
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),

View File

@ -0,0 +1,103 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
import torch
import torch.nn as nn
from transformers.activations import GELUActivation
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict
from .llava_next import (LlavaDummyInputsBuilder, LlavaNextMultiModalProcessor,
LlavaNextProcessingInfo)
from .llava_onevision import LlavaOnevisionForConditionalGeneration
from .utils import WeightsMapper
class RVLProcessingInfo(LlavaNextProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(**kwargs)
class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
image_token = "<image>"
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = (
self.info.get_image_size_with_most_features())
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
}
class RVLMultiModalProjector(nn.Module):
def __init__(self, config):
super().__init__()
self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size,
eps=1e-06)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size,
config.text_config.hidden_size,
bias=True,
)
self.act = GELUActivation()
self.linear_2 = nn.Linear(
config.text_config.hidden_size,
config.text_config.hidden_size,
bias=True,
)
def forward(self, image_feature: torch.Tensor) -> torch.Tensor:
image_feature = self.pre_norm(image_feature)
hidden_states = self.linear_1(image_feature)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
@MULTIMODAL_REGISTRY.register_processor(
LlavaNextMultiModalProcessor,
info=RVLProcessingInfo,
dummy_inputs=RVLDummyInputsBuilder,
)
class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers
# v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"model.image_newline": "image_newline",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__(vllm_config=vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config
self.multi_modal_projector = RVLMultiModalProjector(config)