mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 19:35:36 +08:00
[Model] Use merge_by_field_config for MM models (D-F) (#26076)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
00c0b25e82
commit
3884dce376
@ -20,8 +20,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
|||||||
from vllm.model_executor.models.transformers import replace_linear_class
|
from vllm.model_executor.models.transformers import replace_linear_class
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
MultiModalKwargsItems, MultiModalUUIDDict,
|
MultiModalKwargsItems, MultiModalUUIDDict)
|
||||||
NestedTensors)
|
|
||||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||||
ImageSize, MultiModalDataItems)
|
ImageSize, MultiModalDataItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
@ -40,7 +39,7 @@ from vllm.utils import is_list_of
|
|||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
init_vllm_registered_model, maybe_prefix)
|
init_vllm_registered_model, maybe_prefix)
|
||||||
|
|
||||||
# The image token id may be various
|
# The image token id may be various
|
||||||
@ -50,15 +49,15 @@ _IMAGE_TOKEN = "<image>"
|
|||||||
class DeepseekVL2ImagePixelInputs(TensorSchema):
|
class DeepseekVL2ImagePixelInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
Dimensions:
|
Dimensions:
|
||||||
- bn: Batch size * number of images
|
- bnp: Batch size * number of images * number of patches
|
||||||
- p: Number of patches
|
- p: Number of patches
|
||||||
- c: Number of channels (3)
|
- c: Number of channels (3)
|
||||||
- h: Height of each image
|
- h: Height of each image
|
||||||
- w: Width of each image
|
- w: Width of each image
|
||||||
"""
|
"""
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
data: Annotated[torch.Tensor,
|
||||||
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})]
|
TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})]
|
||||||
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
|
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
|
||||||
|
|
||||||
|
|
||||||
@ -228,12 +227,8 @@ class DeepseekVL2MultiModalProcessor(
|
|||||||
tok_kwargs=tok_kwargs,
|
tok_kwargs=tok_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
pixel_values = processed_outputs["pixel_values"]
|
processed_outputs["num_patches"] = (
|
||||||
# split pixel values into patches corresponding to each image
|
processed_outputs["images_spatial_crop"].prod(-1) + 1)
|
||||||
images_spatial_crop = processed_outputs["images_spatial_crop"]
|
|
||||||
patches_per_image = [x.prod().item() + 1 for x in images_spatial_crop]
|
|
||||||
pixel_values = pixel_values.split(patches_per_image)
|
|
||||||
processed_outputs["pixel_values"] = pixel_values
|
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
@ -242,8 +237,11 @@ class DeepseekVL2MultiModalProcessor(
|
|||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
num_patches = hf_inputs.get("num_patches", torch.empty(0))
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"image", num_patches),
|
||||||
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
@ -318,6 +316,7 @@ class DeepseekVL2MultiModalProcessor(
|
|||||||
info=DeepseekVL2ProcessingInfo,
|
info=DeepseekVL2ProcessingInfo,
|
||||||
dummy_inputs=DeepseekVL2DummyInputsBuilder)
|
dummy_inputs=DeepseekVL2DummyInputsBuilder)
|
||||||
class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||||
"language.": "language_model.",
|
"language.": "language_model.",
|
||||||
@ -460,11 +459,10 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
expected_h = expected_w = self.vision_config.image_size
|
expected_h = expected_w = self.vision_config.image_size
|
||||||
return DeepseekVL2ImagePixelInputs(type="pixel_values",
|
return DeepseekVL2ImagePixelInputs(
|
||||||
data=flatten_bn(pixel_values),
|
type="pixel_values",
|
||||||
images_spatial_crop=flatten_bn(
|
data=pixel_values,
|
||||||
images_spatial_crop,
|
images_spatial_crop=images_spatial_crop,
|
||||||
concat=True),
|
|
||||||
resolve_bindings={
|
resolve_bindings={
|
||||||
"h": expected_h,
|
"h": expected_h,
|
||||||
"w": expected_w,
|
"w": expected_w,
|
||||||
@ -473,24 +471,18 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
return DeepseekVL2VImageEmbeddingInputs(
|
return DeepseekVL2VImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
data=flatten_bn(image_embeds),
|
data=image_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
def _pixel_values_to_embedding(
|
def _pixel_values_to_embedding(
|
||||||
self,
|
self,
|
||||||
pixel_values: NestedTensors,
|
pixel_values: torch.Tensor,
|
||||||
images_spatial_crop: torch.Tensor,
|
images_spatial_crop: torch.Tensor,
|
||||||
) -> NestedTensors:
|
) -> list[torch.Tensor]:
|
||||||
# Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
|
|
||||||
total_tiles = [x for x in pixel_values]
|
|
||||||
|
|
||||||
# [batch_all_tiles, 3, height, width]
|
|
||||||
total_tiles = torch.cat(total_tiles, dim=0)
|
|
||||||
|
|
||||||
# [batch_all_tiles, vit_seq_len, c]
|
# [batch_all_tiles, vit_seq_len, c]
|
||||||
images_feature = self.vision.forward_features(total_tiles)
|
images_feature = self.vision.forward_features(pixel_values)
|
||||||
|
|
||||||
# [batch_all_tiles, hw, D]
|
# [batch_all_tiles, hw, D]
|
||||||
images_embeds = self.projector(images_feature)
|
images_embeds = self.projector(images_feature)
|
||||||
@ -573,7 +565,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
return vision_embeddings
|
return vision_embeddings
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor:
|
self, image_input: DeepseekVL2ImageInputs) -> list[torch.Tensor]:
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
image_data = image_input["data"]
|
image_data = image_input["data"]
|
||||||
if is_list_of(image_data, torch.Tensor):
|
if is_list_of(image_data, torch.Tensor):
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from typing import Literal, Optional, TypedDict, Union
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -42,34 +42,38 @@ from vllm.platforms import _Backend
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
|
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
|
||||||
DotsVisionConfig)
|
DotsVisionConfig)
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .vision import run_dp_sharded_mrope_vision_model
|
from .vision import run_dp_sharded_mrope_vision_model
|
||||||
|
|
||||||
IMAGE_TOKEN = "<|imgpad|>"
|
IMAGE_TOKEN = "<|imgpad|>"
|
||||||
|
|
||||||
|
|
||||||
class DotsOCRImagePixelInputs(TypedDict):
|
class DotsOCRImagePixelInputs(TensorSchema):
|
||||||
type: Literal["pixel_values", "image_grid_thw"]
|
|
||||||
|
|
||||||
pixel_values: torch.Tensor
|
|
||||||
image_grid_thw: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class DotsOCRImageEmbeddingInputs(TypedDict):
|
|
||||||
type: Literal["image_embeds", "image_grid_thw"]
|
|
||||||
image_embeds: torch.Tensor
|
|
||||||
"""Supported types:
|
|
||||||
- List[`torch.Tensor`]: A list of tensors holding all images' features.
|
|
||||||
Each tensor holds an image's features.
|
|
||||||
- `torch.Tensor`: A tensor holding all images' features
|
|
||||||
(concatenation of all images' feature tensors).
|
|
||||||
Tensor shape: `(num_image_features, hidden_size)`
|
|
||||||
- `num_image_features` varies based on
|
|
||||||
the number and resolution of the images.
|
|
||||||
- `hidden_size` must match the hidden size of language model backbone.
|
|
||||||
"""
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- np: The total number of patches over each image over each prompt in
|
||||||
|
the batch
|
||||||
|
- ni: Number of images
|
||||||
|
- cps: Number of channels * patch_size * patch_size
|
||||||
|
"""
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
|
||||||
image_grid_thw: torch.Tensor
|
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
|
||||||
|
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||||
|
|
||||||
|
|
||||||
|
class DotsOCRImageEmbeddingInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- nf: Number of image features
|
||||||
|
- hs: Hidden size
|
||||||
|
- ni: Number of images
|
||||||
|
"""
|
||||||
|
type: Literal["image_embeds"]
|
||||||
|
|
||||||
|
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
|
||||||
|
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||||
|
|
||||||
|
|
||||||
DotsOCRImageInputs = Union[DotsOCRImagePixelInputs,
|
DotsOCRImageInputs = Union[DotsOCRImagePixelInputs,
|
||||||
@ -654,6 +658,8 @@ class DotsVisionTransformer(nn.Module):
|
|||||||
)
|
)
|
||||||
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
||||||
SupportsLoRA):
|
SupportsLoRA):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_substr={
|
orig_to_new_substr={
|
||||||
".attn.qkv_proj.": ".attn.qkv.",
|
".attn.qkv_proj.": ".attn.qkv.",
|
||||||
@ -709,22 +715,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
architectures=["Qwen2ForCausalLM"],
|
architectures=["Qwen2ForCausalLM"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
|
||||||
name: str) -> torch.Tensor:
|
|
||||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
|
||||||
raise ValueError(f"Incorrect type of {name}. "
|
|
||||||
f"Got type: {type(mm_input)}")
|
|
||||||
if isinstance(mm_input, torch.Tensor):
|
|
||||||
if mm_input.ndim == 2:
|
|
||||||
return mm_input
|
|
||||||
if mm_input.ndim != 3:
|
|
||||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
|
||||||
f"Got ndim: {mm_input.ndim} "
|
|
||||||
f"(shape={mm_input.shape})")
|
|
||||||
return torch.concat(list(mm_input))
|
|
||||||
else:
|
|
||||||
return torch.concat(mm_input)
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[DotsOCRImageInputs]:
|
self, **kwargs: object) -> Optional[DotsOCRImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
@ -735,28 +725,11 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
|
||||||
pixel_values, "image pixel values")
|
|
||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_grid_thw, "image grid_thw")
|
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of image pixel values. "
|
|
||||||
f"Got type: {type(pixel_values)}")
|
|
||||||
|
|
||||||
return DotsOCRImagePixelInputs(type="pixel_values",
|
return DotsOCRImagePixelInputs(type="pixel_values",
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
image_grid_thw=image_grid_thw)
|
image_grid_thw=image_grid_thw)
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_embeds, "image embeds")
|
|
||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_grid_thw, "image grid_thw")
|
|
||||||
|
|
||||||
if not isinstance(image_embeds, torch.Tensor):
|
|
||||||
raise ValueError("Incorrect type of image embeddings. "
|
|
||||||
f"Got type: {type(image_embeds)}")
|
|
||||||
return DotsOCRImageEmbeddingInputs(type="image_embeds",
|
return DotsOCRImageEmbeddingInputs(type="image_embeds",
|
||||||
image_embeds=image_embeds,
|
image_embeds=image_embeds,
|
||||||
image_grid_thw=image_grid_thw)
|
image_grid_thw=image_grid_thw)
|
||||||
|
|||||||
@ -25,7 +25,7 @@
|
|||||||
import math
|
import math
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Literal, Optional, TypedDict, Union
|
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -56,6 +56,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
|
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
@ -579,38 +580,38 @@ class Ernie4_5_VisionTransformer(nn.Module):
|
|||||||
# === Vision Inputs === #
|
# === Vision Inputs === #
|
||||||
|
|
||||||
|
|
||||||
class Ernie4_5_VLImagePixelInputs(TypedDict):
|
class Ernie4_5_VLImagePixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- np: The total number of patches over each image over each prompt in
|
||||||
|
the batch
|
||||||
|
- ni: Number of images
|
||||||
|
- cps: Number of channels * patch_size * patch_size
|
||||||
|
"""
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
pixel_values: torch.Tensor
|
|
||||||
"""Shape:
|
|
||||||
`(num_patches, num_channels * patch_size * patch_size)`
|
|
||||||
"""
|
|
||||||
|
|
||||||
grid_thw: torch.Tensor
|
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
|
||||||
"""Shape: `(num_images, 3)`
|
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs
|
Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs
|
||||||
|
|
||||||
|
|
||||||
class Ernie4_5_VLVideoPixelInputs(TypedDict):
|
class Ernie4_5_VLVideoPixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- np: The total number of patches over each image over each prompt in
|
||||||
|
the batch
|
||||||
|
- ni: Number of images
|
||||||
|
- cps: Number of channels * temporal_patch_size * patch_size *
|
||||||
|
patch_size
|
||||||
|
"""
|
||||||
type: Literal["pixel_values_videos"]
|
type: Literal["pixel_values_videos"]
|
||||||
pixel_values_videos: torch.Tensor
|
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "cps")]
|
||||||
"""Shape:
|
video_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||||
`(num_patches,
|
|
||||||
num_channels * temporal_patch_size * patch_size * patch_size)`
|
|
||||||
"""
|
|
||||||
|
|
||||||
video_grid_thw: torch.Tensor
|
|
||||||
"""Shape: `(num_videos, 3)`
|
|
||||||
|
|
||||||
This should be in `(grid_t, grid_h, grid_w)` format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs
|
Ernie4_5_VLVideoInputs = Ernie4_5_VLVideoPixelInputs
|
||||||
|
|
||||||
# === Vision Processor === #
|
# === Vision Processor === #
|
||||||
|
|
||||||
@ -1213,6 +1214,7 @@ class Ernie4_5_VLDummyInputsBuilder(
|
|||||||
dummy_inputs=Ernie4_5_VLDummyInputsBuilder)
|
dummy_inputs=Ernie4_5_VLDummyInputsBuilder)
|
||||||
class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsLoRA, SupportsPP):
|
SupportsLoRA, SupportsPP):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
@ -1325,22 +1327,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.language_model
|
return self.language_model
|
||||||
|
|
||||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
|
||||||
name: str) -> torch.Tensor:
|
|
||||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
|
||||||
raise ValueError(f"Incorrect type of {name}. "
|
|
||||||
f"Got type: {type(mm_input)}")
|
|
||||||
if isinstance(mm_input, torch.Tensor):
|
|
||||||
if mm_input.ndim == 2:
|
|
||||||
return mm_input
|
|
||||||
if mm_input.ndim != 3:
|
|
||||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
|
||||||
f"Got ndim: {mm_input.ndim} "
|
|
||||||
f"(shape={mm_input.shape})")
|
|
||||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
|
||||||
else:
|
|
||||||
return torch.concat(mm_input)
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]:
|
self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
@ -1350,15 +1336,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
|
||||||
pixel_values, "image pixel values")
|
|
||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_grid_thw, "image grid_thw")
|
|
||||||
|
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of image pixel values. "
|
|
||||||
f"Got type: {type(pixel_values)}")
|
|
||||||
|
|
||||||
return Ernie4_5_VLImagePixelInputs(type="pixel_values",
|
return Ernie4_5_VLImagePixelInputs(type="pixel_values",
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
image_grid_thw=image_grid_thw)
|
image_grid_thw=image_grid_thw)
|
||||||
@ -1372,11 +1349,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
if pixel_values_videos is not None:
|
||||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
|
||||||
pixel_values_videos, "video pixel values")
|
|
||||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
video_grid_thw, "video grid_thw")
|
|
||||||
|
|
||||||
return Ernie4_5_VLVideoPixelInputs(
|
return Ernie4_5_VLVideoPixelInputs(
|
||||||
type="pixel_values_videos",
|
type="pixel_values_videos",
|
||||||
pixel_values_videos=pixel_values_videos,
|
pixel_values_videos=pixel_values_videos,
|
||||||
|
|||||||
@ -59,17 +59,14 @@ class FuyuImagePatchInputs(TensorSchema):
|
|||||||
|
|
||||||
type: Literal["image_patches"] = "image_patches"
|
type: Literal["image_patches"] = "image_patches"
|
||||||
|
|
||||||
flat_data: Annotated[
|
image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")]
|
||||||
torch.Tensor,
|
|
||||||
TensorShape("bnp", "fn"),
|
|
||||||
]
|
|
||||||
|
|
||||||
patches_per_image: Annotated[list[int], TensorShape("bn")]
|
patches_per_image: Annotated[list[int], TensorShape("bn")]
|
||||||
"""
|
"""
|
||||||
The number of total patches for each image in the batch.
|
The number of total patches for each image in the batch.
|
||||||
|
|
||||||
This is used to split the embeddings which has the first two dimensions
|
This is used to split the embeddings which has the first two dimensions
|
||||||
flattened just like `flat_data`.
|
flattened just like `image_patches_flat`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -174,28 +171,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
|||||||
tok_kwargs=tok_kwargs,
|
tok_kwargs=tok_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_patches = processed_outputs.get("image_patches")
|
image_patches = processed_outputs["image_patches"]
|
||||||
if image_patches is not None:
|
processed_outputs["image_patches"] = flatten_bn(image_patches)
|
||||||
images = mm_data["images"]
|
processed_outputs["patches_per_image"] = torch.tensor(
|
||||||
assert isinstance(images, list)
|
[len(p) for p in image_patches])
|
||||||
|
|
||||||
# Original output: (1, num_images, Pn, Px * Py * C)
|
|
||||||
# New output: (num_images, Pn, Px * Py * C)
|
|
||||||
# image_patches is a list with shape:
|
|
||||||
# (1, num_images, Pn, Px * Py * C)
|
|
||||||
# before Transformers 4.53
|
|
||||||
if isinstance(image_patches, list):
|
|
||||||
assert len(image_patches) == 1
|
|
||||||
assert (isinstance(image_patches[0], torch.Tensor)
|
|
||||||
and len(image_patches[0]) == len(images))
|
|
||||||
processed_outputs["image_patches"] = image_patches[0]
|
|
||||||
# image_patches is a tensor with shape:
|
|
||||||
# (num_images, Pn, Px * Py * C)
|
|
||||||
# after Transformers 4.53
|
|
||||||
elif isinstance(image_patches, torch.Tensor):
|
|
||||||
assert len(image_patches) == len(images)
|
|
||||||
else:
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
@ -218,7 +197,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
|||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
image_patches=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"image", patches_per_image),
|
||||||
|
patches_per_image=MultiModalFieldConfig.batched("image"),
|
||||||
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
self,
|
self,
|
||||||
@ -263,6 +248,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
|||||||
info=FuyuProcessingInfo,
|
info=FuyuProcessingInfo,
|
||||||
dummy_inputs=FuyuDummyInputsBuilder)
|
dummy_inputs=FuyuDummyInputsBuilder)
|
||||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
orig_to_new_prefix={
|
orig_to_new_prefix={
|
||||||
@ -306,29 +292,28 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
||||||
image_patches = kwargs.pop("image_patches", None)
|
image_patches = kwargs.pop("image_patches", None)
|
||||||
if image_patches is not None:
|
patches_per_image = kwargs.pop("patches_per_image", None)
|
||||||
image_patches_flat = flatten_bn(image_patches)
|
|
||||||
flat_data = flatten_bn(image_patches_flat, concat=True)
|
if image_patches is None:
|
||||||
|
return None
|
||||||
|
|
||||||
return FuyuImagePatchInputs(
|
return FuyuImagePatchInputs(
|
||||||
type="image_patches",
|
type="image_patches",
|
||||||
flat_data=flat_data,
|
image_patches_flat=image_patches,
|
||||||
patches_per_image=[x.size(0) for x in image_patches_flat],
|
patches_per_image=patches_per_image,
|
||||||
resolve_bindings={"fn": self.image_feature_size},
|
resolve_bindings={"fn": self.image_feature_size},
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
|
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
|
||||||
image_patches_flat = image_input["flat_data"]
|
image_patches_flat = image_input["image_patches_flat"]
|
||||||
patches_per_image = image_input["patches_per_image"]
|
patches_per_image = image_input["patches_per_image"]
|
||||||
|
|
||||||
assert self.vision_embed_tokens is not None
|
assert self.vision_embed_tokens is not None
|
||||||
vision_embeddings_flat, _ = self.vision_embed_tokens(
|
vision_embeddings_flat, _ = self.vision_embed_tokens(
|
||||||
image_patches_flat)
|
image_patches_flat)
|
||||||
|
|
||||||
return vision_embeddings_flat.split(patches_per_image, dim=0)
|
return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.language_model
|
return self.language_model
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user