[Model] Use merge_by_field_config for MM models (InternVL family) (#26153)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Cyrus Leung 2025-10-03 16:59:06 +08:00 committed by yewentao256
parent edaae1825f
commit c81dc099a3
9 changed files with 84 additions and 182 deletions

View File

@ -677,7 +677,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | | `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | | `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ | | `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ |

View File

@ -576,7 +576,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
# Intern-S1 # Intern-S1
def run_interns1(questions: list[str], modality: str) -> ModelRequestData: def run_interns1(questions: list[str], modality: str) -> ModelRequestData:
model_name = "internlm/Intern-S1" model_name = "internlm/Intern-S1-mini"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,

View File

@ -309,7 +309,7 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData: def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "internlm/Intern-S1" model_name = "internlm/Intern-S1-mini"
engine_args = EngineArgs( engine_args = EngineArgs(
model=model_name, model=model_name,

View File

@ -25,7 +25,7 @@ from vllm.model_executor.models.interns1_vit import InternS1VisionModel
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors) MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -39,7 +39,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix) init_vllm_registered_model, maybe_prefix)
@ -304,7 +304,7 @@ class InternS1MultiModalProcessor(
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> BatchFeature:
mm_data = dict(mm_data) mm_data = dict(mm_data)
videos = mm_data.pop("videos", []) videos = mm_data.pop("videos", [])
images = mm_data.pop("images", []) images = mm_data.pop("images", [])
@ -342,7 +342,7 @@ class InternS1MultiModalProcessor(
image_placeholder, 1) image_placeholder, 1)
num_patches = [len(item) for item in image_pixel_values] num_patches = [len(item) for item in image_pixel_values]
image_outputs: dict[str, NestedTensors] = { image_outputs = {
"pixel_values": torch.concat(image_pixel_values), "pixel_values": torch.concat(image_pixel_values),
"image_num_patches": torch.tensor(num_patches), "image_num_patches": torch.tensor(num_patches),
"image_token_id": torch.tensor(hf_processor.image_token_id), "image_token_id": torch.tensor(hf_processor.image_token_id),
@ -370,7 +370,7 @@ class InternS1MultiModalProcessor(
video_placeholder, 1) video_placeholder, 1)
num_frames = [len(item) for item in video_pixel_values] num_frames = [len(item) for item in video_pixel_values]
video_outputs: dict[str, NestedTensors] = { video_outputs = {
"pixel_values_videos": torch.concat(video_pixel_values), "pixel_values_videos": torch.concat(video_pixel_values),
"video_num_patches": torch.tensor(num_frames), "video_num_patches": torch.tensor(num_frames),
"video_token_id": torch.tensor(video_token_id), "video_token_id": torch.tensor(video_token_id),
@ -382,16 +382,11 @@ class InternS1MultiModalProcessor(
prompt) prompt)
text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt") text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt")
combined_outputs = dict( return BatchFeature({**text_outputs, **image_outputs, **video_outputs})
**text_outputs,
**image_outputs,
**video_outputs,
)
return BatchFeature(combined_outputs)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: Mapping[str, NestedTensors], hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
@ -487,6 +482,7 @@ class InternS1MultiModalProcessor(
dummy_inputs=InternS1DummyInputsBuilder) dummy_inputs=InternS1DummyInputsBuilder)
class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal, class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsLoRA): SupportsPP, SupportsLoRA):
merge_by_field_config = True
# To ensure correct weight loading and mapping. # To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper( hf_to_vllm_mapper = WeightsMapper(
@ -561,7 +557,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
prefix=prefix, prefix=prefix,
) )
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
return InternS1MultiModalProjector(config) return InternS1MultiModalProjector(config)
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
@ -599,13 +595,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
return None return None
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternS1ImageEmbeddingInputs( return InternS1ImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds), data=image_embeds,
) )
image_token_id = kwargs["image_token_id"] image_token_id = kwargs["image_token_id"]
@ -613,17 +605,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.img_context_token_id = image_token_id.flatten().unique().item() self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values is not None: if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
pixel_values = flatten_bn(pixel_values, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
h, w = self.config.vision_config.image_size h, w = self.config.vision_config.image_size
return InternS1ImagePixelInputs( return InternS1ImagePixelInputs(
type="pixel_values", type="pixel_values",
@ -638,7 +619,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input( def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]: self, **kwargs: object) -> Optional[InternS1VideoInputs]:
pixel_values_flat_video = kwargs.pop("pixel_values_videos", None) pixel_values_flat_video = kwargs.pop("pixel_values_videos", None)
video_num_patches = kwargs.pop("video_num_patches", None) video_num_patches = kwargs.pop("video_num_patches", None)
video_embeds = kwargs.pop("video_embeds", None) video_embeds = kwargs.pop("video_embeds", None)
@ -647,13 +628,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
return None return None
if video_embeds is not None: if video_embeds is not None:
if not isinstance(video_embeds, (torch.Tensor, list)): return InternS1VideoEmbeddingInputs(
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return InternS1ImageEmbeddingInputs(
type="video_embeds", type="video_embeds",
data=flatten_bn(video_embeds), data=video_embeds,
) )
video_token_id = kwargs["video_token_id"] video_token_id = kwargs["video_token_id"]
@ -661,18 +638,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.video_context_token_id = video_token_id.flatten().unique().item() self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None: if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}")
if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}")
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
h, w = self.config.vision_config.image_size h, w = self.config.vision_config.image_size
return InternS1VideoPixelInputs( return InternS1VideoPixelInputs(
type="pixel_values_videos", type="pixel_values_videos",
@ -686,11 +651,12 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _process_image_input( def _process_vision_input(
self, self,
image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs], image_input: Union[InternS1ImageInputs, InternS1VideoInputs],
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds": if (image_input["type"] == "image_embeds"
or image_input["type"] == "video_embeds"):
return image_input["data"] return image_input["data"]
assert self.vision_tower is not None assert self.vision_tower is not None
@ -753,11 +719,11 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_vision_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += vision_embeddings
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input) video_embeddings = self._process_vision_input(video_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += video_embeddings
return multimodal_embeddings return multimodal_embeddings

View File

@ -17,7 +17,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torchvision.transforms as T import torchvision.transforms as T
from PIL import Image from PIL import Image
from transformers import BatchEncoding, PretrainedConfig, TensorType from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@ -28,7 +28,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors) MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -42,8 +42,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
maybe_prefix)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
@ -471,7 +470,7 @@ class BaseInternVLProcessor(ABC):
max_dynamic_patch=max_dynamic_patch, max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
) )
image_inputs: dict[str, NestedTensors] = { image_inputs = {
"pixel_values_flat": "pixel_values_flat":
torch.cat(pixel_values_lst), torch.cat(pixel_values_lst),
"image_num_patches": "image_num_patches":
@ -502,7 +501,7 @@ class BaseInternVLProcessor(ABC):
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
) -> Mapping[str, NestedTensors]: ) -> BatchFeature:
text, images = [self._make_batch_input(x) for x in (text, images)] text, images = [self._make_batch_input(x) for x in (text, images)]
text, image_inputs = self._preprocess_image( text, image_inputs = self._preprocess_image(
@ -515,10 +514,9 @@ class BaseInternVLProcessor(ABC):
text_inputs = self.tokenizer(text) text_inputs = self.tokenizer(text)
return { combined_outputs = {**text_inputs, **image_inputs}
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs, return BatchFeature(combined_outputs, tensor_type=return_tensors)
}
class InternVLProcessor(BaseInternVLProcessor): class InternVLProcessor(BaseInternVLProcessor):
@ -598,7 +596,7 @@ class InternVLProcessor(BaseInternVLProcessor):
videos, videos,
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
) )
video_inputs: dict[str, NestedTensors] = { video_inputs = {
"pixel_values_flat_video": "pixel_values_flat_video":
torch.cat(pixel_values_lst_video), torch.cat(pixel_values_lst_video),
"video_num_patches": "video_num_patches":
@ -622,7 +620,7 @@ class InternVLProcessor(BaseInternVLProcessor):
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
) -> Mapping[str, NestedTensors]: ) -> BatchFeature:
text, images, videos = [ text, images, videos = [
self._make_batch_input(x) for x in (text, images, videos) self._make_batch_input(x) for x in (text, images, videos)
] ]
@ -643,11 +641,9 @@ class InternVLProcessor(BaseInternVLProcessor):
text_inputs = self.tokenizer(text) text_inputs = self.tokenizer(text)
return { combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs, return BatchFeature(combined_outputs, tensor_type=return_tensors)
**video_inputs,
}
def get_image_repl( def get_image_repl(
self, self,
@ -773,7 +769,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> BatchFeature:
processed_outputs = super()._call_hf_processor( processed_outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
@ -793,7 +789,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: Mapping[str, NestedTensors], hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@ -948,7 +944,7 @@ class InternVLMultiModalProcessor(
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> BatchFeature:
processed_outputs = super()._call_hf_processor(prompt, mm_data, processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs, tok_kwargs) mm_kwargs, tok_kwargs)
@ -960,7 +956,7 @@ class InternVLMultiModalProcessor(
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: Mapping[str, NestedTensors], hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_fields = super()._get_mm_fields_config(hf_inputs, image_fields = super()._get_mm_fields_config(hf_inputs,
@ -1033,6 +1029,7 @@ class InternVLMultiModalProcessor(
dummy_inputs=InternVLDummyInputsBuilder) dummy_inputs=InternVLDummyInputsBuilder)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA): SupportsLoRA):
merge_by_field_config = True
supports_encoder_tp_data = True supports_encoder_tp_data = True
@ -1126,7 +1123,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
else: else:
return InternVisionPatchModel(config.vision_config) return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size llm_hidden_size = config.text_config.hidden_size
@ -1175,13 +1172,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
return None return None
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternVLImageEmbeddingInputs( return InternVLImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds), data=image_embeds,
) )
image_token_id = kwargs["image_token_id"] image_token_id = kwargs["image_token_id"]
@ -1189,16 +1182,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self.img_context_token_id = image_token_id.flatten().unique().item() self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None: if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
expected_h = expected_w = self.config.vision_config.image_size expected_h = expected_w = self.config.vision_config.image_size
resolve_bindings = {"h": expected_h, "w": expected_w} resolve_bindings = {"h": expected_h, "w": expected_w}
@ -1223,7 +1206,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
if video_embeds is not None: if video_embeds is not None:
return InternVLVideoEmbeddingInputs( return InternVLVideoEmbeddingInputs(
type="video_embeds", type="video_embeds",
data=flatten_bn(video_embeds), data=video_embeds,
) )
video_token_id = kwargs["video_token_id"] video_token_id = kwargs["video_token_id"]
@ -1231,17 +1214,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self.video_context_token_id = video_token_id.flatten().unique().item() self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None: if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}")
if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}")
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
expected_h = expected_w = self.config.vision_config.image_size expected_h = expected_w = self.config.vision_config.image_size
resolve_bindings = {"h": expected_h, "w": expected_w} resolve_bindings = {"h": expected_h, "w": expected_w}
@ -1254,11 +1226,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
raise AssertionError("This line should be unreachable.") raise AssertionError("This line should be unreachable.")
def _process_image_input( def _process_vision_input(
self, self,
image_input: Union[InternVLImageInputs, InternVLVideoPixelInputs], image_input: Union[InternVLImageInputs, InternVLVideoInputs],
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds": if (image_input["type"] == "image_embeds"
or image_input["type"] == "video_embeds"):
return image_input["data"] return image_input["data"]
assert self.vision_model is not None assert self.vision_model is not None
@ -1326,11 +1299,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
for modality in modalities: for modality in modalities:
if modality == "images": if modality == "images":
image_input = modalities["images"] image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_vision_input(image_input)
multimodal_embeddings += vision_embeddings multimodal_embeddings += vision_embeddings
if modality == "videos": if modality == "videos":
video_input = modalities["videos"] video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input) video_embeddings = self._process_vision_input(video_input)
multimodal_embeddings += video_embeddings multimodal_embeddings += video_embeddings
return multimodal_embeddings return multimodal_embeddings

View File

@ -18,8 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torchvision.transforms as T import torchvision.transforms as T
from PIL import Image from PIL import Image
from transformers import (BatchEncoding, BatchFeature, PretrainedConfig, from transformers import BatchFeature, PretrainedConfig, TensorType
TensorType)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import ReLUSquaredActivation from vllm.model_executor.layers.activation import ReLUSquaredActivation
@ -38,8 +37,7 @@ from vllm.model_executor.models.utils import (flatten_bn,
maybe_prefix) maybe_prefix)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, MultiModalKwargsItems, MultiModalKwargs, MultiModalKwargsItems)
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -298,7 +296,7 @@ class BaseNanoNemotronVLProcessor(ABC):
else: else:
pixel_values_lst = self._images_to_pixel_values_lst( pixel_values_lst = self._images_to_pixel_values_lst(
images, max_num_tiles) images, max_num_tiles)
image_inputs: dict[str, NestedTensors] = { image_inputs = {
"pixel_values_flat": "pixel_values_flat":
torch.cat(pixel_values_lst), torch.cat(pixel_values_lst),
"image_num_patches": "image_num_patches":
@ -326,7 +324,7 @@ class BaseNanoNemotronVLProcessor(ABC):
images: Optional[Union[Image.Image, list[Image.Image]]] = None, images: Optional[Union[Image.Image, list[Image.Image]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
max_num_tiles: Optional[int] = None, max_num_tiles: Optional[int] = None,
) -> Mapping[str, NestedTensors]: ) -> BatchFeature:
# Use default if not provided # Use default if not provided
if max_num_tiles is None: if max_num_tiles is None:
max_num_tiles = 12 max_num_tiles = 12
@ -341,10 +339,9 @@ class BaseNanoNemotronVLProcessor(ABC):
text_inputs = self.tokenizer(text, add_special_tokens=False) text_inputs = self.tokenizer(text, add_special_tokens=False)
return { combined_outputs = {**text_inputs, **image_inputs}
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs, return BatchFeature(combined_outputs, tensor_type=return_tensors)
}
class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
@ -420,7 +417,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
) )
video_inputs: dict[str, NestedTensors] = { video_inputs = {
"pixel_values_flat_video": "pixel_values_flat_video":
torch.cat(pixel_values_lst_video), torch.cat(pixel_values_lst_video),
"video_num_patches": "video_num_patches":
@ -443,7 +440,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
max_num_tiles: Optional[int] = None, max_num_tiles: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
) -> Mapping[str, NestedTensors]: ) -> BatchFeature:
# Use default if not provided # Use default if not provided
if max_num_tiles is None: if max_num_tiles is None:
max_num_tiles = 12 max_num_tiles = 12
@ -467,11 +464,9 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text_inputs = self.tokenizer(text, add_special_tokens=False) text_inputs = self.tokenizer(text, add_special_tokens=False)
return BatchFeature({ combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs, return BatchFeature(combined_outputs, tensor_type=return_tensors)
**video_inputs,
})
def get_image_repl( def get_image_repl(
self, self,
@ -625,7 +620,7 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> BatchFeature:
processed_outputs = super()._call_hf_processor( processed_outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
@ -645,7 +640,7 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: Mapping[str, NestedTensors], hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@ -724,7 +719,7 @@ class NanoNemotronVLMultiModalProcessor(
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> BatchFeature:
processed_outputs = super()._call_hf_processor(prompt, mm_data, processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs, tok_kwargs) mm_kwargs, tok_kwargs)
@ -736,7 +731,7 @@ class NanoNemotronVLMultiModalProcessor(
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: Mapping[str, NestedTensors], hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_fields = super()._get_mm_fields_config(hf_inputs, image_fields = super()._get_mm_fields_config(hf_inputs,

View File

@ -28,7 +28,6 @@ from vllm.model_executor.models.internvl import (
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import PromptUpdateDetails from vllm.multimodal.processing import PromptUpdateDetails
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import ( from vllm.transformers_utils.processor import (
@ -37,8 +36,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP) SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
maybe_prefix)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
@ -289,7 +287,7 @@ class NemotronVLProcessor(InternVLProcessor):
max_dynamic_patch=max_dynamic_patch, max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
) )
image_inputs: dict[str, NestedTensors] = { image_inputs = {
"pixel_values_flat": "pixel_values_flat":
torch.cat(pixel_values_lst), torch.cat(pixel_values_lst),
"image_num_patches": "image_num_patches":
@ -344,6 +342,7 @@ class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo]) dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo])
class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA): SupportsLoRA):
merge_by_field_config = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@ -414,7 +413,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
return AutoModel.from_config(config.vision_config, return AutoModel.from_config(config.vision_config,
trust_remote_code=True) trust_remote_code=True)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vit_hidden_size vit_hidden_size = config.vit_hidden_size
vision_projection_hidden_size = config.projector_hidden_size vision_projection_hidden_size = config.projector_hidden_size
llm_hidden_size = config.text_config.hidden_size llm_hidden_size = config.text_config.hidden_size
@ -467,13 +466,9 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
return None return None
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternVLImageEmbeddingInputs( return InternVLImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds), data=image_embeds,
) )
image_token_id = kwargs["image_token_id"] image_token_id = kwargs["image_token_id"]
@ -481,17 +476,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self.img_context_token_id = image_token_id.flatten().unique().item() self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None: if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
return InternVLImagePixelInputs( return InternVLImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values_flat=pixel_values_flat, pixel_values_flat=pixel_values_flat,

View File

@ -159,7 +159,7 @@ class NVLMMultiModalProcessor(
dummy_inputs=NVLMDummyInputsBuilder) dummy_inputs=NVLMDummyInputsBuilder)
class NVLM_D_Model(InternVLChatModel): class NVLM_D_Model(InternVLChatModel):
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size vit_hidden_size = config.vision_config.hidden_size
llm_intermediate_size = config.text_config.intermediate_size llm_intermediate_size = config.text_config.intermediate_size
llm_hidden_size = config.text_config.hidden_size llm_hidden_size = config.text_config.hidden_size

View File

@ -14,7 +14,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torchvision.transforms as T import torchvision.transforms as T
from PIL import Image from PIL import Image
from transformers import BatchEncoding, PretrainedConfig, TensorType from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
@ -25,7 +25,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors) MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
maybe_prefix)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
@ -399,7 +398,7 @@ class SkyworkR1VProcessor:
max_dynamic_patch: Optional[int] = None, max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None, dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None, return_tensors: Optional[Union[str, TensorType]] = None,
) -> Mapping[str, NestedTensors]: ) -> BatchFeature:
if text is None: if text is None:
text = [] text = []
if not isinstance(text, list): if not isinstance(text, list):
@ -418,7 +417,7 @@ class SkyworkR1VProcessor:
max_dynamic_patch=max_dynamic_patch, max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size, dynamic_image_size=dynamic_image_size,
) )
image_inputs: dict[str, NestedTensors] = { image_inputs = {
"pixel_values_flat": "pixel_values_flat":
torch.cat(pixel_values_lst), torch.cat(pixel_values_lst),
"image_num_patches": "image_num_patches":
@ -435,10 +434,9 @@ class SkyworkR1VProcessor:
text_inputs = self.tokenizer(text) text_inputs = self.tokenizer(text)
return { combined_outputs = {**text_inputs, **image_inputs}
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs, return BatchFeature(combined_outputs, tensor_type=return_tensors)
}
class SkyworkR1VProcessingInfo(BaseProcessingInfo): class SkyworkR1VProcessingInfo(BaseProcessingInfo):
@ -529,7 +527,7 @@ class SkyworkR1VMultiModalProcessor(
mm_data: Mapping[str, object], mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object], mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object], tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]: ) -> BatchFeature:
processed_outputs = super()._call_hf_processor( processed_outputs = super()._call_hf_processor(
prompt=prompt, prompt=prompt,
mm_data=mm_data, mm_data=mm_data,
@ -549,7 +547,7 @@ class SkyworkR1VMultiModalProcessor(
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: Mapping[str, NestedTensors], hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@ -617,6 +615,7 @@ class SkyworkR1VMultiModalProcessor(
info=SkyworkR1VProcessingInfo, info=SkyworkR1VProcessingInfo,
dummy_inputs=SkyworkR1VDummyInputsBuilder) dummy_inputs=SkyworkR1VDummyInputsBuilder)
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@ -703,7 +702,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else: else:
return InternVisionPatchModel(config.vision_config) return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size llm_hidden_size = config.text_config.hidden_size
@ -756,13 +755,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return None return None
if image_embeds is not None: if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return SkyworkR1VImageEmbeddingInputs( return SkyworkR1VImageEmbeddingInputs(
type="image_embeds", type="image_embeds",
data=flatten_bn(image_embeds), data=image_embeds,
) )
image_token_id = kwargs["image_token_id"] image_token_id = kwargs["image_token_id"]
@ -770,17 +765,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.img_context_token_id = image_token_id.flatten().unique().item() self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None: if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
return SkyworkR1VImagePixelInputs( return SkyworkR1VImagePixelInputs(
type="pixel_values", type="pixel_values",
pixel_values_flat=pixel_values_flat, pixel_values_flat=pixel_values_flat,