mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 10:17:52 +08:00
[Bugfix] Fix embedding assignment for InternVL-based models (#15086)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
70e500cad9
commit
ffa443afed
@ -169,7 +169,6 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
max_model_len=2048,
|
max_model_len=2048,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
# Default is False; setting it to True is not supported in V1 yet
|
|
||||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -91,8 +91,6 @@ def load_gemma3(question: str, image_urls: list[str]) -> ModelRequestData:
|
|||||||
model=model_name,
|
model=model_name,
|
||||||
max_model_len=8192,
|
max_model_len=8192,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
# Default is False; setting it to True is not supported in V1 yet
|
|
||||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
|
||||||
limit_mm_per_prompt={"image": len(image_urls)},
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -183,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
|||||||
image_width: int,
|
image_width: int,
|
||||||
image_height: int,
|
image_height: int,
|
||||||
processor: Optional[Gemma3Processor],
|
processor: Optional[Gemma3Processor],
|
||||||
) -> PromptUpdateDetails:
|
) -> PromptUpdateDetails[str]:
|
||||||
if processor is None:
|
if processor is None:
|
||||||
processor = self.get_hf_processor()
|
processor = self.get_hf_processor()
|
||||||
|
|
||||||
|
|||||||
@ -249,20 +249,15 @@ class H2OVLProcessor(BaseInternVLProcessor):
|
|||||||
def image_token_id(self) -> int:
|
def image_token_id(self) -> int:
|
||||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||||
|
|
||||||
def get_image_repl_features(
|
def get_image_repl(
|
||||||
self,
|
self,
|
||||||
feature_size: int,
|
feature_size: int,
|
||||||
num_patches: Optional[int],
|
num_patches: Optional[int],
|
||||||
) -> str:
|
) -> PromptUpdateDetails[str]:
|
||||||
return IMG_CONTEXT * feature_size
|
repl_features = IMG_CONTEXT * feature_size
|
||||||
|
repl_full = IMG_START + repl_features + IMG_END
|
||||||
|
|
||||||
def get_image_repl_full(
|
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||||
self,
|
|
||||||
feature_size: int,
|
|
||||||
num_patches: Optional[int],
|
|
||||||
) -> str:
|
|
||||||
features = self.get_image_repl_features(feature_size, num_patches)
|
|
||||||
return IMG_START + features + IMG_END
|
|
||||||
|
|
||||||
def resolve_min_max_num(
|
def resolve_min_max_num(
|
||||||
self,
|
self,
|
||||||
@ -501,12 +496,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
|
|||||||
if num_patches is not None:
|
if num_patches is not None:
|
||||||
assert isinstance(num_patches, int)
|
assert isinstance(num_patches, int)
|
||||||
|
|
||||||
return PromptUpdateDetails(
|
return hf_processor.get_image_repl(feature_size, num_patches)
|
||||||
full=hf_processor.get_image_repl_full(feature_size,
|
|
||||||
num_patches),
|
|
||||||
features=hf_processor.get_image_repl_features(
|
|
||||||
feature_size, num_patches),
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
|
|||||||
@ -9,14 +9,13 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar,
|
from typing import Literal, Optional, Set, Tuple, TypedDict, TypeVar, Union
|
||||||
Union)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
from transformers import BatchEncoding, PretrainedConfig, TensorType
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@ -36,10 +35,12 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
from vllm.utils import flatten_2d_lists
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
IMG_START = '<img>'
|
IMG_START = '<img>'
|
||||||
IMG_END = '</img>'
|
IMG_END = '</img>'
|
||||||
@ -51,16 +52,26 @@ IMAGENET_STD = (0.229, 0.224, 0.225)
|
|||||||
|
|
||||||
class InternVLImagePixelInputs(TypedDict):
|
class InternVLImagePixelInputs(TypedDict):
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
data: torch.Tensor
|
pixel_values_flat: torch.Tensor
|
||||||
"""
|
"""
|
||||||
Shape:
|
Shape:
|
||||||
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
|
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
|
||||||
"""
|
"""
|
||||||
patches_per_image: List[int]
|
|
||||||
|
num_patches: torch.Tensor
|
||||||
|
"""Shape: `(batch_size * num_images)`"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""
|
"""
|
||||||
List of number of total patches for each image in the batch.
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size, num_images, num_embeds)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""Shape: `(batch_size, num_images)`"""
|
||||||
|
|
||||||
|
|
||||||
class InternVLImageEmbeddingInputs(TypedDict):
|
class InternVLImageEmbeddingInputs(TypedDict):
|
||||||
type: Literal["image_embeds"]
|
type: Literal["image_embeds"]
|
||||||
@ -286,19 +297,11 @@ class BaseInternVLProcessor(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_image_repl_features(
|
def get_image_repl(
|
||||||
self,
|
self,
|
||||||
feature_size: int,
|
feature_size: int,
|
||||||
num_patches: Optional[int],
|
num_patches: Optional[int],
|
||||||
) -> str:
|
) -> PromptUpdateDetails[str]:
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_image_repl_full(
|
|
||||||
self,
|
|
||||||
feature_size: int,
|
|
||||||
num_patches: Optional[int],
|
|
||||||
) -> str:
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def resolve_min_max_num(
|
def resolve_min_max_num(
|
||||||
@ -394,7 +397,7 @@ class BaseInternVLProcessor(ABC):
|
|||||||
max_dynamic_patch: Optional[int] = None,
|
max_dynamic_patch: Optional[int] = None,
|
||||||
dynamic_image_size: Optional[bool] = None,
|
dynamic_image_size: Optional[bool] = None,
|
||||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
) -> BatchFeature:
|
) -> Mapping[str, NestedTensors]:
|
||||||
if text is None:
|
if text is None:
|
||||||
text = []
|
text = []
|
||||||
if not isinstance(text, list):
|
if not isinstance(text, list):
|
||||||
@ -413,28 +416,41 @@ class BaseInternVLProcessor(ABC):
|
|||||||
max_dynamic_patch=max_dynamic_patch,
|
max_dynamic_patch=max_dynamic_patch,
|
||||||
dynamic_image_size=dynamic_image_size,
|
dynamic_image_size=dynamic_image_size,
|
||||||
)
|
)
|
||||||
image_inputs = {
|
image_inputs: dict[str, NestedTensors] = {
|
||||||
"pixel_values_flat": torch.cat(pixel_values_lst),
|
"pixel_values_flat":
|
||||||
"image_num_patches": list(map(len, pixel_values_lst)),
|
torch.cat(pixel_values_lst),
|
||||||
|
"image_num_patches":
|
||||||
|
torch.tensor([len(item) for item in pixel_values_lst]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tokenizer = self.tokenizer
|
||||||
|
image_token_id = self.image_token_id
|
||||||
|
|
||||||
|
num_embeds = list[int]()
|
||||||
|
embed_is_patch = list[torch.Tensor]()
|
||||||
|
|
||||||
for pixel_values in pixel_values_lst:
|
for pixel_values in pixel_values_lst:
|
||||||
num_patches = pixel_values.shape[0]
|
num_patches = pixel_values.shape[0]
|
||||||
feature_size = num_patches * self.num_image_token
|
feature_size = num_patches * self.num_image_token
|
||||||
|
|
||||||
image_repl = self.get_image_repl_full(feature_size,
|
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||||
num_patches)
|
feature_tokens = tokenizer.encode(image_repl.features,
|
||||||
text = [t.replace('<image>', image_repl, 1) for t in text]
|
add_special_tokens=False)
|
||||||
|
|
||||||
|
text = [t.replace('<image>', image_repl.full, 1) for t in text]
|
||||||
|
num_embeds.append(len(feature_tokens))
|
||||||
|
embed_is_patch.append(
|
||||||
|
torch.tensor(feature_tokens) == image_token_id)
|
||||||
|
|
||||||
|
image_inputs["num_embeds"] = torch.tensor(num_embeds)
|
||||||
|
image_inputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
text_inputs = self.tokenizer(text)
|
text_inputs = self.tokenizer(text)
|
||||||
|
|
||||||
return BatchFeature(
|
return {
|
||||||
{
|
**BatchEncoding(text_inputs, tensor_type=return_tensors),
|
||||||
**text_inputs,
|
**image_inputs,
|
||||||
**image_inputs,
|
}
|
||||||
},
|
|
||||||
tensor_type=return_tensors,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InternVLProcessor(BaseInternVLProcessor):
|
class InternVLProcessor(BaseInternVLProcessor):
|
||||||
@ -443,20 +459,15 @@ class InternVLProcessor(BaseInternVLProcessor):
|
|||||||
def image_token_id(self) -> int:
|
def image_token_id(self) -> int:
|
||||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||||
|
|
||||||
def get_image_repl_features(
|
def get_image_repl(
|
||||||
self,
|
self,
|
||||||
feature_size: int,
|
feature_size: int,
|
||||||
num_patches: Optional[int],
|
num_patches: Optional[int],
|
||||||
) -> str:
|
) -> PromptUpdateDetails[str]:
|
||||||
return IMG_CONTEXT * feature_size
|
repl_features = IMG_CONTEXT * feature_size
|
||||||
|
repl_full = IMG_START + repl_features + IMG_END
|
||||||
|
|
||||||
def get_image_repl_full(
|
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||||
self,
|
|
||||||
feature_size: int,
|
|
||||||
num_patches: Optional[int],
|
|
||||||
) -> str:
|
|
||||||
features = self.get_image_repl_features(feature_size, num_patches)
|
|
||||||
return IMG_START + features + IMG_END
|
|
||||||
|
|
||||||
|
|
||||||
class BaseInternVLProcessingInfo(BaseProcessingInfo):
|
class BaseInternVLProcessingInfo(BaseProcessingInfo):
|
||||||
@ -566,16 +577,15 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
mm_data: Mapping[str, object],
|
mm_data: Mapping[str, object],
|
||||||
mm_kwargs: Mapping[str, object],
|
mm_kwargs: Mapping[str, object],
|
||||||
) -> BatchFeature:
|
) -> Mapping[str, NestedTensors]:
|
||||||
processed_outputs = super()._call_hf_processor(
|
processed_outputs = super()._call_hf_processor(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
mm_data=mm_data,
|
mm_data=mm_data,
|
||||||
mm_kwargs=mm_kwargs,
|
mm_kwargs=mm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id
|
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||||
image_data = mm_data.get("images", [])
|
image_token_id = hf_processor.image_token_id
|
||||||
assert isinstance(image_data, list)
|
|
||||||
|
|
||||||
# Since there may be extra tokens in the feature placeholders,
|
# Since there may be extra tokens in the feature placeholders,
|
||||||
# we need to pass the image token ID to the model to select the
|
# we need to pass the image token ID to the model to select the
|
||||||
@ -586,7 +596,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
hf_inputs: BatchFeature,
|
hf_inputs: Mapping[str, NestedTensors],
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||||
@ -596,6 +606,8 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", image_num_patches),
|
"image", image_num_patches),
|
||||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
|
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||||
)
|
)
|
||||||
@ -637,12 +649,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
if num_patches is not None:
|
if num_patches is not None:
|
||||||
assert isinstance(num_patches, int)
|
assert isinstance(num_patches, int)
|
||||||
|
|
||||||
return PromptUpdateDetails(
|
return hf_processor.get_image_repl(feature_size, num_patches)
|
||||||
full=hf_processor.get_image_repl_full(feature_size,
|
|
||||||
num_patches),
|
|
||||||
features=hf_processor.get_image_repl_features(
|
|
||||||
feature_size, num_patches),
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
PromptReplacement(
|
PromptReplacement(
|
||||||
@ -832,6 +839,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
||||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||||
|
num_embeds = kwargs.pop("num_embeds", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values_flat is None and image_embeds is None:
|
if pixel_values_flat is None and image_embeds is None:
|
||||||
@ -858,35 +867,47 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
|
|
||||||
if not isinstance(image_num_patches, (torch.Tensor, list)):
|
if not isinstance(image_num_patches, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of image_num_patches. "
|
raise ValueError("Incorrect type of image_num_patches. "
|
||||||
f"Got type: {type(pixel_values_flat)}")
|
f"Got type: {type(image_num_patches)}")
|
||||||
|
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
|
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of num_embeds. "
|
||||||
|
f"Got type: {type(num_embeds)}")
|
||||||
|
|
||||||
|
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||||
|
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||||
|
|
||||||
return InternVLImagePixelInputs(
|
return InternVLImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
data=self._validate_pixel_values(
|
pixel_values_flat=self._validate_pixel_values(
|
||||||
flatten_bn(pixel_values_flat, concat=True)),
|
pixel_values_flat),
|
||||||
patches_per_image=flatten_bn(image_num_patches,
|
num_patches=image_num_patches,
|
||||||
concat=True).tolist())
|
embed_is_patch=embed_is_patch,
|
||||||
|
num_embeds=num_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
raise AssertionError("This line should be unreachable.")
|
raise AssertionError("This line should be unreachable.")
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self,
|
self,
|
||||||
image_input: InternVLImageInputs,
|
image_input: InternVLImageInputs,
|
||||||
) -> tuple[torch.Tensor, ...]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
return image_input["data"]
|
return image_input["data"]
|
||||||
|
|
||||||
assert self.vision_model is not None
|
assert self.vision_model is not None
|
||||||
|
|
||||||
image_embeds = self.extract_feature(image_input["data"])
|
image_embeds = self.extract_feature(image_input["pixel_values_flat"])
|
||||||
|
|
||||||
patches_per_image = image_input["patches_per_image"]
|
num_patches = image_input["num_patches"]
|
||||||
|
|
||||||
# Only one image in the current batch
|
# Only one image in the current batch
|
||||||
if len(patches_per_image) == 1:
|
if len(num_patches) == 1:
|
||||||
image_embeds = image_embeds.view(
|
return image_embeds.view(
|
||||||
-1, self.config.text_config.hidden_size).unsqueeze(0)
|
-1, self.config.text_config.hidden_size).unsqueeze(0)
|
||||||
return image_embeds
|
|
||||||
|
|
||||||
# NOTE: Image embeddings are split into separate tensors for each image
|
# NOTE: Image embeddings are split into separate tensors for each image
|
||||||
# by the size of each embedding.
|
# by the size of each embedding.
|
||||||
@ -894,10 +915,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
image_embeds = image_embeds.view(-1,
|
image_embeds = image_embeds.view(-1,
|
||||||
self.config.text_config.hidden_size)
|
self.config.text_config.hidden_size)
|
||||||
image_feature_sizes = [
|
image_feature_sizes = [
|
||||||
num_patches * feature_size for num_patches in patches_per_image
|
num_patches * feature_size for num_patches in num_patches
|
||||||
]
|
]
|
||||||
image_embeds = image_embeds.split(image_feature_sizes)
|
return image_embeds.split(image_feature_sizes)
|
||||||
return image_embeds
|
|
||||||
|
|
||||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
||||||
if self.is_mono:
|
if self.is_mono:
|
||||||
@ -911,8 +931,19 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
|
||||||
return vision_embeddings
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
|
if (kwargs.get("v0_path", False)
|
||||||
|
or image_input["type"] != "pixel_values"):
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
return flatten_2d_lists(
|
||||||
|
scatter_patch_features(*args) for args in zip(
|
||||||
|
image_features,
|
||||||
|
image_input["num_embeds"],
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
))
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
@ -924,8 +955,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
assert self.img_context_token_id is not None
|
assert self.img_context_token_id is not None
|
||||||
self._set_visual_token_mask(input_ids)
|
self._set_visual_token_mask(input_ids)
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
input_ids,
|
||||||
self.img_context_token_id)
|
inputs_embeds,
|
||||||
|
select_patch_features(multimodal_embeddings),
|
||||||
|
self.img_context_token_id,
|
||||||
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -944,6 +978,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
|
kwargs.update({"v0_path": True})
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
vision_embeddings)
|
vision_embeddings)
|
||||||
|
|||||||
@ -36,11 +36,11 @@ class NVLMProcessor(BaseInternVLProcessor):
|
|||||||
def image_token_id(self) -> int:
|
def image_token_id(self) -> int:
|
||||||
return self.tokenizer.get_vocab()[IMG_PAD]
|
return self.tokenizer.get_vocab()[IMG_PAD]
|
||||||
|
|
||||||
def get_image_repl_features(
|
def get_image_repl(
|
||||||
self,
|
self,
|
||||||
feature_size: int,
|
feature_size: int,
|
||||||
num_patches: Optional[int],
|
num_patches: Optional[int],
|
||||||
) -> str:
|
) -> PromptUpdateDetails[str]:
|
||||||
if num_patches is None:
|
if num_patches is None:
|
||||||
raise NotImplementedError("Embedding inputs are not supported")
|
raise NotImplementedError("Embedding inputs are not supported")
|
||||||
|
|
||||||
@ -55,14 +55,9 @@ class NVLMProcessor(BaseInternVLProcessor):
|
|||||||
# We include the start and end as well because "<Image><tile" is
|
# We include the start and end as well because "<Image><tile" is
|
||||||
# tokenized as ["<Image", "><", "tile"], resulting in assertion error
|
# tokenized as ["<Image", "><", "tile"], resulting in assertion error
|
||||||
# when trying to find "<tile" as a subsequence of "<Image><tile"
|
# when trying to find "<tile" as a subsequence of "<Image><tile"
|
||||||
return "<Image>" + features + "</Image>"
|
repl = "<Image>" + features + "</Image>"
|
||||||
|
|
||||||
def get_image_repl_full(
|
return PromptUpdateDetails(full=repl, features=repl)
|
||||||
self,
|
|
||||||
feature_size: int,
|
|
||||||
num_patches: Optional[int],
|
|
||||||
) -> str:
|
|
||||||
return self.get_image_repl_features(feature_size, num_patches)
|
|
||||||
|
|
||||||
|
|
||||||
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
|
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
|
||||||
@ -180,11 +175,11 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
|
|||||||
if num_patches is not None:
|
if num_patches is not None:
|
||||||
assert isinstance(num_patches, int)
|
assert isinstance(num_patches, int)
|
||||||
|
|
||||||
|
repl = hf_processor.get_image_repl(feature_size, num_patches)
|
||||||
|
|
||||||
return PromptUpdateDetails(
|
return PromptUpdateDetails(
|
||||||
full=hf_processor.get_image_repl_full(feature_size,
|
full=repl.full + "\n",
|
||||||
num_patches) + "\n",
|
features=repl.features + "\n",
|
||||||
features=hf_processor.get_image_repl_features(
|
|
||||||
feature_size, num_patches) + "\n",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# See note in dummy data regarding why we have the extra newline
|
# See note in dummy data regarding why we have the extra newline
|
||||||
|
|||||||
@ -103,13 +103,13 @@ The token sequence or text to update.
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PromptUpdateDetails:
|
class PromptUpdateDetails(Generic[_S]):
|
||||||
"""Details about the token sequence or text that are part of the update."""
|
"""Details about the token sequence or text that are part of the update."""
|
||||||
|
|
||||||
full: PromptSeq
|
full: _S
|
||||||
"""The full content."""
|
"""The full content."""
|
||||||
|
|
||||||
features: PromptSeq
|
features: _S
|
||||||
"""
|
"""
|
||||||
The part of the content that corresponds to feature placeholders;
|
The part of the content that corresponds to feature placeholders;
|
||||||
this will be replaced by the output of the vision encoder during model
|
this will be replaced by the output of the vision encoder during model
|
||||||
@ -117,7 +117,7 @@ class PromptUpdateDetails:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_seq(seq: PromptSeq) -> "PromptUpdateDetails":
|
def from_seq(seq: _S) -> "PromptUpdateDetails[_S]":
|
||||||
return PromptUpdateDetails(full=seq, features=seq)
|
return PromptUpdateDetails(full=seq, features=seq)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user