From 655a09f6538e6b09af23771dcc4fcebd72a15b23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=A5=87=28yann=20qi=29?= <51905299+yannqi@users.noreply.github.com> Date: Thu, 21 Aug 2025 12:08:52 +0800 Subject: [PATCH] [Model][VLM] Support R-4B Model (#23246) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: yannqi Signed-off-by: 杨奇(yann qi) <51905299+yannqi@users.noreply.github.com> Signed-off-by: Cyrus Leung Co-authored-by: yannqiyang Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung --- docs/models/supported_models.md | 1 + examples/offline_inference/vision_language.py | 23 ++++ .../vision_language_multi_image.py | 34 ++++++ .../multimodal/processing/test_common.py | 1 + tests/models/registry.py | 2 + vllm/model_executor/models/registry.py | 1 + vllm/model_executor/models/rvl.py | 103 ++++++++++++++++++ 7 files changed, 165 insertions(+) create mode 100644 vllm/model_executor/models/rvl.py diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 7308d0010690..831bfb1e939e 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -652,6 +652,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + IE+ + VE+ | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | +| `RForConditionalGeneration` | R-VL-4B | T + IE+ | `YannQi/R-4B` | | ✅︎ | ✅︎ | | `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | | `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | | `Step3VLForConditionalGeneration` | Step3-VL | T + I+ | `stepfun-ai/step3` | | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 88bbbfdfbd18..e7a7a30dd31a 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -1436,6 +1436,28 @@ def run_qwen2_5_omni(questions: list[str], modality: str): ) +# R-4B +def run_r_vl(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + model_name = "YannQi/R-4B" + + prompts = [ + f"<|im_start|>user \n{question}<|im_end|><|im_start|>assistant\n" + for question in questions + ] + + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + limit_mm_per_prompt={modality: 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # SkyworkR1V def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1622,6 +1644,7 @@ model_example_map = { "qwen2_vl": run_qwen2_vl, "qwen2_5_vl": run_qwen2_5_vl, "qwen2_5_omni": run_qwen2_5_omni, + "rvl": run_r_vl, "skywork_chat": run_skyworkr1v, "smolvlm": run_smolvlm, "step3": run_step3, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index eabd9453f3c5..d9242efa8547 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -992,6 +992,39 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "YannQi/R-4B" + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + max_num_seqs=16, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct" @@ -1193,6 +1226,7 @@ model_example_map = { "qwen_vl_chat": load_qwen_vl_chat, "qwen2_vl": load_qwen2_vl, "qwen2_5_vl": load_qwen2_5_vl, + "rvl": load_r_vl, "smolvlm": load_smolvlm, "step3": load_step3, "tarsier": load_tarsier, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 02aecfad8281..adc8b2510d67 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -316,6 +316,7 @@ def _test_processing_correctness_one( "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", "Qwen/Qwen2.5-Omni-3B", + "YannQi/R-4B", "Skywork/Skywork-R1V-38B", "HuggingFaceTB/SmolVLM2-2.2B-Instruct", "stepfun-ai/step3", diff --git a/tests/models/registry.py b/tests/models/registry.py index 6e6acfb8cd22..4f69f90b6aae 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -489,6 +489,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { max_model_len=4096), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 + "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B", + trust_remote_code=True), "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B", trust_remote_code=True), "SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501 diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 78ef270598b8..39a3e425a46d 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -217,6 +217,7 @@ _MULTIMODAL_MODELS = { "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), + "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), diff --git a/vllm/model_executor/models/rvl.py b/vllm/model_executor/models/rvl.py new file mode 100644 index 000000000000..efdb01004663 --- /dev/null +++ b/vllm/model_executor/models/rvl.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Mapping + +import torch +import torch.nn as nn +from transformers.activations import GELUActivation + +from vllm.config import VllmConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalDataDict + +from .llava_next import (LlavaDummyInputsBuilder, LlavaNextMultiModalProcessor, + LlavaNextProcessingInfo) +from .llava_onevision import LlavaOnevisionForConditionalGeneration +from .utils import WeightsMapper + + +class RVLProcessingInfo(LlavaNextProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self, **kwargs: object): + return self.ctx.get_hf_processor(**kwargs) + + +class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + image_token = "" + + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = ( + self.info.get_image_size_with_most_features()) + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + + +class RVLMultiModalProjector(nn.Module): + + def __init__(self, config): + super().__init__() + self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size, + eps=1e-06) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + self.act = GELUActivation() + self.linear_2 = nn.Linear( + config.text_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + + def forward(self, image_feature: torch.Tensor) -> torch.Tensor: + image_feature = self.pre_norm(image_feature) + hidden_states = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states + + +@MULTIMODAL_REGISTRY.register_processor( + LlavaNextMultiModalProcessor, + info=RVLProcessingInfo, + dummy_inputs=RVLDummyInputsBuilder, +) +class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration): + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + # mapping for new names in checkpoint saved after transformers + # v4.52 + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.multi_modal_projector.": "multi_modal_projector.", + "model.image_newline": "image_newline", + "lm_head.": "language_model.lm_head.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config + self.multi_modal_projector = RVLMultiModalProjector(config)