mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 00:05:38 +08:00
[Model] Use merge_by_field_config for MM models (InternVL family) (#26153)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
3e70e3d4d5
commit
f9a8084e48
@ -677,7 +677,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
|
||||
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, `internlm/Intern-S1-mini`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
||||
@ -576,7 +576,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData:
|
||||
|
||||
# Intern-S1
|
||||
def run_interns1(questions: list[str], modality: str) -> ModelRequestData:
|
||||
model_name = "internlm/Intern-S1"
|
||||
model_name = "internlm/Intern-S1-mini"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
|
||||
@ -309,7 +309,7 @@ def load_idefics3(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
|
||||
|
||||
def load_interns1(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||
model_name = "internlm/Intern-S1"
|
||||
model_name = "internlm/Intern-S1-mini"
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
|
||||
@ -25,7 +25,7 @@ from vllm.model_executor.models.interns1_vit import InternS1VisionModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, NestedTensors)
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@ -39,7 +39,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
|
||||
@ -304,7 +304,7 @@ class InternS1MultiModalProcessor(
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> BatchFeature:
|
||||
mm_data = dict(mm_data)
|
||||
videos = mm_data.pop("videos", [])
|
||||
images = mm_data.pop("images", [])
|
||||
@ -342,7 +342,7 @@ class InternS1MultiModalProcessor(
|
||||
image_placeholder, 1)
|
||||
|
||||
num_patches = [len(item) for item in image_pixel_values]
|
||||
image_outputs: dict[str, NestedTensors] = {
|
||||
image_outputs = {
|
||||
"pixel_values": torch.concat(image_pixel_values),
|
||||
"image_num_patches": torch.tensor(num_patches),
|
||||
"image_token_id": torch.tensor(hf_processor.image_token_id),
|
||||
@ -370,7 +370,7 @@ class InternS1MultiModalProcessor(
|
||||
video_placeholder, 1)
|
||||
|
||||
num_frames = [len(item) for item in video_pixel_values]
|
||||
video_outputs: dict[str, NestedTensors] = {
|
||||
video_outputs = {
|
||||
"pixel_values_videos": torch.concat(video_pixel_values),
|
||||
"video_num_patches": torch.tensor(num_frames),
|
||||
"video_token_id": torch.tensor(video_token_id),
|
||||
@ -382,16 +382,11 @@ class InternS1MultiModalProcessor(
|
||||
prompt)
|
||||
text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt")
|
||||
|
||||
combined_outputs = dict(
|
||||
**text_outputs,
|
||||
**image_outputs,
|
||||
**video_outputs,
|
||||
)
|
||||
return BatchFeature(combined_outputs)
|
||||
return BatchFeature({**text_outputs, **image_outputs, **video_outputs})
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
|
||||
@ -487,6 +482,7 @@ class InternS1MultiModalProcessor(
|
||||
dummy_inputs=InternS1DummyInputsBuilder)
|
||||
class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP, SupportsLoRA):
|
||||
merge_by_field_config = True
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
@ -561,7 +557,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
|
||||
return InternS1MultiModalProjector(config)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
@ -599,13 +595,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return InternS1ImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
@ -613,17 +605,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(image_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(image_num_patches)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
|
||||
h, w = self.config.vision_config.image_size
|
||||
return InternS1ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
@ -638,7 +619,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]:
|
||||
self, **kwargs: object) -> Optional[InternS1VideoInputs]:
|
||||
pixel_values_flat_video = kwargs.pop("pixel_values_videos", None)
|
||||
video_num_patches = kwargs.pop("video_num_patches", None)
|
||||
video_embeds = kwargs.pop("video_embeds", None)
|
||||
@ -647,13 +628,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if video_embeds is not None:
|
||||
if not isinstance(video_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of video embeddings. "
|
||||
f"Got type: {type(video_embeds)}")
|
||||
|
||||
return InternS1ImageEmbeddingInputs(
|
||||
return InternS1VideoEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
data=flatten_bn(video_embeds),
|
||||
data=video_embeds,
|
||||
)
|
||||
|
||||
video_token_id = kwargs["video_token_id"]
|
||||
@ -661,18 +638,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.video_context_token_id = video_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values_flat_video is not None:
|
||||
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat_video)}")
|
||||
|
||||
if not isinstance(video_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(video_num_patches)}")
|
||||
|
||||
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
|
||||
concat=True)
|
||||
video_num_patches = flatten_bn(video_num_patches, concat=True)
|
||||
|
||||
h, w = self.config.vision_config.image_size
|
||||
return InternS1VideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
@ -686,11 +651,12 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
def _process_vision_input(
|
||||
self,
|
||||
image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs],
|
||||
image_input: Union[InternS1ImageInputs, InternS1VideoInputs],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
if (image_input["type"] == "image_embeds"
|
||||
or image_input["type"] == "video_embeds"):
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_tower is not None
|
||||
@ -753,11 +719,11 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
vision_embeddings = self._process_vision_input(image_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_image_input(video_input)
|
||||
video_embeddings = self._process_vision_input(video_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
@ -17,7 +17,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import BatchEncoding, PretrainedConfig, TensorType
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@ -28,7 +28,7 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, NestedTensors)
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@ -42,8 +42,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -471,7 +470,7 @@ class BaseInternVLProcessor(ABC):
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
image_inputs: dict[str, NestedTensors] = {
|
||||
image_inputs = {
|
||||
"pixel_values_flat":
|
||||
torch.cat(pixel_values_lst),
|
||||
"image_num_patches":
|
||||
@ -502,7 +501,7 @@ class BaseInternVLProcessor(ABC):
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> BatchFeature:
|
||||
text, images = [self._make_batch_input(x) for x in (text, images)]
|
||||
|
||||
text, image_inputs = self._preprocess_image(
|
||||
@ -515,10 +514,9 @@ class BaseInternVLProcessor(ABC):
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
return {
|
||||
**BatchEncoding(text_inputs, tensor_type=return_tensors),
|
||||
**image_inputs,
|
||||
}
|
||||
combined_outputs = {**text_inputs, **image_inputs}
|
||||
|
||||
return BatchFeature(combined_outputs, tensor_type=return_tensors)
|
||||
|
||||
|
||||
class InternVLProcessor(BaseInternVLProcessor):
|
||||
@ -598,7 +596,7 @@ class InternVLProcessor(BaseInternVLProcessor):
|
||||
videos,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
video_inputs: dict[str, NestedTensors] = {
|
||||
video_inputs = {
|
||||
"pixel_values_flat_video":
|
||||
torch.cat(pixel_values_lst_video),
|
||||
"video_num_patches":
|
||||
@ -622,7 +620,7 @@ class InternVLProcessor(BaseInternVLProcessor):
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> BatchFeature:
|
||||
text, images, videos = [
|
||||
self._make_batch_input(x) for x in (text, images, videos)
|
||||
]
|
||||
@ -643,11 +641,9 @@ class InternVLProcessor(BaseInternVLProcessor):
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
return {
|
||||
**BatchEncoding(text_inputs, tensor_type=return_tensors),
|
||||
**image_inputs,
|
||||
**video_inputs,
|
||||
}
|
||||
combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
|
||||
|
||||
return BatchFeature(combined_outputs, tensor_type=return_tensors)
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
@ -773,7 +769,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
@ -793,7 +789,7 @@ class BaseInternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||
@ -948,7 +944,7 @@ class InternVLMultiModalProcessor(
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(prompt, mm_data,
|
||||
mm_kwargs, tok_kwargs)
|
||||
|
||||
@ -960,7 +956,7 @@ class InternVLMultiModalProcessor(
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_fields = super()._get_mm_fields_config(hf_inputs,
|
||||
@ -1033,6 +1029,7 @@ class InternVLMultiModalProcessor(
|
||||
dummy_inputs=InternVLDummyInputsBuilder)
|
||||
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsLoRA):
|
||||
merge_by_field_config = True
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
@ -1126,7 +1123,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
else:
|
||||
return InternVisionPatchModel(config.vision_config)
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
@ -1175,13 +1172,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return InternVLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
@ -1189,16 +1182,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values_flat is not None:
|
||||
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat)}")
|
||||
|
||||
if not isinstance(image_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(image_num_patches)}")
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
expected_h = expected_w = self.config.vision_config.image_size
|
||||
resolve_bindings = {"h": expected_h, "w": expected_w}
|
||||
|
||||
@ -1223,7 +1206,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
if video_embeds is not None:
|
||||
return InternVLVideoEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
data=flatten_bn(video_embeds),
|
||||
data=video_embeds,
|
||||
)
|
||||
|
||||
video_token_id = kwargs["video_token_id"]
|
||||
@ -1231,17 +1214,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self.video_context_token_id = video_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values_flat_video is not None:
|
||||
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat_video)}")
|
||||
|
||||
if not isinstance(video_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(video_num_patches)}")
|
||||
|
||||
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
|
||||
concat=True)
|
||||
video_num_patches = flatten_bn(video_num_patches, concat=True)
|
||||
expected_h = expected_w = self.config.vision_config.image_size
|
||||
resolve_bindings = {"h": expected_h, "w": expected_w}
|
||||
|
||||
@ -1254,11 +1226,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
def _process_vision_input(
|
||||
self,
|
||||
image_input: Union[InternVLImageInputs, InternVLVideoPixelInputs],
|
||||
image_input: Union[InternVLImageInputs, InternVLVideoInputs],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
if (image_input["type"] == "image_embeds"
|
||||
or image_input["type"] == "video_embeds"):
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_model is not None
|
||||
@ -1326,11 +1299,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
vision_embeddings = self._process_vision_input(image_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_image_input(video_input)
|
||||
video_embeddings = self._process_vision_input(video_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
@ -18,8 +18,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import (BatchEncoding, BatchFeature, PretrainedConfig,
|
||||
TensorType)
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
@ -38,8 +37,7 @@ from vllm.model_executor.models.utils import (flatten_bn,
|
||||
maybe_prefix)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargs, MultiModalKwargsItems,
|
||||
NestedTensors)
|
||||
MultiModalKwargs, MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@ -298,7 +296,7 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
else:
|
||||
pixel_values_lst = self._images_to_pixel_values_lst(
|
||||
images, max_num_tiles)
|
||||
image_inputs: dict[str, NestedTensors] = {
|
||||
image_inputs = {
|
||||
"pixel_values_flat":
|
||||
torch.cat(pixel_values_lst),
|
||||
"image_num_patches":
|
||||
@ -326,7 +324,7 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
max_num_tiles: Optional[int] = None,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> BatchFeature:
|
||||
# Use default if not provided
|
||||
if max_num_tiles is None:
|
||||
max_num_tiles = 12
|
||||
@ -341,10 +339,9 @@ class BaseNanoNemotronVLProcessor(ABC):
|
||||
|
||||
text_inputs = self.tokenizer(text, add_special_tokens=False)
|
||||
|
||||
return {
|
||||
**BatchEncoding(text_inputs, tensor_type=return_tensors),
|
||||
**image_inputs,
|
||||
}
|
||||
combined_outputs = {**text_inputs, **image_inputs}
|
||||
|
||||
return BatchFeature(combined_outputs, tensor_type=return_tensors)
|
||||
|
||||
|
||||
class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
@ -420,7 +417,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
video_inputs: dict[str, NestedTensors] = {
|
||||
video_inputs = {
|
||||
"pixel_values_flat_video":
|
||||
torch.cat(pixel_values_lst_video),
|
||||
"video_num_patches":
|
||||
@ -443,7 +440,7 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
max_num_tiles: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> BatchFeature:
|
||||
# Use default if not provided
|
||||
if max_num_tiles is None:
|
||||
max_num_tiles = 12
|
||||
@ -467,11 +464,9 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
|
||||
|
||||
text_inputs = self.tokenizer(text, add_special_tokens=False)
|
||||
|
||||
return BatchFeature({
|
||||
**BatchEncoding(text_inputs, tensor_type=return_tensors),
|
||||
**image_inputs,
|
||||
**video_inputs,
|
||||
})
|
||||
combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
|
||||
|
||||
return BatchFeature(combined_outputs, tensor_type=return_tensors)
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
@ -625,7 +620,7 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
@ -645,7 +640,7 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||
@ -724,7 +719,7 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(prompt, mm_data,
|
||||
mm_kwargs, tok_kwargs)
|
||||
|
||||
@ -736,7 +731,7 @@ class NanoNemotronVLMultiModalProcessor(
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_fields = super()._get_mm_fields_config(hf_inputs,
|
||||
|
||||
@ -28,7 +28,6 @@ from vllm.model_executor.models.internvl import (
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.processing import PromptUpdateDetails
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.processor import (
|
||||
@ -37,8 +36,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -289,7 +287,7 @@ class NemotronVLProcessor(InternVLProcessor):
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
image_inputs: dict[str, NestedTensors] = {
|
||||
image_inputs = {
|
||||
"pixel_values_flat":
|
||||
torch.cat(pixel_values_lst),
|
||||
"image_num_patches":
|
||||
@ -344,6 +342,7 @@ class NemotronVLProcessingInfo(BaseInternVLProcessingInfo):
|
||||
dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo])
|
||||
class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsLoRA):
|
||||
merge_by_field_config = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@ -414,7 +413,7 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
return AutoModel.from_config(config.vision_config,
|
||||
trust_remote_code=True)
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
|
||||
vit_hidden_size = config.vit_hidden_size
|
||||
vision_projection_hidden_size = config.projector_hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
@ -467,13 +466,9 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return InternVLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
@ -481,17 +476,6 @@ class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values_flat is not None:
|
||||
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat)}")
|
||||
|
||||
if not isinstance(image_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(image_num_patches)}")
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values_flat=pixel_values_flat,
|
||||
|
||||
@ -159,7 +159,7 @@ class NVLMMultiModalProcessor(
|
||||
dummy_inputs=NVLMDummyInputsBuilder)
|
||||
class NVLM_D_Model(InternVLChatModel):
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_intermediate_size = config.text_config.intermediate_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
@ -14,7 +14,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import BatchEncoding, PretrainedConfig, TensorType
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
@ -25,7 +25,7 @@ from vllm.model_executor.models.intern_vit import (InternVisionModel,
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, NestedTensors)
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
@ -399,7 +398,7 @@ class SkyworkR1VProcessor:
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
@ -418,7 +417,7 @@ class SkyworkR1VProcessor:
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
image_inputs: dict[str, NestedTensors] = {
|
||||
image_inputs = {
|
||||
"pixel_values_flat":
|
||||
torch.cat(pixel_values_lst),
|
||||
"image_num_patches":
|
||||
@ -435,10 +434,9 @@ class SkyworkR1VProcessor:
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
return {
|
||||
**BatchEncoding(text_inputs, tensor_type=return_tensors),
|
||||
**image_inputs,
|
||||
}
|
||||
combined_outputs = {**text_inputs, **image_inputs}
|
||||
|
||||
return BatchFeature(combined_outputs, tensor_type=return_tensors)
|
||||
|
||||
|
||||
class SkyworkR1VProcessingInfo(BaseProcessingInfo):
|
||||
@ -529,7 +527,7 @@ class SkyworkR1VMultiModalProcessor(
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
@ -549,7 +547,7 @@ class SkyworkR1VMultiModalProcessor(
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||
@ -617,6 +615,7 @@ class SkyworkR1VMultiModalProcessor(
|
||||
info=SkyworkR1VProcessingInfo,
|
||||
dummy_inputs=SkyworkR1VDummyInputsBuilder)
|
||||
class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@ -703,7 +702,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
else:
|
||||
return InternVisionPatchModel(config.vision_config)
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
@ -756,13 +755,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
|
||||
return SkyworkR1VImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
@ -770,17 +765,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values_flat is not None:
|
||||
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values_flat)}")
|
||||
|
||||
if not isinstance(image_num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image_num_patches. "
|
||||
f"Got type: {type(image_num_patches)}")
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
|
||||
return SkyworkR1VImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values_flat=pixel_values_flat,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user