[VLM] Remove input processor from clip and siglip (#13165)

This commit is contained in:
Isotr0py 2025-02-13 16:31:37 +08:00 committed by GitHub
parent 9605c1256e
commit fa253f1a70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 213 deletions

View File

@ -1,156 +1,24 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Minimal implementation of CLIPVisionModel intended to be only used """Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model.""" within a vision language model."""
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig from transformers import CLIPVisionConfig
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
return image_size // patch_size
def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
grid_length = get_clip_patch_grid_length(image_size=image_size,
patch_size=patch_size)
return grid_length * grid_length
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
return get_clip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size) + 1
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
return get_clip_image_feature_size(hf_config)
def dummy_seq_data_for_clip(hf_config: CLIPVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
mm_key: str = "image"):
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
return SequenceData.from_prompt_token_counts(
(image_token_id, image_feature_size * num_images),
(0, seq_len - image_feature_size * num_images),
), {
mm_key:
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_clip(
hf_config: CLIPVisionConfig,
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
width = height = hf_config.image_size
if image_width_override is not None:
width = image_width_override
if image_height_override is not None:
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_video_for_clip(
hf_config: CLIPVisionConfig,
num_frames: int,
num_videos: int = 1,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
pil_frame = dummy_image_for_clip(
hf_config,
num_images=1,
image_width_override=image_width_override,
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
video_data = [mm_data_per_video] * num_videos
mm_data = {"video": video_data}
return mm_data
def input_processor_for_clip(
model_config: ModelConfig,
hf_config: CLIPVisionConfig,
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_clip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]): class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
def get_num_image_tokens( def get_num_image_tokens(
@ -159,10 +27,10 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
image_width: int, image_width: int,
image_height: int, image_height: int,
) -> int: ) -> int:
return get_clip_image_feature_size(self.vision_config) return self.get_patch_grid_length()**2 + 1
def get_max_image_tokens(self) -> int: def get_max_image_tokens(self) -> int:
return get_max_clip_image_tokens(self.vision_config) return self.get_patch_grid_length()**2 + 1
def get_image_size(self) -> int: def get_image_size(self) -> int:
return self.vision_config.image_size return self.vision_config.image_size
@ -171,10 +39,9 @@ class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
return self.vision_config.patch_size return self.vision_config.patch_size
def get_patch_grid_length(self) -> int: def get_patch_grid_length(self) -> int:
return get_clip_patch_grid_length( image_size, patch_size = self.get_image_size(), self.get_patch_size()
image_size=self.vision_config.image_size, assert image_size % patch_size == 0
patch_size=self.vision_config.patch_size, return image_size // patch_size
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
@ -186,6 +53,7 @@ class CLIPVisionEmbeddings(nn.Module):
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.image_size = config.image_size self.image_size = config.image_size
self.patch_size = config.patch_size self.patch_size = config.patch_size
assert self.image_size % self.patch_size == 0
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
@ -197,8 +65,7 @@ class CLIPVisionEmbeddings(nn.Module):
bias=False, bias=False,
) )
self.num_patches = get_clip_num_patches(image_size=self.image_size, self.num_patches = (self.image_size // self.patch_size)**2
patch_size=self.patch_size)
self.num_positions = self.num_patches + 1 self.num_positions = self.num_patches + 1
self.position_embedding = nn.Embedding(self.num_positions, self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim) self.embed_dim)

View File

@ -3,18 +3,15 @@
within a vision language model.""" within a vision language model."""
import math import math
from typing import Iterable, List, Optional, Set, Tuple, Union from typing import Iterable, Optional, Set, Tuple, Union
import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from torch import nn from torch import nn
from transformers import SiglipVisionConfig from transformers import SiglipVisionConfig
from vllm.attention.layer import MultiHeadAttention from vllm.attention.layer import MultiHeadAttention
from vllm.config import ModelConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.inputs import DecoderOnlyInputs, token_inputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
@ -23,9 +20,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import consecutive_placeholder_ranges
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
@ -93,71 +88,6 @@ def dummy_image_for_siglip(
return {"image": image if num_images == 1 else [image] * num_images} return {"image": image if num_images == 1 else [image] * num_images}
def dummy_video_for_siglip(
hf_config: SiglipVisionConfig,
num_frames: int,
num_videos: int = 1,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
pil_frame = dummy_image_for_siglip(
hf_config,
num_images=1,
image_width_override=image_width_override,
image_height_override=image_height_override)
np_frame = np.array(pil_frame["image"])
mm_data_per_video = np.repeat([np_frame], num_frames, axis=0)
video_data = [mm_data_per_video] * num_videos
mm_data = {"video": video_data}
return mm_data
def input_processor_for_siglip(
model_config: ModelConfig,
hf_config: SiglipVisionConfig,
inputs: DecoderOnlyInputs,
*,
image_token_id: int,
image_feature_size_override: Optional[Union[int, List[int]]] = None,
):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
if "multi_modal_placeholders" in inputs and "image" in inputs[
"multi_modal_placeholders"]:
# The inputs already have placeholders.
return inputs
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_siglip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer,
inputs.get("prompt"),
inputs["prompt_token_ids"],
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]): class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
def get_num_image_tokens( def get_num_image_tokens(