diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 7308d0010690..831bfb1e939e 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -652,6 +652,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + IE+ + VE+ | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + IE+ + VE+ | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + IE+ + VE+ + A+ | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
+| `RForConditionalGeneration` | R-VL-4B | T + IE+ | `YannQi/R-4B` | | ✅︎ | ✅︎ |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
| `Step3VLForConditionalGeneration` | Step3-VL | T + I+ | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 88bbbfdfbd18..e7a7a30dd31a 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -1436,6 +1436,28 @@ def run_qwen2_5_omni(questions: list[str], modality: str):
)
+# R-4B
+def run_r_vl(questions: list[str], modality: str) -> ModelRequestData:
+ assert modality == "image"
+ model_name = "YannQi/R-4B"
+
+ prompts = [
+ f"<|im_start|>user \n{question}<|im_end|><|im_start|>assistant\n"
+ for question in questions
+ ]
+
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=16384,
+ limit_mm_per_prompt={modality: 1},
+ )
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompts=prompts,
+ )
+
+
# SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
@@ -1622,6 +1644,7 @@ model_example_map = {
"qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl,
"qwen2_5_omni": run_qwen2_5_omni,
+ "rvl": run_r_vl,
"skywork_chat": run_skyworkr1v,
"smolvlm": run_smolvlm,
"step3": run_step3,
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index eabd9453f3c5..d9242efa8547 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -992,6 +992,39 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
)
+def load_r_vl(question: str, image_urls: list[str]) -> ModelRequestData:
+ model_name = "YannQi/R-4B"
+ engine_args = EngineArgs(
+ model=model_name,
+ max_model_len=16384,
+ max_num_seqs=16,
+ limit_mm_per_prompt={"image": len(image_urls)},
+ )
+
+ placeholders = [{"type": "image", "image": url} for url in image_urls]
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ *placeholders,
+ {"type": "text", "text": question},
+ ],
+ }
+ ]
+
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
+
+ prompt = processor.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ return ModelRequestData(
+ engine_args=engine_args,
+ prompt=prompt,
+ image_data=[fetch_image(url) for url in image_urls],
+ )
+
+
def load_smolvlm(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "HuggingFaceTB/SmolVLM2-2.2B-Instruct"
@@ -1193,6 +1226,7 @@ model_example_map = {
"qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl,
"qwen2_5_vl": load_qwen2_5_vl,
+ "rvl": load_r_vl,
"smolvlm": load_smolvlm,
"step3": load_step3,
"tarsier": load_tarsier,
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 02aecfad8281..adc8b2510d67 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -316,6 +316,7 @@ def _test_processing_correctness_one(
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"Qwen/Qwen2.5-Omni-3B",
+ "YannQi/R-4B",
"Skywork/Skywork-R1V-38B",
"HuggingFaceTB/SmolVLM2-2.2B-Instruct",
"stepfun-ai/step3",
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 6e6acfb8cd22..4f69f90b6aae 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -489,6 +489,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
max_model_len=4096),
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"),
"Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501
+ "RForConditionalGeneration": _HfExamplesInfo("YannQi/R-4B",
+ trust_remote_code=True),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
trust_remote_code=True),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct", # noqa: E501
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 78ef270598b8..39a3e425a46d 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -217,6 +217,7 @@ _MULTIMODAL_MODELS = {
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
+ "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
diff --git a/vllm/model_executor/models/rvl.py b/vllm/model_executor/models/rvl.py
new file mode 100644
index 000000000000..efdb01004663
--- /dev/null
+++ b/vllm/model_executor/models/rvl.py
@@ -0,0 +1,103 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Mapping
+
+import torch
+import torch.nn as nn
+from transformers.activations import GELUActivation
+
+from vllm.config import VllmConfig
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import MultiModalDataDict
+
+from .llava_next import (LlavaDummyInputsBuilder, LlavaNextMultiModalProcessor,
+ LlavaNextProcessingInfo)
+from .llava_onevision import LlavaOnevisionForConditionalGeneration
+from .utils import WeightsMapper
+
+
+class RVLProcessingInfo(LlavaNextProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config()
+
+ def get_hf_processor(self, **kwargs: object):
+ return self.ctx.get_hf_processor(**kwargs)
+
+
+class RVLDummyInputsBuilder(LlavaDummyInputsBuilder[RVLProcessingInfo]):
+
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_images = mm_counts.get("image", 0)
+ image_token = ""
+
+ return image_token * num_images
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ ) -> MultiModalDataDict:
+ num_images = mm_counts.get("image", 0)
+
+ target_width, target_height = (
+ self.info.get_image_size_with_most_features())
+
+ return {
+ "image":
+ self._get_dummy_images(width=target_width,
+ height=target_height,
+ num_images=num_images),
+ }
+
+
+class RVLMultiModalProjector(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+ self.pre_norm = nn.LayerNorm(config.vision_config.hidden_size,
+ eps=1e-06)
+ self.linear_1 = nn.Linear(
+ config.vision_config.hidden_size,
+ config.text_config.hidden_size,
+ bias=True,
+ )
+ self.act = GELUActivation()
+ self.linear_2 = nn.Linear(
+ config.text_config.hidden_size,
+ config.text_config.hidden_size,
+ bias=True,
+ )
+
+ def forward(self, image_feature: torch.Tensor) -> torch.Tensor:
+ image_feature = self.pre_norm(image_feature)
+ hidden_states = self.linear_1(image_feature)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+
+ return hidden_states
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ LlavaNextMultiModalProcessor,
+ info=RVLProcessingInfo,
+ dummy_inputs=RVLDummyInputsBuilder,
+)
+class RForConditionalGeneration(LlavaOnevisionForConditionalGeneration):
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_prefix={
+ # mapping for new names in checkpoint saved after transformers
+ # v4.52
+ "model.language_model.": "language_model.model.",
+ "model.vision_tower.": "vision_tower.",
+ "model.multi_modal_projector.": "multi_modal_projector.",
+ "model.image_newline": "image_newline",
+ "lm_head.": "language_model.lm_head.",
+ })
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
+ super().__init__(vllm_config=vllm_config, prefix=prefix)
+ config = vllm_config.model_config.hf_config
+ self.multi_modal_projector = RVLMultiModalProjector(config)