[Model] Use merge_by_field_config for MM models (InternVL family) (#26153)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-03 16:59:06 +08:00 committed by GitHub
parent 3e70e3d4d5
commit f9a8084e48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 84 additions and 182 deletions

View File

@ -677,7 +677,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ |

View File

@ -576,7 +576,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
# Intern-S1
def run_interns1(questions: list[str], modality: str) -> ModelRequestData:
model_name = "internlm/Intern-S1"
model_name = "internlm/Intern-S1-mini"
engine_args = EngineArgs(
model=model_name,

View File

@ -309,7 +309,7 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "internlm/Intern-S1"
model_name = "internlm/Intern-S1-mini"
engine_args = EngineArgs(
model=model_name,

View File

@ -25,7 +25,7 @@ from vllm.model_executor.models.interns1_vit import InternS1VisionModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -39,7 +39,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
@ -304,7 +304,7 @@ class InternS1MultiModalProcessor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
mm_data = dict(mm_data)
videos = mm_data.pop("videos", [])
images = mm_data.pop("images", [])
@ -342,7 +342,7 @@ class InternS1MultiModalProcessor(
image_placeholder, 1)
num_patches = [len(item) for item in image_pixel_values]
image_outputs: dict[str, NestedTensors] = {
image_outputs = {
"pixel_values": torch.concat(image_pixel_values),
"image_num_patches": torch.tensor(num_patches),
"image_token_id": torch.tensor(hf_processor.image_token_id),
@ -370,7 +370,7 @@ class InternS1MultiModalProcessor(
video_placeholder, 1)
num_frames = [len(item) for item in video_pixel_values]
video_outputs: dict[str, NestedTensors] = {
video_outputs = {
"pixel_values_videos": torch.concat(video_pixel_values),
"video_num_patches": torch.tensor(num_frames),
"video_token_id": torch.tensor(video_token_id),
@ -382,16 +382,11 @@ class InternS1MultiModalProcessor(
prompt)
text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt")
combined_outputs = dict(
**text_outputs,
**image_outputs,
**video_outputs,
)
return BatchFeature(combined_outputs)
return BatchFeature({**text_outputs, **image_outputs, **video_outputs})
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
@ -487,6 +482,7 @@ class InternS1MultiModalProcessor(
dummy_inputs=InternS1DummyInputsBuilder)
class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsLoRA):
merge_by_field_config = True
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
@ -561,7 +557,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
prefix=prefix,
)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
return InternS1MultiModalProjector(config)
def pixel_shuffle(self, x, scale_factor=0.5):
@ -599,13 +595,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternS1ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)
image_token_id = kwargs["image_token_id"]
@ -613,17 +605,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
pixel_values = flatten_bn(pixel_values, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
h, w = self.config.vision_config.image_size
return InternS1ImagePixelInputs(
type="pixel_values",
@ -638,7 +619,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]:
self, **kwargs: object) -> Optional[InternS1VideoInputs]:
pixel_values_flat_video = kwargs.pop("pixel_values_videos", None)
video_num_patches = kwargs.pop("video_num_patches", None)
video_embeds = kwargs.pop("video_embeds", None)
@ -647,13 +628,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if video_embeds is not None:
if not isinstance(video_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return InternS1ImageEmbeddingInputs(
return InternS1VideoEmbeddingInputs(
type="video_embeds",
data=flatten_bn(video_embeds),
data=video_embeds,
)
video_token_id = kwargs["video_token_id"]
@ -661,18 +638,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}")
if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}")
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
h, w = self.config.vision_config.image_size
return InternS1VideoPixelInputs(
type="pixel_values_videos",
@ -686,11 +651,12 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def _process_image_input(
def _process_vision_input(
self,
image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs],
image_input: Union[InternS1ImageInputs, InternS1VideoInputs],
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
if (image_input["type"] == "image_embeds"
or image_input["type"] == "video_embeds"):
return image_input["data"]
assert self.vision_tower is not None
@ -753,11 +719,11 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
vision_embeddings = self._process_vision_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input)
video_embeddings = self._process_vision_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings

View File

@ -17,7 +17,7 @@ import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchEncoding, PretrainedConfig, TensorType
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
@ -28,7 +28,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -42,8 +42,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
IMG_START = '<img>'
IMG_END = '</img>'
@ -471,7 +470,7 @@ class BaseInternVLProcessor(ABC):
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
image_inputs: dict[str, NestedTensors] = {
image_inputs = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
@ -502,7 +501,7 @@ class BaseInternVLProcessor(ABC):
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
text, images = [self._make_batch_input(x) for x in (text, images)]
text, image_inputs = self._preprocess_image(
@ -515,10 +514,9 @@ class BaseInternVLProcessor(ABC):
text_inputs = self.tokenizer(text)
return {
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs,
}
combined_outputs = {**text_inputs, **image_inputs}
return BatchFeature(combined_outputs, tensor_type=return_tensors)
class InternVLProcessor(BaseInternVLProcessor):
@ -598,7 +596,7 @@ class InternVLProcessor(BaseInternVLProcessor):
videos,
dynamic_image_size=dynamic_image_size,
)
video_inputs: dict[str, NestedTensors] = {
video_inputs = {
"pixel_values_flat_video":
torch.cat(pixel_values_lst_video),
"video_num_patches":
@ -622,7 +620,7 @@ class InternVLProcessor(BaseInternVLProcessor):
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
text, images, videos = [
self._make_batch_input(x) for x in (text, images, videos)
]
@ -643,11 +641,9 @@ class InternVLProcessor(BaseInternVLProcessor):
text_inputs = self.tokenizer(text)
return {
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs,
**video_inputs,
}
combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
return BatchFeature(combined_outputs, tensor_type=return_tensors)
def get_image_repl(
self,
@ -773,7 +769,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
@ -793,7 +789,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@ -948,7 +944,7 @@ class InternVLMultiModalProcessor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs, tok_kwargs)
@ -960,7 +956,7 @@ class InternVLMultiModalProcessor(
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_fields = super()._get_mm_fields_config(hf_inputs,
@ -1033,6 +1029,7 @@ class InternVLMultiModalProcessor(
dummy_inputs=InternVLDummyInputsBuilder)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
merge_by_field_config = True
supports_encoder_tp_data = True
@ -1126,7 +1123,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
else:
return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
@ -1175,13 +1172,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternVLImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)
image_token_id = kwargs["image_token_id"]
@ -1189,16 +1182,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
expected_h = expected_w = self.config.vision_config.image_size
resolve_bindings = {"h": expected_h, "w": expected_w}
@ -1223,7 +1206,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
if video_embeds is not None:
return InternVLVideoEmbeddingInputs(
type="video_embeds",
data=flatten_bn(video_embeds),
data=video_embeds,
)
video_token_id = kwargs["video_token_id"]
@ -1231,17 +1214,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}")
if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}")
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
expected_h = expected_w = self.config.vision_config.image_size
resolve_bindings = {"h": expected_h, "w": expected_w}
@ -1254,11 +1226,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
raise AssertionError("This line should be unreachable.")
def _process_image_input(
def _process_vision_input(
self,
image_input: Union[InternVLImageInputs, InternVLVideoPixelInputs],
image_input: Union[InternVLImageInputs, InternVLVideoInputs],
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
if (image_input["type"] == "image_embeds"
or image_input["type"] == "video_embeds"):
return image_input["data"]
assert self.vision_model is not None
@ -1326,11 +1299,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
vision_embeddings = self._process_vision_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input)
video_embeddings = self._process_vision_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings

View File

@ -18,8 +18,7 @@ import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import (BatchEncoding, BatchFeature, PretrainedConfig,
TensorType)
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import ReLUSquaredActivation
@ -38,8 +37,7 @@ from vllm.model_executor.models.utils import (flatten_bn,
maybe_prefix)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, MultiModalKwargsItems,
NestedTensors)
MultiModalKwargs, MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -298,7 +296,7 @@ class BaseNanoNemotronVLProcessor(ABC):
else:
pixel_values_lst = self._images_to_pixel_values_lst(
images, max_num_tiles)
image_inputs: dict[str, NestedTensors] = {
image_inputs = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
@ -326,7 +324,7 @@ class BaseNanoNemotronVLProcessor(ABC):
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
max_num_tiles: Optional[int] = None,
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
# Use default if not provided
if max_num_tiles is None:
max_num_tiles = 12
@ -341,10 +339,9 @@ class BaseNanoNemotronVLProcessor(ABC):
text_inputs = self.tokenizer(text, add_special_tokens=False)
return {
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs,
}
combined_outputs = {**text_inputs, **image_inputs}
return BatchFeature(combined_outputs, tensor_type=return_tensors)
class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
@ -420,7 +417,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
dynamic_image_size=dynamic_image_size,
)
video_inputs: dict[str, NestedTensors] = {
video_inputs = {
"pixel_values_flat_video":
torch.cat(pixel_values_lst_video),
"video_num_patches":
@ -443,7 +440,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
return_tensors: Optional[Union[str, TensorType]] = None,
max_num_tiles: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
# Use default if not provided
if max_num_tiles is None:
max_num_tiles = 12
@ -467,11 +464,9 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text_inputs = self.tokenizer(text, add_special_tokens=False)
return BatchFeature({
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs,
**video_inputs,
})
combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
return BatchFeature(combined_outputs, tensor_type=return_tensors)
def get_image_repl(
self,
@ -625,7 +620,7 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
@ -645,7 +640,7 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@ -724,7 +719,7 @@ class NanoNemotronVLMultiModalProcessor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(prompt, mm_data,
mm_kwargs, tok_kwargs)
@ -736,7 +731,7 @@ class NanoNemotronVLMultiModalProcessor(
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_fields = super()._get_mm_fields_config(hf_inputs,

View File

@ -28,7 +28,6 @@ from vllm.model_executor.models.internvl import (
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import PromptUpdateDetails
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import (
@ -37,8 +36,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
IMG_START = '<img>'
IMG_END = '</img>'
@ -289,7 +287,7 @@ class NemotronVLProcessor(InternVLProcessor):
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
image_inputs: dict[str, NestedTensors] = {
image_inputs = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
@ -344,6 +342,7 @@ class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo])
class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
merge_by_field_config = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@ -414,7 +413,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
return AutoModel.from_config(config.vision_config,
trust_remote_code=True)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vit_hidden_size
vision_projection_hidden_size = config.projector_hidden_size
llm_hidden_size = config.text_config.hidden_size
@ -467,13 +466,9 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternVLImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)
image_token_id = kwargs["image_token_id"]
@ -481,17 +476,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
return InternVLImagePixelInputs(
type="pixel_values",
pixel_values_flat=pixel_values_flat,

View File

@ -159,7 +159,7 @@ class NVLMMultiModalProcessor(
dummy_inputs=NVLMDummyInputsBuilder)
class NVLM_D_Model(InternVLChatModel):
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size
llm_intermediate_size = config.text_config.intermediate_size
llm_hidden_size = config.text_config.hidden_size

View File

@ -14,7 +14,7 @@ import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import BatchEncoding, PretrainedConfig, TensorType
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
@ -25,7 +25,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import convert_image_mode
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
IMG_START = '<img>'
IMG_END = '</img>'
@ -399,7 +398,7 @@ class SkyworkR1VProcessor:
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
@ -418,7 +417,7 @@ class SkyworkR1VProcessor:
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
image_inputs: dict[str, NestedTensors] = {
image_inputs = {
"pixel_values_flat":
torch.cat(pixel_values_lst),
"image_num_patches":
@ -435,10 +434,9 @@ class SkyworkR1VProcessor:
text_inputs = self.tokenizer(text)
return {
**BatchEncoding(text_inputs, tensor_type=return_tensors),
**image_inputs,
}
combined_outputs = {**text_inputs, **image_inputs}
return BatchFeature(combined_outputs, tensor_type=return_tensors)
class SkyworkR1VProcessingInfo(BaseProcessingInfo):
@ -529,7 +527,7 @@ class SkyworkR1VMultiModalProcessor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
@ -549,7 +547,7 @@ class SkyworkR1VMultiModalProcessor(
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
@ -617,6 +615,7 @@ class SkyworkR1VMultiModalProcessor(
info=SkyworkR1VProcessingInfo,
dummy_inputs=SkyworkR1VDummyInputsBuilder)
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@ -703,7 +702,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else:
return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
@ -756,13 +755,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return SkyworkR1VImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)
image_token_id = kwargs["image_token_id"]
@ -770,17 +765,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
return SkyworkR1VImagePixelInputs(
type="pixel_values",
pixel_values_flat=pixel_values_flat,