diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 626904a974155..404519f887dc6 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -701,6 +701,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I+ | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | | `MolmoForCausalLM` | Molmo | T + I+ | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | | `NVLM_D_Model` | NVLM-D 1.0 | T + I+ | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | +| `OpenCUAForConditionalGeneration` | OpenCUA-7B | T + IE+ | `xlangai/OpenCUA-7B` | ✅︎ | ✅︎ | | `Ovis` | Ovis2, Ovis1.6 | T + I+ | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | | `Ovis2_5` | Ovis2.5 | T + I+ + V | `AIDC-AI/Ovis2.5-9B`, etc. | | | | `PaddleOCRVLForConditionalGeneration` | Paddle-OCR | T + I+ | `PaddlePaddle/PaddleOCR-VL`, etc. | | | diff --git a/tests/models/registry.py b/tests/models/registry.py index b088e16756d7a..758ec54493aa3 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -725,6 +725,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { "NemotronH_Nano_VL_V2": _HfExamplesInfo( "nano_vl_dummy", is_available_online=False, trust_remote_code=True ), + "OpenCUAForConditionalGeneration": _HfExamplesInfo( + "xlangai/OpenCUA-7B", trust_remote_code=True + ), "Ovis": _HfExamplesInfo( "AIDC-AI/Ovis2-1B", trust_remote_code=True, diff --git a/vllm/model_executor/models/opencua.py b/vllm/model_executor/models/opencua.py new file mode 100644 index 0000000000000..121bf896fa6ba --- /dev/null +++ b/vllm/model_executor/models/opencua.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Adapted from Qwen2.5-VL implementation +# Copyright 2025 The vLLM team. +# Copyright 2025 XLANG Lab, The University of Hong Kong + +"""Inference-only OpenCUA-7B model compatible with HuggingFace weights.""" + +from collections.abc import Mapping, Sequence +from typing import Any + +import torch +import torch.nn as nn +from transformers import BatchFeature +from transformers.models.qwen2_vl import ( + Qwen2VLImageProcessor, + Qwen2VLProcessor, + Qwen2VLVideoProcessor, +) + +from vllm.config import VllmConfig +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalFieldConfig, + MultiModalKwargs, +) +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseMultiModalProcessor, + PromptReplacement, + PromptUpdate, +) +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .qwen2_5_vl import ( + Qwen2_5_VisionTransformer as OpenCUAVisionTransformer, +) +from .qwen2_5_vl import ( + Qwen2_5_VLForConditionalGeneration, +) +from .qwen2_vl import ( + Qwen2VLDummyInputsBuilder, + Qwen2VLMultiModalDataParser, + Qwen2VLProcessingInfo, + _create_qwen2vl_field_factory, +) +from .utils import ( + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) + + +class OpenCUAProcessingInfo(Qwen2VLProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} + + def get_hf_processor(self, **kwargs: object): + """Load OpenCUA processor.""" + tokenizer = self.get_tokenizer() + vision_config = self.ctx.get_hf_image_processor_config() + return OpenCUAProcessor( + vision_config=vision_config, + tokenizer=tokenizer, + **kwargs, + ) + + +class OpenCUAProcessor(Qwen2VLProcessor): + def check_argument_for_proper_class(self, attribute_name: str, arg: object) -> None: + if attribute_name == "tokenizer": + return + return super().check_argument_for_proper_class(attribute_name, arg) + + def __init__( + self, + vision_config: dict, + tokenizer: AnyTokenizer, + **kwargs, + ): + image_processor = Qwen2VLImageProcessor(**vision_config) + video_processor = Qwen2VLVideoProcessor(**vision_config) + chat_template = kwargs.pop("chat_template", None) + + super().__init__( + image_processor=image_processor, + tokenizer=tokenizer, + video_processor=video_processor, + chat_template=chat_template, + **kwargs, + ) + + self.image_token = "<|media_placeholder|>" + + def __call__( + self, + text=None, + images=None, + return_tensors=None, + **kwargs, + ): + if text is not None: + if not isinstance(text, list): + text = [text] + text_inputs = self.tokenizer(text, **kwargs) + else: + text_inputs = {} + + image_inputs = {} + if images is not None: + if not isinstance(images, list): + images = [images] + if len(images) > 0: + image_inputs = self.image_processor( + images, return_tensors=return_tensors or "pt" + ) + + combined_inputs = {**text_inputs, **image_inputs} + + return BatchFeature(combined_inputs, tensor_type=return_tensors) + + +class OpenCUAMultiModalProcessor(BaseMultiModalProcessor[OpenCUAProcessingInfo]): + def _get_data_parser(self) -> MultiModalDataParser: + return Qwen2VLMultiModalDataParser( + self.info.get_hf_config().vision_config.spatial_merge_size + ) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _create_qwen2vl_field_factory( + self.info.get_hf_config().vision_config.spatial_merge_size + )(hf_inputs) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + """vLLM이 prompt 업데이트를 처리하도록 False 반환.""" + return False + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + hf_config = self.info.get_hf_config() + + image_token_str = getattr(hf_processor, "image_token", "<|media_placeholder|>") + image_token_id = vocab.get( + image_token_str, + getattr(hf_config, "media_placeholder_token_id", 151664), + ) + + merge_length = image_processor.merge_size**2 + + def get_replacement_opencua(item_idx: int): + out_item = out_mm_kwargs["image"][item_idx] + grid_thw = out_item["image_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [image_token_id] * num_tokens + + return [ + PromptReplacement( + modality="image", + target=[image_token_id], + replacement=get_replacement_opencua, + ) + ] + + +class OpenCUADummyInputsBuilder(Qwen2VLDummyInputsBuilder): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + image_token = "<|media_placeholder|>" + + return image_token * num_images + + +@MULTIMODAL_REGISTRY.register_processor( + OpenCUAMultiModalProcessor, + info=OpenCUAProcessingInfo, + dummy_inputs=OpenCUADummyInputsBuilder, +) +class OpenCUAForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): + merge_by_field_config = True + multimodal_cpu_fields = {"image_grid_thw"} + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.language_model.": "language_model.model.", + "model.visual.": "visual.", + "vision_tower.": "visual.", + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + } + ) + + supports_encoder_tp_data = True + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("image"): + return "<|media_placeholder|>" + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.config = config + self.vllm_config = vllm_config + self.multimodal_config = multimodal_config + self.quant_config = quant_config + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) + + if multimodal_config.get_limit_per_prompt("image"): + attn_backend_override = ( + multimodal_config.mm_encoder_attn_backend + if multimodal_config is not None + else None + ) + self.visual = OpenCUAVisionTransformer( + vision_config=config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + attn_backend_override=attn_backend_override, + ) + else: + self.visual = None + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4943987606201..b3da64af750c7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -289,6 +289,10 @@ _MULTIMODAL_MODELS = { "H2OVLChatModel": ("h2ovl", "H2OVLChatModel"), "InternVLChatModel": ("internvl", "InternVLChatModel"), "NemotronH_Nano_VL_V2": ("nano_nemotron_vl", "NemotronH_Nano_VL_V2"), + "OpenCUAForConditionalGeneration": ( + "opencua", + "OpenCUAForConditionalGeneration", + ), "InternS1ForConditionalGeneration": ( "interns1", "InternS1ForConditionalGeneration",