From c10f5653baeb6d1e1d684161c37bd59ce02edcf7 Mon Sep 17 00:00:00 2001 From: Yang Date: Wed, 26 Nov 2025 15:09:47 -0800 Subject: [PATCH] 1. Add support for Isaac model in the registry and documentation 2. optimize Isaac model implementation. Signed-off-by: Yang --- docs/models/supported_models.md | 1 + tests/models/registry.py | 4 + vllm/model_executor/models/isaac.py | 394 +++++++++++++++------------- 3 files changed, 212 insertions(+), 187 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9ba0f4ca9096e..470807ff8da91 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -679,6 +679,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `H2OVLChatModel` | H2OVL | T + IE+ | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | | `HunYuanVLForConditionalGeneration` | HunyuanOCR | T + IE+ | `tencent/HunyuanOCR`, etc. | ✅︎ | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | +| `IsaacForConditionalGeneration` | Isaac | T + I+ | `PerceptronAI/Isaac-0.1` | ✅︎ | ✅︎ | | `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | | `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + IE+ + VE+ | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | diff --git a/tests/models/registry.py b/tests/models/registry.py index c5d72b5d581b9..7ce22f4238167 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -646,6 +646,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { "HuggingFaceM4/Idefics3-8B-Llama3", extras={"tiny": "HuggingFaceTB/SmolVLM-256M-Instruct"}, ), + "IsaacForConditionalGeneration": _HfExamplesInfo( + "PerceptronAI/Isaac-0.1", + trust_remote_code=True, + ), "InternS1ForConditionalGeneration": _HfExamplesInfo( "internlm/Intern-S1", trust_remote_code=True ), diff --git a/vllm/model_executor/models/isaac.py b/vllm/model_executor/models/isaac.py index d2d980a9aadf4..82dae62cb56e4 100644 --- a/vllm/model_executor/models/isaac.py +++ b/vllm/model_executor/models/isaac.py @@ -4,7 +4,7 @@ from __future__ import annotations import itertools import math -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence from dataclasses import dataclass from enum import Enum from typing import Any @@ -15,7 +15,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import PretrainedConfig, Qwen3Config +from transformers import Qwen3Config from transformers.image_processing_utils import BatchFeature from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig from transformers.tokenization_utils import TensorType @@ -30,8 +30,10 @@ from vllm.attention.ops.vit_attn_wrappers import ( vit_xformers_attn_wrapper, ) from vllm.config import VllmConfig +from vllm.config.model import ModelConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger from vllm.model_executor.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -50,18 +52,18 @@ from vllm.model_executor.models.interfaces import ( SupportsPP, ) from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM from vllm.model_executor.models.siglip import SiglipMLP from vllm.model_executor.models.utils import ( AutoWeightsLoader, WeightsMapper, - _merge_multimodal_embeddings, + init_vllm_registered_model, maybe_prefix, ) from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import ( MultiModalDataDict, + MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargs, ) @@ -73,6 +75,13 @@ from vllm.multimodal.processing import ( PromptUpdate, ) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import ( + get_cached_tokenizer, + get_tokenizer, +) + +logger = init_logger(__name__) # ===== TensorStream Compatibility Layer for Isaac MRoPE ===== # Minimal implementation of TensorStream classes needed for Isaac's 3D positional @@ -286,12 +295,14 @@ def compute_mrope_pos_tensor(ts: TensorStream, n_pos_dims: int = 3) -> torch.Ten dims = (event.dims() or [1]) + [1] * (n_pos_dims - len(event.dims() or [])) # Create ranges for each dimension (similar to old _finalize implementation) - first_dim = range(cumulative_offset, cumulative_offset + dims[0]) + first_dim = list(range(cumulative_offset, cumulative_offset + dims[0])) cumulative_offset += dims[0] # advance time for the next event - other_dims = [range(d) for d in dims[1:]] - # Use itertools.product to create all coordinate combinations - full_coords = list(itertools.product(first_dim, *other_dims)) + if event.modality_type != VisionType.image: + full_coords = [(t, t, t) for t in first_dim] + else: + other_dims = [range(d) for d in dims[1:]] + full_coords = list(itertools.product(first_dim, *other_dims)) # Slice if the event is partial s, e = event.idx_range @@ -307,6 +318,19 @@ def compute_mrope_pos_tensor(ts: TensorStream, n_pos_dims: int = 3) -> torch.Ten ) +def _resolve_vision_token_id(model_config: ModelConfig, vision_token: str) -> int: + tokenizer_name = model_config.tokenizer or model_config.model + tokenizer = get_cached_tokenizer( + get_tokenizer( + tokenizer_name, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision or model_config.revision, + ) + ) + return tokenizer.encode(vision_token, add_special_tokens=False)[0] + + def modality_mask(ts: TensorStream, modality_type: ModalityType) -> torch.Tensor: """Create boolean mask for specific modality type in the tensor stream.""" B, T = ts.shape @@ -883,7 +907,8 @@ class IsaacConfig(Qwen3Config): vision_min_num_patches: int | None = None, pixel_shuffle_scale: int = 1, max_sequence_length: int = 16384, - vision_token: str = "<|image_pad|>", + vision_token: str = "", + vision_attn_implementation: str | None = None, **kwargs, ): super().__init__(**kwargs) @@ -899,10 +924,25 @@ class IsaacConfig(Qwen3Config): self.vision_token = vision_token # Handle vision config - PixelShuffleSiglip2VisionConfig instance - self.vision_config = PixelShuffleSiglip2VisionConfig( - pixel_shuffle_scale_factor=pixel_shuffle_scale, - num_patches=vision_max_num_patches, + if isinstance(vision_config, dict): + self.vision_config = PixelShuffleSiglip2VisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = PixelShuffleSiglip2VisionConfig() + else: + self.vision_config = vision_config + + # Ensure compatibility with pretrained checkpoints + self.vision_config.pixel_shuffle_scale_factor = getattr( + self.vision_config, + "pixel_shuffle_scale_factor", + pixel_shuffle_scale, ) + self.vision_config.num_patches = getattr( + self.vision_config, + "num_patches", + vision_max_num_patches, + ) + self.vision_attn_implementation = vision_attn_implementation class IsaacImageProcessorKwargs(TypedDict, total=False): @@ -991,9 +1031,9 @@ class IsaacProcessor: tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") def __init__(self, image_processor=None, tokenizer=None, **kwargs): + self.image_token = kwargs.pop("image_token", "") self.image_processor = image_processor or IsaacImageProcessor(kwargs) self.tokenizer = tokenizer - self.image_token = "<|image_pad|>" def __call__(self, text=None, images=None, **kwargs) -> BatchFeature: result = {} @@ -1062,12 +1102,20 @@ class IsaacProcessingInfo(BaseProcessingInfo): max_sequence_length=getattr( original_config, "max_sequence_length", 16384 ), - vision_token="<|image_pad|>", + vision_token=getattr(original_config, "vision_token", ""), + vision_attn_implementation=getattr( + original_config, "vision_attn_implementation", None + ), ) return IsaacConfig() def get_hf_processor(self, **kwargs) -> IsaacProcessor: - return self.ctx.get_hf_processor(IsaacProcessor, **kwargs) + hf_config = self.get_hf_config() + processor_kwargs = { + "image_token": hf_config.vision_token, + } + processor_kwargs.update(kwargs) + return self.ctx.get_hf_processor(IsaacProcessor, **processor_kwargs) def get_tokenizer(self): return self.ctx.tokenizer @@ -1157,11 +1205,13 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor): out_mm_kwargs: MultiModalKwargs, ) -> Sequence[PromptUpdate]: # hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + hf_config = self.info.get_hf_config() image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs) tokenizer = self.info.get_tokenizer() - - vocab = tokenizer.get_vocab() - placeholder_id = vocab.get("<|image_pad|>", 151655) + placeholder_id = tokenizer.encode( + hf_config.vision_token, + add_special_tokens=False, + ) pixel_shuffle_scale = getattr(image_processor, "pixel_shuffle_scale", 2) merge_length = pixel_shuffle_scale**2 @@ -1172,12 +1222,12 @@ class IsaacMultiModalProcessor(BaseMultiModalProcessor): assert isinstance(grid_thw, torch.Tensor) num_tokens = int(grid_thw.prod()) // merge_length - return [placeholder_id] * num_tokens + return placeholder_id * num_tokens return [ PromptReplacement( modality="image", - target=[placeholder_id], + target=placeholder_id, replacement=get_replacement_isaac, ) ] @@ -1278,16 +1328,7 @@ class Siglip2VisionAttention(nn.Module): def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: seq_len, bs, _ = qkv.shape - if self.tp_size > 1: - qkv = all_gather_interleave(qkv, self.qkv_proj.hidden_size, self.tp_size) - q, k, v = qkv.chunk(3, dim=2) - - if self.tp_size > 1: - q = dist_utils.split_tensor_along_last_dim(q, self.tp_size)[self.tp_rank] - k = dist_utils.split_tensor_along_last_dim(k, self.tp_size)[self.tp_rank] - v = dist_utils.split_tensor_along_last_dim(v, self.tp_size)[self.tp_rank] - new_shape = ( seq_len, bs, @@ -1604,7 +1645,8 @@ class IsaacVisionEmbedding(nn.Module): vision_cfg: PixelShuffleSiglip2VisionConfig, hidden_dim: int, output_dim: int, - prefix: str, + quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() self.transformer = Siglip2VisionTransformer( @@ -1614,6 +1656,7 @@ class IsaacVisionEmbedding(nn.Module): hidden_dim, 4 * hidden_dim, bias=False, + quant_config=quant_config, prefix=maybe_prefix(prefix, "vision_embedding.1"), return_bias=False, ) @@ -1622,6 +1665,7 @@ class IsaacVisionEmbedding(nn.Module): 4 * hidden_dim, output_dim, bias=False, + quant_config=quant_config, prefix=maybe_prefix(prefix, "vision_embedding.3"), return_bias=False, ) @@ -1642,8 +1686,9 @@ class IsaacVisionEmbedding(nn.Module): dummy_inputs=IsaacDummyInputsBuilder, ) class IsaacForConditionalGeneration( - Qwen3ForCausalLM, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE + nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE ): + merge_by_field_config = True packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1661,221 +1706,196 @@ class IsaacForConditionalGeneration( # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.vision_embedding.0": "vision_embedding.transformer", + "model.vision_embedding.1": "vision_embedding.linear_fc1", + "model.vision_embedding.2": "vision_embedding.act", + "model.vision_embedding.3": "vision_embedding.linear_fc2", "model.vision_embedding.": "vision_embedding.", - "vision_embedding.0": "vision_embedding.transformer", - "vision_embedding.1": "vision_embedding.linear_fc1", - "vision_embedding.2": "vision_embedding.act", - "vision_embedding.3": "vision_embedding.linear_fc2", + "model.": "language_model.model.", } ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): - return "<|image_pad|>" + return "" raise ValueError("Only image modality is supported") def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): + super().__init__() config: IsaacConfig = vllm_config.model_config.hf_config - head_dim = config.head_dim + quant_config = vllm_config.quant_config + self.config = config + self.multimodal_config = vllm_config.model_config.multimodal_config + head_dim = config.head_dim calculated_mrope_section = [ head_dim // 4, # 2x more for temporal dim head_dim // 8, head_dim // 8, ] + self.vision_token_id = _resolve_vision_token_id( + vllm_config.model_config, config.vision_token + ) + config.image_token_id = self.vision_token_id + + logger.info("vllm config: %s", repr(vllm_config)) config.rope_scaling["mrope_section"] = calculated_mrope_section - self.config = config - - # Initialize the parent class with updated config - super().__init__(vllm_config=vllm_config, prefix=prefix) - - # Create the language model module to match checkpoint structure - self.language_model = nn.ModuleDict( - { - "embed_tokens": self.model.embed_tokens, - "layers": self.model.layers, - "norm": self.model.norm, - } + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + architectures=["Qwen3ForCausalLM"], + prefix=maybe_prefix(prefix, "language_model"), + ) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors ) - config.vision_config.preserve_original_pe = True - config.vision_config.use_rope = False - config.vision_config.hidden_stride = ( - config.vision_config.pixel_shuffle_scale_factor - ) - config.vision_config.window_size = 32 * 2 - config.vision_config.fullatt_block_indexes = None vision_cfg = config.vision_config if vision_cfg is None: raise ValueError("IsaacConfig should always have vision_config") + vision_cfg.preserve_original_pe = True + vision_cfg.use_rope = False + vision_cfg.hidden_stride = vision_cfg.pixel_shuffle_scale_factor + vision_cfg.window_size = 32 * 2 + vision_cfg.fullatt_block_indexes = None + attn_impl = ( + config.vision_attn_implementation + if config.vision_attn_implementation is not None + else getattr(config, "_attn_implementation", None) + ) + if attn_impl is not None: + vision_cfg._attn_implementation = attn_impl hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) self.vision_embedding = IsaacVisionEmbedding( vision_cfg=vision_cfg, hidden_dim=hidden_dim, output_dim=config.hidden_size, - prefix=prefix, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_embedding"), ) + def iter_mm_grid_hw( + self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] + ) -> Iterator[tuple[int, int, int]]: + spatial_merge_size = self.config.vision_config.pixel_shuffle_scale_factor + for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): + offset = mm_feature.mm_position.offset + if mm_feature.modality == "image": + t, h, w = mm_feature.data["image_grid_thw"].data.tolist() + assert t == 1, f"Image must have 1 frame, got {t}" + yield offset, h // spatial_merge_size, w // spatial_merge_size + else: + raise ValueError(f"Unsupported modality: {mm_feature.modality}") + def get_mrope_input_positions( self, input_tokens: list[int], - hf_config: PretrainedConfig, - image_grid_thw: list[list[int]] | torch.Tensor, - video_grid_thw: list[list[int]] | torch.Tensor, - context_len: int = 0, - seq_len: int | None = None, - second_per_grid_ts: list[float] | None = None, - audio_feature_lengths: torch.Tensor | None = None, - use_audio_in_video: bool = False, + mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - """Get mrope input positions and delta value.""" - - vision_token_id = getattr(self.config, "image_token_id", 151655) - spatial_merge_size = hf_config.vision_config.pixel_shuffle_scale_factor - input_tokens_tensor = torch.tensor(input_tokens) - - # Find image token positions - image_positions = torch.where(input_tokens_tensor == vision_token_id)[ - 0 - ].tolist() - - # For text-only inputs, use Isaac's original logic from - # compute_position_ids_input_ids() - if len(image_positions) == 0: - seq_len = len(input_tokens) - # Create 3D positions where all dimensions get the same 1D temporal - # progression - position_ids = torch.arange(seq_len, dtype=torch.long) - position_ids = position_ids.view(1, -1).expand(1, -1) # [1, seq_len] - position_ids = position_ids.unsqueeze(2).expand( - -1, -1, 3 - ) # [1, seq_len, 3] - - # vLLM expects shape [3, seq_len], so transpose - position_ids = position_ids.squeeze(0).transpose(0, 1) # [3, seq_len] - - return position_ids, 0 - - events = [] - image_idx = 0 - current_pos = 0 - last_processed_pos = -1 - - for image_pos in image_positions: - if image_pos <= last_processed_pos: - continue # Skip already processed positions - - # Add any text before this image - if image_pos > current_pos: - text_tokens = image_pos - current_pos - text_event = Event( - modality_type=TextType.text, - dims_virtual=[text_tokens, 1], - idx_range=(0, text_tokens), - ) - events.append(text_event) - - # Add image - t, h, w = image_grid_thw[image_idx] - llm_grid_h, llm_grid_w = h // spatial_merge_size, w // spatial_merge_size - image_tokens = t * llm_grid_h * llm_grid_w - - image_event = Event( - modality_type=VisionType.image, - dims_virtual=[t, llm_grid_h, llm_grid_w], - idx_range=(0, image_tokens), + llm_pos_ids_list = [] + st = 0 + for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( + input_tokens, mm_features + ): + text_len = offset - st + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - events.append(image_event) - current_pos = image_pos + image_tokens - last_processed_pos = ( - current_pos - 1 - ) # Mark up to this position as processed - image_idx += 1 + grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) + grid_indices[0, :] = grid_indices[0, :] + text_len + st_idx + llm_pos_ids_list.append(grid_indices) + st = offset + llm_grid_h * llm_grid_w - # Add final text segment if any - if current_pos < len(input_tokens): - text_tokens = len(input_tokens) - current_pos - text_event = Event( - modality_type=TextType.text, - dims_virtual=[text_tokens, 1], - idx_range=(0, text_tokens), + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1][0, -1] + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - events.append(text_event) - stream = Stream(events) - tensor_stream = TensorStream([stream]) + llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - # Use Isaac's native MRoPE calculation - position_ids = compute_mrope_pos_tensor(tensor_stream, n_pos_dims=3) + return torch.from_numpy(llm_positions), mrope_position_delta - # Max position per batch across the 3 planes and sequence dimension: (B,) - m_per_batch = position_ids.amax(dim=(1, 2)) + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> dict[str, torch.Tensor] | None: + pixel_values = kwargs.get("pixel_values") + image_grid_thw = kwargs.get("image_grid_thw") + if pixel_values is None or image_grid_thw is None: + return None + return { + "pixel_values": pixel_values, + "image_grid_thw": image_grid_thw, + } - mrope_position_delta = (m_per_batch + 1 - len(input_tokens)).item() + def _process_image_input( + self, + image_input: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, ...]: + pixel_values = image_input["pixel_values"] + image_grid_thw = image_input["image_grid_thw"] + if pixel_values.numel() == 0: + return () - # vLLM expects shape [3, seq_len] but Isaac returns [batch, seq_len, 3] - # Transpose to match vLLM's expected format - position_ids = position_ids.squeeze(0).transpose(0, 1) + device = next(self.language_model.parameters()).device + dtype = self.vision_embedding.linear_fc1.weight.dtype + pixel_values = pixel_values.to(device=device, dtype=dtype) + if image_grid_thw.dim() == 3: + image_grid_thw = image_grid_thw[0] + spatial_grids = image_grid_thw[:, 1:3].to(device, dtype=torch.int32) - return position_ids, mrope_position_delta + vision_embeddings = self.vision_embedding((pixel_values, spatial_grids)) + merge_size = self.config.vision_config.pixel_shuffle_scale_factor + sizes = spatial_grids.prod(-1) // (merge_size * merge_size) + return tuple(vision_embeddings.split(sizes.tolist())) + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return () + return self._process_image_input(image_input) def get_multimodal_embeddings( self, **kwargs: object ) -> MultiModalEmbeddings | None: - pixel_values = kwargs.get("pixel_values") - image_grid_thw = kwargs.get("image_grid_thw") - - if pixel_values is None: + # Backward compatibility for older runners. + embeddings = self.embed_multimodal(**kwargs) + if not embeddings: return [] + return embeddings - # Convert image_grid_thw from [batch, 1, [T, H, W]] to [batch, [H, W]] - spatial_grids = image_grid_thw[ - :, 0, 1:3 - ] # Extract H, W from [T, H, W] for each image + def get_language_model(self) -> torch.nn.Module: + return self.language_model - # Process packed sequence patches through vision_embedding module - vision_embeddings = self.vision_embedding((pixel_values, spatial_grids)) - - # Split concatenated embeddings for each image item (following Qwen2-VL pattern) - merge_size = ( - self.config.vision_config.pixel_shuffle_scale_factor - ) # Isaac uses pixel shuffle - sizes = spatial_grids.prod(-1) // ( - merge_size * merge_size - ) # H * W / (merge_size^2) - - return vision_embeddings.split(sizes.tolist()) - - def get_input_embeddings( + def forward( self, input_ids: torch.Tensor, - multimodal_embeddings: MultiModalEmbeddings | None = None, - *, - is_multimodal: torch.Tensor | None = None, - handle_oov_mm_token: bool = False, - ) -> torch.Tensor: - # Get text embeddings from the base language model - inputs_embeds = super().get_input_embeddings(input_ids) + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + return self.language_model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) - # If we have multimodal embeddings, merge them with text embeddings - if multimodal_embeddings is not None and len(multimodal_embeddings) != 0: - inputs_embeds = _merge_multimodal_embeddings( - inputs_embeds=inputs_embeds, - multimodal_embeddings=multimodal_embeddings, - is_multimodal=is_multimodal, - ) - - return inputs_embeds + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - skip_prefixes = [] - - loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: