[Model] Use merge_by_field_config for MM models (M-N) (#26710)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-14 01:27:01 +08:00 committed by GitHub
parent e3b90c1ba2
commit afc47e4de7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 127 additions and 331 deletions

View File

@ -631,8 +631,11 @@ class InternS1ForConditionalGeneration(
) )
image_token_id = kwargs["image_token_id"] image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor) if isinstance(image_token_id, torch.Tensor):
self.img_context_token_id = image_token_id.flatten().unique().item() image_token_id = image_token_id.flatten().unique().item()
assert isinstance(image_token_id, int)
self.img_context_token_id = image_token_id
if pixel_values is not None: if pixel_values is not None:
h, w = self.config.vision_config.image_size h, w = self.config.vision_config.image_size
@ -665,8 +668,11 @@ class InternS1ForConditionalGeneration(
) )
video_token_id = kwargs["video_token_id"] video_token_id = kwargs["video_token_id"]
assert isinstance(video_token_id, torch.Tensor) if isinstance(video_token_id, torch.Tensor):
self.video_context_token_id = video_token_id.flatten().unique().item() video_token_id = video_token_id.flatten().unique().item()
assert isinstance(video_token_id, int)
self.video_context_token_id = video_token_id
if pixel_values_flat_video is not None: if pixel_values_flat_video is not None:
h, w = self.config.vision_config.image_size h, w = self.config.vision_config.image_size

View File

@ -1232,8 +1232,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
) )
image_token_id = kwargs["image_token_id"] image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor) if isinstance(image_token_id, torch.Tensor):
self.img_context_token_id = image_token_id.flatten().unique().item() image_token_id = image_token_id.flatten().unique().item()
assert isinstance(image_token_id, int)
self.img_context_token_id = image_token_id
if pixel_values_flat is not None: if pixel_values_flat is not None:
expected_h = expected_w = self.config.vision_config.image_size expected_h = expected_w = self.config.vision_config.image_size
@ -1265,8 +1268,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA)
) )
video_token_id = kwargs["video_token_id"] video_token_id = kwargs["video_token_id"]
assert isinstance(video_token_id, torch.Tensor) if isinstance(video_token_id, torch.Tensor):
self.video_context_token_id = video_token_id.flatten().unique().item() video_token_id = video_token_id.flatten().unique().item()
assert isinstance(video_token_id, int)
self.video_context_token_id = video_token_id
if pixel_values_flat_video is not None: if pixel_values_flat_video is not None:
expected_h = expected_w = self.config.vision_config.image_size expected_h = expected_w = self.config.vision_config.image_size

View File

@ -26,7 +26,7 @@
import collections import collections
import collections.abc import collections.abc
from collections.abc import Callable, Iterable, Mapping, Sequence from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any, TypeAlias, TypedDict, cast from typing import Annotated, Any, TypeAlias, cast
import numpy as np import numpy as np
import torch import torch
@ -62,6 +62,7 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.midashenglm import DashengConfig from vllm.transformers_utils.configs.midashenglm import DashengConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
@ -508,11 +509,16 @@ class AudioProjectorSubsample(nn.Module):
# === Audio Inputs === # # === Audio Inputs === #
class MiDashengLMAudioInputs(TypedDict): class MiDashengLMAudioInputs(TensorSchema):
input_values: torch.Tensor """
"""Shape: `(num_audios, num_sampling_points)`"""
audio_length: torch.Tensor Dimensions:
"""Shape: `(num_audios, 1)`""" - bn: Batch size * number of audios
- p: Number of sampling points
"""
input_values: Annotated[torch.Tensor, TensorShape("n", "p")]
audio_length: Annotated[torch.Tensor, TensorShape("n")]
class MiDashengLMProcessingInfo(BaseProcessingInfo): class MiDashengLMProcessingInfo(BaseProcessingInfo):
@ -676,6 +682,8 @@ class MiDashengLMMultiModalProcessor(
dummy_inputs=MiDashengLMDummyInputsBuilder, dummy_inputs=MiDashengLMDummyInputsBuilder,
) )
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@ -728,26 +736,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
self.decoder.make_empty_intermediate_tensors self.decoder.make_empty_intermediate_tensors
) )
def _validate_and_reshape_mm_tensor(
self, mm_input: object, name: str
) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
return mm_input.reshape(-1, *mm_input.shape[2:])
if name == "input_values":
max_length = max(tensor.shape[1] for tensor in mm_input)
padded_mm_input = [
torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1]))
if tensor.shape[1] < max_length
else tensor
for tensor in mm_input
]
return torch.concat(padded_mm_input)
return torch.concat(mm_input)
def _parse_and_validate_audio_input( def _parse_and_validate_audio_input(
self, **kwargs: object self, **kwargs: object
) -> MiDashengLMAudioInputs | None: ) -> MiDashengLMAudioInputs | None:
@ -756,16 +744,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
if input_values is None: if input_values is None:
return None return None
input_values = self._validate_and_reshape_mm_tensor(
input_values, "input_values" if isinstance(input_values, list):
) input_values = torch.nn.utils.rnn.pad_sequence(
audio_length = self._validate_and_reshape_mm_tensor( input_values,
audio_length, "audio_length" batch_first=True,
)
if not isinstance(input_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of audio input features. "
f"Got type: {type(input_values)}"
) )
return MiDashengLMAudioInputs( return MiDashengLMAudioInputs(
@ -773,7 +756,10 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_length=audio_length, audio_length=audio_length,
) )
def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor: def _process_audio_input(
self,
audio_input: MiDashengLMAudioInputs,
) -> tuple[torch.Tensor, ...]:
# Process audio through encoder and projector # Process audio through encoder and projector
input_values = audio_input["input_values"] input_values = audio_input["input_values"]
audio_length = audio_input["audio_length"] audio_length = audio_input["audio_length"]
@ -783,17 +769,13 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype) audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype)
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
audio_length_np = (
audio_length.cpu().numpy()
if isinstance(audio_length, torch.Tensor)
else audio_length
)
audio_output_lengths = [ audio_output_lengths = [
max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame
for length in audio_length_np for length in audio_length.tolist()
] ]
audio_output_lengths = torch.tensor(audio_output_lengths).to( audio_output_lengths = torch.tensor(
audio_embeddings.device audio_output_lengths,
device=audio_embeddings.device,
) )
audio_feature_mask = torch.arange( audio_feature_mask = torch.arange(
@ -826,14 +808,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> torch.Tensor | IntermediateTensors: ) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None: if intermediate_tensors is not None:
inputs_embeds = None inputs_embeds = None
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.audio_token_id,
)
input_ids = None
return self.decoder.model( return self.decoder.model(
input_ids, input_ids,

View File

@ -71,7 +71,7 @@ from .minicpmv import (
MiniCPMVProcessingInfo, MiniCPMVProcessingInfo,
_minicpmv_field_config, _minicpmv_field_config,
) )
from .utils import AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
CPU_DEVICE = torch.device("cpu") CPU_DEVICE = torch.device("cpu")
@ -132,15 +132,11 @@ MiniCPMOAudioInputs: TypeAlias = (
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features = hf_inputs.get("audio_features", torch.empty(0))
num_audios = len(audio_features)
return dict( return dict(
**_minicpmv_field_config(hf_inputs), **_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.batched("audio"), audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"), audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"), audio_embeds=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
) )
@ -332,10 +328,6 @@ class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessing
] ]
audio_inputs["audio_features"] = unpadded_audio_features audio_inputs["audio_features"] = unpadded_audio_features
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
return audio_inputs return audio_inputs
def process_mm_inputs( def process_mm_inputs(
@ -436,12 +428,10 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
past_key_values = None
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states, attn_weights, past_key_values = self.self_attn( hidden_states, _ = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
past_key_value=past_key_values,
) )
hidden_states = nn.functional.dropout( hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training hidden_states, p=self.dropout, training=self.training
@ -567,8 +557,6 @@ class MiniCPMO(MiniCPMV2_6):
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
) )
self.audio_token_id = None
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""): def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Do not use parameters temporarily # Do not use parameters temporarily
audio_config = self.config.audio_config audio_config = self.config.audio_config
@ -731,43 +719,18 @@ class MiniCPMO(MiniCPMV2_6):
if audio_features is None and audio_embeds is None: if audio_features is None and audio_embeds is None:
return None return None
audio_token_id = kwargs.pop("audio_token_id")
if audio_token_id is not None:
assert isinstance(audio_token_id, torch.Tensor)
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
if audio_embeds is not None: if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of audio_embeds. Got type: {type(audio_embeds)}"
)
audio_embeds_flat = flatten_bn(audio_embeds)
return MiniCPMOAudioEmbeddingInputs( return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds", type="audio_embeds",
audio_embeds=audio_embeds_flat, audio_embeds=audio_embeds,
)
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of audio_features. Got type: {type(audio_features)}"
) )
audio_feature_lens = kwargs.pop("audio_feature_lens") audio_feature_lens = kwargs.pop("audio_feature_lens")
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of audio_feature_lens. "
f"Got type: {type(audio_feature_lens)}"
)
audio_features_flat = flatten_bn(audio_features)
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
return MiniCPMOAudioFeatureInputs( return MiniCPMOAudioFeatureInputs(
type="audio_features", type="audio_features",
audio_features=audio_features_flat, audio_features=audio_features,
audio_feature_lens=audio_feature_lens_flat, audio_feature_lens=audio_feature_lens,
) )
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:

View File

@ -114,7 +114,7 @@ class MiniCPMVImagePixelInputs(TensorSchema):
type: Literal["pixel_values"] = "pixel_values" type: Literal["pixel_values"] = "pixel_values"
# Note that the image size may vary, so we pass it as a list instead of a # Note that the patch size may vary, so we pass it as a list instead of a
# batched tensor. # batched tensor.
pixel_values: Annotated[ pixel_values: Annotated[
list[torch.Tensor], list[torch.Tensor],
@ -453,12 +453,6 @@ def get_version_by_config(config: PretrainedConfig) -> tuple[int, ...]:
def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
pixel_values = hf_inputs.get("pixel_values", torch.empty(0))
num_images = len(pixel_values)
video_pixel_values = hf_inputs.get("video_pixel_values", torch.empty(0))
num_videos = len(video_pixel_values)
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
image_sizes=MultiModalFieldConfig.batched("image"), image_sizes=MultiModalFieldConfig.batched("image"),
@ -468,8 +462,6 @@ def _minicpmv_field_config(hf_inputs: Mapping[str, torch.Tensor]):
video_image_sizes=MultiModalFieldConfig.batched("video"), video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.batched("video"), video_tgt_sizes=MultiModalFieldConfig.batched("video"),
video_embeds=MultiModalFieldConfig.batched("video"), video_embeds=MultiModalFieldConfig.batched("video"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
) )
@ -792,10 +784,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
out_keys={"pixel_values", "image_sizes", "tgt_sizes"}, out_keys={"pixel_values", "image_sizes", "tgt_sizes"},
) )
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
image_inputs["image_token_id"] = torch.tensor(unk_token_id)
return image_inputs return image_inputs
def process_videos( def process_videos(
@ -831,10 +819,6 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
video_inputs = {f"video_{k}": v for k, v in video_inputs.items()} video_inputs = {f"video_{k}": v for k, v in video_inputs.items()}
tokenizer = self.info.get_tokenizer()
unk_token_id = tokenizer.get_vocab()["<unk>"]
video_inputs["video_token_id"] = torch.tensor(unk_token_id)
return video_inputs return video_inputs
def process_mm_inputs( def process_mm_inputs(
@ -1021,6 +1005,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
instantiated. instantiated.
""" """
merge_by_field_config = True
supports_encoder_tp_data = True supports_encoder_tp_data = True
@classmethod @classmethod
@ -1066,7 +1052,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
prefix=maybe_prefix(prefix, "resampler"), prefix=maybe_prefix(prefix, "resampler"),
) )
self.mm_token_ids = set[int]()
self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors self.make_empty_intermediate_tensors = self.llm.make_empty_intermediate_tensors
def _parse_and_validate_vision_input( def _parse_and_validate_vision_input(
@ -1080,43 +1065,17 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
if pixel_values is None and image_embeds is None: if pixel_values is None and image_embeds is None:
return None return None
image_token_id = kwargs.pop("image_token_id")
if image_token_id is not None:
assert isinstance(image_token_id, torch.Tensor)
self.mm_token_ids.add(image_token_id.flatten().unique().item())
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of image_embeds for {modality=}. "
f"Got type: {type(image_embeds)}"
)
image_embeds_flat = flatten_bn(image_embeds)
return MiniCPMVImageEmbeddingInputs( return MiniCPMVImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
image_embeds=image_embeds_flat, image_embeds=image_embeds,
)
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of pixel_values for {modality=}. "
f"Got type: {type(pixel_values)}"
) )
tgt_sizes = kwargs.pop("tgt_sizes") tgt_sizes = kwargs.pop("tgt_sizes")
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of tgt_sizes for {modality=}. "
f"Got type: {type(tgt_sizes)}"
)
num_slices = [[len(p) for p in ps] for ps in pixel_values] num_slices_flat = torch.tensor([len(ps) for ps in pixel_values])
num_slices_flat = flatten_bn(torch.tensor(num_slices)) pixel_values_flat = flatten_bn(pixel_values)
tgt_sizes_flat = flatten_bn(tgt_sizes, concat=True)
pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
return MiniCPMVImagePixelInputs( return MiniCPMVImagePixelInputs(
type="pixel_values", type="pixel_values",
@ -1142,15 +1101,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
input_key in ("video_pixel_values", "video_embeds") input_key in ("video_pixel_values", "video_embeds")
and "videos" not in modalities and "videos" not in modalities
): ):
def _image_key(video_key: str):
if video_key == "video_token_id":
return "image_token_id"
return video_key.removeprefix("video_")
modalities["videos"] = self._parse_and_validate_vision_input( modalities["videos"] = self._parse_and_validate_vision_input(
"videos", **{_image_key(k): v for k, v in kwargs.items()} "videos", **{k.removeprefix("video_"): v for k, v in kwargs.items()}
) )
return modalities return modalities

View File

@ -71,7 +71,7 @@ from .interfaces import (
SupportsPP, SupportsPP,
) )
from .llama4 import Llama4ForCausalLM from .llama4 import Llama4ForCausalLM
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix from .utils import AutoWeightsLoader, maybe_prefix
from .vision import run_dp_sharded_vision_model from .vision import run_dp_sharded_vision_model
@ -86,7 +86,7 @@ class Llama4ImagePatchInputs(TensorSchema):
type: Literal["pixel_values"] = "pixel_values" type: Literal["pixel_values"] = "pixel_values"
flat_data: Annotated[ pixel_values: Annotated[
torch.Tensor, torch.Tensor,
TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"), TensorShape("total_num_chunks", "num_channels", "image_size", "image_size"),
] ]
@ -96,7 +96,7 @@ class Llama4ImagePatchInputs(TensorSchema):
The number of total patches for each image in the batch. The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`. flattened just like `pixel_values`.
""" """
aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)] aspect_ratios: Annotated[torch.Tensor, TensorShape("batch_size", 2)]
@ -725,6 +725,8 @@ class Mllama4DummyInputsBuilder(BaseDummyInputsBuilder[Mllama4ProcessingInfo]):
class Llama4ForConditionalGeneration( class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3 nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
): ):
merge_by_field_config = True
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"], "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
@ -798,17 +800,12 @@ class Llama4ForConditionalGeneration(
if pixel_values is None: if pixel_values is None:
return None return None
# num_images x num_chunks, channel, image_size, image_size patches_per_image = kwargs.pop("patches_per_image")
# TODO: confirm handling for variable lengths
flat_pixel_values = flatten_bn(pixel_values, concat=True)
patches_per_image = flatten_bn(kwargs.pop("patches_per_image"))
aspect_ratios = kwargs.pop("aspect_ratios") aspect_ratios = kwargs.pop("aspect_ratios")
if aspect_ratios.ndim == 3:
aspect_ratios = aspect_ratios.squeeze(1)
return Llama4ImagePatchInputs( return Llama4ImagePatchInputs(
type="pixel_values", type="pixel_values",
flat_data=flat_pixel_values, pixel_values=pixel_values,
patches_per_image=patches_per_image, patches_per_image=patches_per_image,
aspect_ratios=aspect_ratios, aspect_ratios=aspect_ratios,
) )
@ -817,16 +814,16 @@ class Llama4ForConditionalGeneration(
self, image_input: Llama4ImagePatchInputs self, image_input: Llama4ImagePatchInputs
) -> MultiModalEmbeddings: ) -> MultiModalEmbeddings:
assert self.vision_model and self.multi_modal_projector assert self.vision_model and self.multi_modal_projector
flat_data = image_input["flat_data"] pixel_values = image_input["pixel_values"]
patches_per_image = image_input["patches_per_image"].tolist() patches_per_image = image_input["patches_per_image"].tolist()
# shard image input # shard image input
if self.use_data_parallel: if self.use_data_parallel:
vision_embeddings_flat = run_dp_sharded_vision_model( vision_embeddings_flat = run_dp_sharded_vision_model(
flat_data, self.vision_model pixel_values, self.vision_model
) )
else: else:
vision_embeddings_flat = self.vision_model(flat_data) vision_embeddings_flat = self.vision_model(pixel_values)
vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat) vision_embeddings_flat = self.multi_modal_projector(vision_embeddings_flat)

View File

@ -75,7 +75,6 @@ from .interfaces import (
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
flatten_bn,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_empty_intermediate_tensors_factory,
make_layers, make_layers,
@ -97,28 +96,19 @@ class MolmoImageInputs(TensorSchema):
""" """
Dimensions: Dimensions:
- bn: Batch size * number of images - bn: Batch size * number of images
- nc: Number of crops (dynamic) - bnc: Batch size * number of images * number of crops (dynamic)
- np: Number of patches - np: Number of patches
- tp: Token sequence positions - tp: Token sequence positions
- pd: Patch dimension - pd: Patch dimension
""" """
images: Annotated[ images: Annotated[torch.Tensor, TensorShape("bnc", "np", "pd")]
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}),
]
# Number of crops may vary per batch and image, so pass it as a list.
image_masks: Annotated[ image_masks: Annotated[torch.Tensor | None, TensorShape("bnc", "np")]
torch.Tensor | list[torch.Tensor] | None,
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}), image_input_idx: Annotated[torch.Tensor, TensorShape("bnc", "tp")]
] """An index tensor that maps image features to their corresponding patch tokens."""
image_input_idx: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
]
# An index tensor that maps image features to their corresponding patch tokens.
num_crops: Annotated[torch.Tensor, TensorShape("bn")] num_crops: Annotated[torch.Tensor, TensorShape("bn")]
@ -1363,6 +1353,8 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
class MolmoForCausalLM( class MolmoForCausalLM(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsQuant
): ):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={ orig_to_new_substr={
# vision backbone mapping # vision backbone mapping
@ -1451,18 +1443,12 @@ class MolmoForCausalLM(
if images is None: if images is None:
return None return None
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of num_crops. Got type: {type(num_crops)}"
)
num_crops = flatten_bn(num_crops, concat=True)
img_patch_id = kwargs.pop("img_patch_id", None) img_patch_id = kwargs.pop("img_patch_id", None)
if not isinstance(img_patch_id, torch.Tensor): if isinstance(img_patch_id, torch.Tensor):
raise ValueError( img_patch_id = img_patch_id.item()
f"Incorrect type of img_patch_id. Got type: {type(img_patch_id)}"
) assert isinstance(img_patch_id, int)
self.img_patch_id = img_patch_id.flatten().unique().item() self.img_patch_id = img_patch_id
return MolmoImageInputs( return MolmoImageInputs(
images=images, images=images,
@ -1481,17 +1467,9 @@ class MolmoForCausalLM(
num_crops = image_input["num_crops"] num_crops = image_input["num_crops"]
# Call the vision backbone on the whole batch at once # Call the vision backbone on the whole batch at once
images_flat = flatten_bn(images, concat=True) image_features = self.vision_backbone(
image_masks_flat = ( images=images.unsqueeze(0),
None if image_masks is None else flatten_bn(image_masks, concat=True) image_masks=None if image_masks is None else image_masks.unsqueeze(0),
)
image_input_idx_flat = flatten_bn(image_input_idx, concat=True)
image_features_flat = self.vision_backbone(
images=images_flat.unsqueeze(0),
image_masks=(
None if image_masks_flat is None else image_masks_flat.unsqueeze(0)
),
).squeeze(0) ).squeeze(0)
# Only the features corresponding to patch tokens are relevant # Only the features corresponding to patch tokens are relevant
@ -1499,8 +1477,8 @@ class MolmoForCausalLM(
results = [] results = []
num_crops_list = num_crops.tolist() num_crops_list = num_crops.tolist()
for feats, img_idx in zip( for feats, img_idx in zip(
image_features_flat.split(num_crops_list), image_features.split(num_crops_list),
image_input_idx_flat.split(num_crops_list), image_input_idx.split(num_crops_list),
): ):
is_valid = img_idx >= 0 is_valid = img_idx >= 0
valid_img_idx = img_idx[is_valid] valid_img_idx = img_idx[is_valid]

View File

@ -11,7 +11,7 @@ import copy
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias, TypedDict, TypeVar from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy.typing as npt import numpy.typing as npt
import torch import torch
@ -40,7 +40,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.radio import RadioModel from vllm.model_executor.models.radio import RadioModel
from vllm.model_executor.models.utils import ( from vllm.model_executor.models.utils import (
flatten_bn,
init_vllm_registered_model, init_vllm_registered_model,
maybe_prefix, maybe_prefix,
) )
@ -96,31 +95,35 @@ MAX_FRAMES = 16
DEFAULT_NUM_TILES = 12 DEFAULT_NUM_TILES = 12
class NanoNemotronVLImagePixelInputs(TypedDict): class NanoNemotronVLImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- bnp: Batch size * number of images * (1 + num_patches)
- c: Number of channels (3)
- h: Height of each image patch
- w: Width of each image patch
"""
type: Literal["pixel_values"] type: Literal["pixel_values"]
pixel_values_flat: torch.Tensor pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class NanoNemotronVLImageEmbeddingInputs(TensorSchema):
""" """
Shape: Dimensions:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)` - n: Number of images
- f: Total image feature size
- h: Hidden size (must match the hidden size of language model backbone)
""" """
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class NanoNemotronVLImageEmbeddinInputs(TypedDict):
type: Literal["image_embeds"] type: Literal["image_embeds"]
data: torch.Tensor | list[torch.Tensor] data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
NanoNemotronVLImageInputs: TypeAlias = ( NanoNemotronVLImageInputs: TypeAlias = (
NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddinInputs NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddingInputs
) )
@ -710,37 +713,12 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]): class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""Basic image-only MultiModalProcessor for InternVL-style models.""" """Basic image-only MultiModalProcessor for InternVL-style models."""
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_token_id = hf_processor.image_token_id
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
return processed_outputs
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
return dict( return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
@ -748,7 +726,6 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
), ),
image_num_patches=MultiModalFieldConfig.batched("image"), image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
) )
def _get_prompt_updates( def _get_prompt_updates(
@ -814,25 +791,6 @@ class NanoNemotronVLMultiModalProcessor(
): ):
"""MultiModalProcessor extended for video support""" """MultiModalProcessor extended for video support"""
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs, tok_kwargs
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
if (
self.info.supports_video
and (video_token_id := hf_processor.video_token_id) is not None
):
processed_outputs["video_token_id"] = torch.tensor(video_token_id)
return processed_outputs
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
@ -841,13 +799,12 @@ class NanoNemotronVLMultiModalProcessor(
image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs) image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
if self.info.supports_video: if self.info.supports_video:
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0)) video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
num_videos = len(video_num_patches)
video_fields = dict( video_fields = dict(
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes( pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches "video", video_num_patches
), ),
video_num_patches=MultiModalFieldConfig.batched("video"), video_num_patches=MultiModalFieldConfig.batched("video"),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
) )
else: else:
video_fields = {} video_fields = {}
@ -999,6 +956,8 @@ class NanoNemotronVLDummyInputsBuilder(
class NemotronH_Nano_VL_V2( class NemotronH_Nano_VL_V2(
nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning
): ):
merge_by_field_config = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"): if modality.startswith("image"):
@ -1051,8 +1010,6 @@ class NemotronH_Nano_VL_V2(
) )
self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype) self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype)
self.img_context_token_id = None
self.video_context_token_id = None
self.config = config self.config = config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
@ -1106,37 +1063,12 @@ class NemotronH_Nano_VL_V2(
return None return None
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)): return NanoNemotronVLImageEmbeddingInputs(
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
return NanoNemotronVLImageEmbeddinInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds), data=image_embeds,
) )
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None: if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}"
)
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}"
)
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
return NanoNemotronVLImagePixelInputs( return NanoNemotronVLImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values_flat=pixel_values_flat, pixel_values_flat=pixel_values_flat,
@ -1285,28 +1217,10 @@ class NemotronH_Nano_VL_V2(
if video_embeds is not None: if video_embeds is not None:
return NanoNemotronVLVideoEmbeddingInputs( return NanoNemotronVLVideoEmbeddingInputs(
type="video_embeds", type="video_embeds",
data=flatten_bn(video_embeds), data=video_embeds,
) )
video_token_id = kwargs["video_token_id"]
assert isinstance(video_token_id, torch.Tensor)
self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None: if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}"
)
if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}"
)
pixel_values_flat_video = flatten_bn(pixel_values_flat_video, concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
expected_h = expected_w = self.config.force_image_size expected_h = expected_w = self.config.force_image_size
resolve_bindings = {"h": expected_h, "w": expected_w} resolve_bindings = {"h": expected_h, "w": expected_w}

View File

@ -496,8 +496,11 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, Suppor
) )
image_token_id = kwargs["image_token_id"] image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor) if isinstance(image_token_id, torch.Tensor):
self.img_context_token_id = image_token_id.flatten().unique().item() image_token_id = image_token_id.flatten().unique().item()
assert isinstance(image_token_id, int)
self.img_context_token_id = image_token_id
if pixel_values_flat is not None: if pixel_values_flat is not None:
return InternVLImagePixelInputs( return InternVLImagePixelInputs(

View File

@ -814,8 +814,11 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
) )
image_token_id = kwargs["image_token_id"] image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor) if isinstance(image_token_id, torch.Tensor):
self.img_context_token_id = image_token_id.flatten().unique().item() image_token_id = image_token_id.flatten().unique().item()
assert isinstance(image_token_id, int)
self.img_context_token_id = image_token_id
if pixel_values_flat is not None: if pixel_values_flat is not None:
return SkyworkR1VImagePixelInputs( return SkyworkR1VImagePixelInputs(

View File

@ -432,7 +432,7 @@ def group_mm_kwargs_by_modality(
if device is not None: if device is not None:
mm_kwargs_group = json_map_leaves( mm_kwargs_group = json_map_leaves(
lambda x: x.to(device=device), lambda x: x.to(device=device) if isinstance(x, torch.Tensor) else x,
mm_kwargs_group, mm_kwargs_group,
) )
else: else: