mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 20:54:59 +08:00
[Bugfix] Re-enable Gemma3 for V1 (#14980)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
05ccd0aa35
commit
61f412187d
@ -768,7 +768,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
|
* `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc.
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
*
|
* ⚠️
|
||||||
- * `GLM4VForCausalLM`<sup>^</sup>
|
- * `GLM4VForCausalLM`<sup>^</sup>
|
||||||
* GLM-4V
|
* GLM-4V
|
||||||
* T + I
|
* T + I
|
||||||
@ -951,13 +951,10 @@ V0 correctly implements the model's attention pattern:
|
|||||||
|
|
||||||
V1 currently uses a simplified attention pattern:
|
V1 currently uses a simplified attention pattern:
|
||||||
- Uses causal attention for all tokens, including image tokens
|
- Uses causal attention for all tokens, including image tokens
|
||||||
- Generates reasonable outputs but does not match the original model's attention for text + image inputs
|
- Generates reasonable outputs but does not match the original model's attention for text + image inputs, especially when `{"do_pan_and_scan": True}`
|
||||||
- Will be updated in the future to support the correct behavior
|
- Will be updated in the future to support the correct behavior
|
||||||
- Does not support `"do_pan_and_scan": True`
|
|
||||||
|
|
||||||
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
|
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
|
||||||
|
|
||||||
For these reasons, `Gemma3ForConditionalGeneration` is supported only on V0 at the moment.
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
:::{note}
|
:::{note}
|
||||||
|
|||||||
@ -19,7 +19,8 @@ from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
|||||||
apply_token_matches,
|
apply_token_matches,
|
||||||
find_mm_placeholders,
|
find_mm_placeholders,
|
||||||
find_text_matches, find_token_matches,
|
find_text_matches, find_token_matches,
|
||||||
iter_token_matches)
|
iter_token_matches,
|
||||||
|
replace_token_matches)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.multimodal.profiling import MultiModalProfiler
|
from vllm.multimodal.profiling import MultiModalProfiler
|
||||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||||
@ -89,6 +90,58 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
|||||||
assert all(match_len == len(match_ids) for match_len in match_lens)
|
assert all(match_len == len(match_ids) for match_len in match_lens)
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("token_ids", "match_ids", "new_ids", "expected"),
|
||||||
|
[
|
||||||
|
([], [], [-1], []),
|
||||||
|
([], [32000], [-1], []),
|
||||||
|
(
|
||||||
|
[32000, 32000, 32000],
|
||||||
|
[32000],
|
||||||
|
[-1],
|
||||||
|
[-1, -1, -1],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[32000, 32000, 32000],
|
||||||
|
[32000, 32000],
|
||||||
|
[-1],
|
||||||
|
[-1, 32000],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[32000, 32000, 32000],
|
||||||
|
[32000, 32000, 32000],
|
||||||
|
[-1],
|
||||||
|
[-1],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
|
[28747, 32000],
|
||||||
|
[-1],
|
||||||
|
[9833, -1, 32000, 32000, 9833, -1, 32000, 918],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
|
[28747, 32000, 32000, 32000],
|
||||||
|
[-1],
|
||||||
|
[9833, -1, 9833, 28747, 32000, 32000, 918],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
|
[28747, 0, 32000],
|
||||||
|
[-1],
|
||||||
|
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# yapf: enable
|
||||||
|
def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
|
||||||
|
result = replace_token_matches(token_ids, match_ids, new_ids)
|
||||||
|
|
||||||
|
# Manually constructed results
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("prompt", "target_by_key", "expected_by_key"),
|
("prompt", "target_by_key", "expected_by_key"),
|
||||||
|
|||||||
@ -1,34 +1,43 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import math
|
import math
|
||||||
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
Tuple, TypedDict, Union)
|
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
|
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
|
||||||
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||||
MultiModalDataItems)
|
MultiModalDataItems)
|
||||||
|
# yapf: disable
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, PromptReplacement,
|
BaseProcessingInfo, BoundPromptUpdate,
|
||||||
PromptUpdate, encode_tokens)
|
PlaceholderFeaturesInfo,
|
||||||
|
PromptReplacement, PromptTargetMatch,
|
||||||
|
PromptUpdate, PromptUpdateDetails,
|
||||||
|
encode_tokens, find_mm_placeholders,
|
||||||
|
replace_token_matches)
|
||||||
|
# yapf: enable
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.utils import flatten_2d_lists
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
SupportsMultiModal, SupportsPP, SupportsV0Only)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
from .vision import scatter_patch_features, select_patch_features
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -37,13 +46,25 @@ class Gemma3ImagePixelInputs(TypedDict):
|
|||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
pixel_values: torch.Tensor
|
pixel_values: torch.Tensor
|
||||||
"""
|
"""
|
||||||
Shape: `(num_crops_total, num_channels, height, width)`
|
Shape: `(num_patches_total, num_channels, height, width)`
|
||||||
|
|
||||||
`num_crops_total` is the total number of crops
|
`num_patches_total` is the total number of patches
|
||||||
over each image over each prompt in the batch.
|
over each image over each prompt in the batch.
|
||||||
"""
|
"""
|
||||||
num_crops: torch.Tensor
|
|
||||||
"""Shape: `(batch_size * num_images,)`"""
|
num_patches: torch.Tensor
|
||||||
|
"""Shape: `(batch_size * num_images)`"""
|
||||||
|
|
||||||
|
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""
|
||||||
|
A boolean mask indicating which image embeddings correspond
|
||||||
|
to patch tokens.
|
||||||
|
|
||||||
|
Shape: `(batch_size, num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
"""Shape: `(batch_size, num_images)`"""
|
||||||
|
|
||||||
|
|
||||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||||
@ -51,6 +72,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
|
|||||||
|
|
||||||
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
|
def get_hf_config(self):
|
||||||
|
return self.ctx.get_hf_config(Gemma3Config)
|
||||||
|
|
||||||
def get_hf_processor(self, **kwargs: object):
|
def get_hf_processor(self, **kwargs: object):
|
||||||
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
|
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
|
||||||
|
|
||||||
@ -114,6 +138,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
|||||||
if not do_pan_and_scan:
|
if not do_pan_and_scan:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
logger.warning_once(
|
||||||
|
"`do_pan_and_scan=True` has suboptimal results on V1 "
|
||||||
|
"because of the simplified attention pattern being used.")
|
||||||
|
|
||||||
# Based on Gemma3ImageProcessor.pan_and_scan
|
# Based on Gemma3ImageProcessor.pan_and_scan
|
||||||
if image_width >= image_height:
|
if image_width >= image_height:
|
||||||
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
|
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
|
||||||
@ -154,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
|||||||
image_width: int,
|
image_width: int,
|
||||||
image_height: int,
|
image_height: int,
|
||||||
processor: Optional[Gemma3Processor],
|
processor: Optional[Gemma3Processor],
|
||||||
) -> str:
|
) -> PromptUpdateDetails:
|
||||||
if processor is None:
|
if processor is None:
|
||||||
processor = self.get_hf_processor()
|
processor = self.get_hf_processor()
|
||||||
|
|
||||||
@ -175,7 +204,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
|||||||
f"Here is the original image {image_token} and here are some "
|
f"Here is the original image {image_token} and here are some "
|
||||||
f"crops to help you see better {crops_image_tokens}")
|
f"crops to help you see better {crops_image_tokens}")
|
||||||
|
|
||||||
return image_text.replace(image_token, processor.full_image_sequence)
|
repl_full = image_text.replace(image_token,
|
||||||
|
processor.full_image_sequence)
|
||||||
|
repl_features = repl_full.strip("\n")
|
||||||
|
|
||||||
|
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||||
|
|
||||||
def get_num_image_tokens(
|
def get_num_image_tokens(
|
||||||
self,
|
self,
|
||||||
@ -193,7 +226,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
|||||||
|
|
||||||
image_repl_tokens = encode_tokens(
|
image_repl_tokens = encode_tokens(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
image_repl,
|
image_repl.features,
|
||||||
add_special_tokens=False,
|
add_special_tokens=False,
|
||||||
)
|
)
|
||||||
return len(image_repl_tokens)
|
return len(image_repl_tokens)
|
||||||
@ -240,12 +273,8 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
|||||||
num_images=num_images)
|
num_images=num_images)
|
||||||
}
|
}
|
||||||
|
|
||||||
# NOTE: We need to separate the image tokens here because
|
|
||||||
# encode("\n\n\n\n") != encode("\n\n") * 2, which interferes
|
|
||||||
# with the detection of prompt updates when the image tokens are
|
|
||||||
# right next to each other
|
|
||||||
return ProcessorInputs(
|
return ProcessorInputs(
|
||||||
prompt_text=" ".join([image_token] * num_images),
|
prompt_text=image_token * num_images,
|
||||||
mm_data=mm_data,
|
mm_data=mm_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -278,13 +307,39 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
]
|
]
|
||||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||||
|
|
||||||
|
image_repl_features = [
|
||||||
|
self.info.get_image_repl(image_width=size.width,
|
||||||
|
image_height=size.height,
|
||||||
|
processor=hf_processor).features
|
||||||
|
for size in image_sizes
|
||||||
|
]
|
||||||
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
image_repls_feature_tokens = [
|
||||||
|
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||||
|
for image_repl in image_repl_features
|
||||||
|
]
|
||||||
|
num_embeds = [
|
||||||
|
len(image_repl_feature_tokens)
|
||||||
|
for image_repl_feature_tokens in image_repls_feature_tokens
|
||||||
|
]
|
||||||
|
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
|
||||||
|
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
image_token_id = vocab[tokenizer.image_token]
|
||||||
|
|
||||||
|
embed_is_patch = [
|
||||||
|
torch.tensor(image_repl_tokens) == image_token_id
|
||||||
|
for image_repl_tokens in image_repls_feature_tokens
|
||||||
|
]
|
||||||
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
num_crops = [
|
num_crops = [
|
||||||
self.info.get_num_crops(image_width=size.width,
|
self.info.get_num_crops(image_width=size.width,
|
||||||
image_height=size.height,
|
image_height=size.height,
|
||||||
processor=hf_processor)
|
processor=hf_processor)
|
||||||
for size in image_sizes
|
for size in image_sizes
|
||||||
]
|
]
|
||||||
|
|
||||||
processed_outputs["num_crops"] = torch.tensor(num_crops)
|
processed_outputs["num_crops"] = torch.tensor(num_crops)
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
@ -300,6 +355,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", num_crops + 1),
|
"image", num_crops + 1),
|
||||||
num_crops=MultiModalFieldConfig.batched("image"),
|
num_crops=MultiModalFieldConfig.batched("image"),
|
||||||
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
|
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
@ -329,6 +386,91 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _apply_token_matches(
|
||||||
|
self,
|
||||||
|
prompt: list[int],
|
||||||
|
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||||
|
mm_item_counts: Mapping[str, int],
|
||||||
|
) -> list[int]:
|
||||||
|
token_ids = super()._apply_token_matches(
|
||||||
|
prompt,
|
||||||
|
mm_matches,
|
||||||
|
mm_item_counts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# "\n\n\n" and "\n\n\n\n" are single tokens
|
||||||
|
# Since our replacement can insert "\n\n" next to "\n"
|
||||||
|
# tokens, we have to combine them to be consistent with
|
||||||
|
# the output of the tokenizer
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
newline_1 = vocab["\n"]
|
||||||
|
newline_2 = vocab["\n\n"]
|
||||||
|
newline_3 = vocab["\n\n\n"]
|
||||||
|
newline_4 = vocab["\n\n\n\n"]
|
||||||
|
|
||||||
|
token_ids = replace_token_matches(
|
||||||
|
token_ids,
|
||||||
|
[newline_1, newline_2],
|
||||||
|
[newline_3],
|
||||||
|
)
|
||||||
|
token_ids = replace_token_matches(
|
||||||
|
token_ids,
|
||||||
|
[newline_2, newline_1],
|
||||||
|
[newline_3],
|
||||||
|
)
|
||||||
|
token_ids = replace_token_matches(
|
||||||
|
token_ids,
|
||||||
|
[newline_2, newline_2],
|
||||||
|
[newline_4],
|
||||||
|
)
|
||||||
|
|
||||||
|
return token_ids
|
||||||
|
|
||||||
|
def _find_mm_placeholders(
|
||||||
|
self,
|
||||||
|
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
||||||
|
new_token_ids: list[int],
|
||||||
|
mm_item_counts: Mapping[str, int],
|
||||||
|
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||||
|
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
newline_1 = vocab["\n"]
|
||||||
|
newline_2 = vocab["\n\n"]
|
||||||
|
newline_3 = vocab["\n\n\n"]
|
||||||
|
newline_4 = vocab["\n\n\n\n"]
|
||||||
|
|
||||||
|
def get_repl_toks(tok: int) -> list[int]:
|
||||||
|
if tok == newline_3:
|
||||||
|
return [newline_1, newline_2]
|
||||||
|
if tok == newline_4:
|
||||||
|
return [newline_2, newline_2]
|
||||||
|
|
||||||
|
return [tok]
|
||||||
|
|
||||||
|
repl_token_ids = list[int]()
|
||||||
|
repl_orig_idxs = list[int]()
|
||||||
|
for orig_idx, orig_tok in enumerate(new_token_ids):
|
||||||
|
repl_toks = get_repl_toks(orig_tok)
|
||||||
|
repl_token_ids.extend(repl_toks)
|
||||||
|
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
||||||
|
|
||||||
|
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
|
||||||
|
mm_item_counts)
|
||||||
|
|
||||||
|
return {
|
||||||
|
modality: [
|
||||||
|
PlaceholderFeaturesInfo(
|
||||||
|
modality=p.modality,
|
||||||
|
item_idx=p.item_idx,
|
||||||
|
start_idx=repl_orig_idxs[p.start_idx],
|
||||||
|
tokens=p.tokens,
|
||||||
|
) for p in placeholders
|
||||||
|
]
|
||||||
|
for modality, placeholders in repls.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Gemma3MultiModalProjector(nn.Module):
|
class Gemma3MultiModalProjector(nn.Module):
|
||||||
|
|
||||||
@ -374,7 +516,7 @@ class Gemma3MultiModalProjector(nn.Module):
|
|||||||
info=Gemma3ProcessingInfo,
|
info=Gemma3ProcessingInfo,
|
||||||
dummy_inputs=Gemma3DummyInputsBuilder)
|
dummy_inputs=Gemma3DummyInputsBuilder)
|
||||||
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||||
SupportsLoRA, SupportsV0Only):
|
SupportsLoRA):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -415,6 +557,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.language_model.make_empty_intermediate_tensors)
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return next(self.parameters()).dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sampler(self):
|
def sampler(self):
|
||||||
return self.language_model.sampler
|
return self.language_model.sampler
|
||||||
@ -438,6 +584,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
num_crops = kwargs.pop("num_crops", None)
|
num_crops = kwargs.pop("num_crops", None)
|
||||||
|
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||||
|
num_embeds = kwargs.pop("num_embeds", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
@ -448,16 +596,26 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
f"Got type: {type(pixel_values)}")
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of num_crops values. "
|
raise ValueError("Incorrect type of num_crops. "
|
||||||
f"Got type: {type(num_crops)}")
|
f"Got type: {type(num_crops)}")
|
||||||
|
|
||||||
|
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
|
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of num_embeds. "
|
||||||
|
f"Got type: {type(num_embeds)}")
|
||||||
|
|
||||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||||
num_crops = flatten_bn(num_crops, concat=True)
|
num_crops = flatten_bn(num_crops, concat=True)
|
||||||
|
|
||||||
return Gemma3ImagePixelInputs(
|
return Gemma3ImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=self._validate_pixel_values(pixel_values),
|
pixel_values=self._validate_pixel_values(pixel_values),
|
||||||
num_crops=num_crops,
|
num_patches=num_crops + 1,
|
||||||
|
embed_is_patch=embed_is_patch,
|
||||||
|
num_embeds=num_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _image_pixels_to_features(
|
def _image_pixels_to_features(
|
||||||
@ -472,36 +630,51 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
self,
|
self,
|
||||||
image_input: Gemma3ImageInputs,
|
image_input: Gemma3ImageInputs,
|
||||||
) -> torch.Tensor:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
pixel_values = image_input["pixel_values"]
|
pixel_values = image_input["pixel_values"]
|
||||||
vision_outputs = self._image_pixels_to_features(
|
num_patches = image_input["num_patches"]
|
||||||
|
|
||||||
|
image_features = self._image_pixels_to_features(
|
||||||
self.vision_tower,
|
self.vision_tower,
|
||||||
pixel_values,
|
pixel_values,
|
||||||
)
|
)
|
||||||
return self.multi_modal_projector(vision_outputs)
|
image_embeds = self.multi_modal_projector(image_features)
|
||||||
|
|
||||||
|
return image_embeds.split(num_patches.tolist())
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
if image_input is None:
|
if image_input is None:
|
||||||
return None
|
return None
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
|
||||||
return vision_embeddings
|
image_features = self._process_image_input(image_input)
|
||||||
|
|
||||||
|
if kwargs.get("v0_path", False):
|
||||||
|
return image_features
|
||||||
|
|
||||||
|
return flatten_2d_lists(
|
||||||
|
scatter_patch_features(*args) for args in zip(
|
||||||
|
image_features,
|
||||||
|
image_input["num_embeds"],
|
||||||
|
image_input["embed_is_patch"],
|
||||||
|
))
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if multimodal_embeddings is None:
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
if multimodal_embeddings is not None:
|
||||||
else:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids, inputs_embeds, multimodal_embeddings,
|
input_ids,
|
||||||
self.config.image_token_index)
|
inputs_embeds,
|
||||||
|
select_patch_features(multimodal_embeddings),
|
||||||
|
self.config.image_token_index,
|
||||||
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -516,6 +689,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
# condition is for v0 compatibility.
|
# condition is for v0 compatibility.
|
||||||
elif inputs_embeds is None:
|
elif inputs_embeds is None:
|
||||||
|
kwargs.update({"v0_path": True})
|
||||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
|
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
@ -524,8 +698,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
kwargs = self.prepare_attn_masks(
|
kwargs = self.prepare_attn_masks(
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
mask_dtype=vision_embeddings.dtype,
|
mask_dtype=self.dtype,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
hidden_states = self.language_model.model(input_ids,
|
hidden_states = self.language_model.model(input_ids,
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from transformers.models.pixtral import PixtralProcessor
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import InputProcessingContext
|
from vllm.inputs import InputProcessingContext
|
||||||
from vllm.jsontree import JSONTree, json_map_leaves
|
from vllm.jsontree import json_map_leaves
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@ -27,8 +27,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
MultiModalInputs, MultiModalKwargs,
|
MultiModalInputs, MultiModalKwargs)
|
||||||
NestedTensors)
|
|
||||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||||
ImageSize, MultiModalDataItems)
|
ImageSize, MultiModalDataItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
@ -44,7 +43,8 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
|||||||
from .siglip import SiglipVisionModel
|
from .siglip import SiglipVisionModel
|
||||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
from .vision import get_vision_encoder_info
|
from .vision import (get_vision_encoder_info, scatter_patch_features,
|
||||||
|
select_patch_features)
|
||||||
|
|
||||||
|
|
||||||
class LlavaImagePixelInputs(TypedDict):
|
class LlavaImagePixelInputs(TypedDict):
|
||||||
@ -76,7 +76,7 @@ class PixtralHFImagePixelInputs(TypedDict):
|
|||||||
Shape: `(batch_size, num_images, num_embeds)`
|
Shape: `(batch_size, num_images, num_embeds)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_patches: Union[torch.Tensor, list[torch.Tensor]]
|
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""Shape: `(batch_size, num_images)`"""
|
"""Shape: `(batch_size, num_images)`"""
|
||||||
|
|
||||||
|
|
||||||
@ -352,15 +352,15 @@ class PixtralHFMultiModalProcessor(
|
|||||||
image_height=pixel_value.shape[-2],
|
image_height=pixel_value.shape[-2],
|
||||||
) for pixel_value in processed_outputs["pixel_values"]
|
) for pixel_value in processed_outputs["pixel_values"]
|
||||||
]
|
]
|
||||||
num_patches = torch.tensor([(ncols + 1) * nrows
|
num_embeds = torch.tensor([(ncols + 1) * nrows
|
||||||
for ncols, nrows in tile_sizes])
|
for ncols, nrows in tile_sizes])
|
||||||
# Each image may result to masks of different sizes, so we need to
|
# Each image may result to masks of different sizes, so we need to
|
||||||
# later use `num_patches` to get per-image masks.
|
# later use `num_embeds` to get per-image masks.
|
||||||
embed_is_patch = [
|
embed_is_patch = [
|
||||||
torch.tensor(([True] * ncols + [False]) * nrows)
|
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||||
for ncols, nrows in tile_sizes
|
for ncols, nrows in tile_sizes
|
||||||
]
|
]
|
||||||
processed_outputs["num_patches"] = num_patches
|
processed_outputs["num_embeds"] = num_embeds
|
||||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
@ -372,7 +372,7 @@ class PixtralHFMultiModalProcessor(
|
|||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
num_patches=MultiModalFieldConfig.batched("image"),
|
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
@ -621,16 +621,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
raise ValueError("Incorrect type of embed_is_patch. "
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
f"Got type: {type(embed_is_patch)}")
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
num_patches = kwargs.pop("num_patches")
|
num_embeds = kwargs.pop("num_embeds")
|
||||||
if not isinstance(num_patches, (torch.Tensor, list)):
|
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of num_patches. "
|
raise ValueError("Incorrect type of num_embeds. "
|
||||||
f"Got type: {type(num_patches)}")
|
f"Got type: {type(num_embeds)}")
|
||||||
|
|
||||||
return PixtralHFImagePixelInputs(
|
return PixtralHFImagePixelInputs(
|
||||||
type="pixel_values_pixtral",
|
type="pixel_values_pixtral",
|
||||||
pixel_values=flatten_bn(pixel_values),
|
pixel_values=flatten_bn(pixel_values),
|
||||||
embed_is_patch=embed_is_patch,
|
embed_is_patch=embed_is_patch,
|
||||||
num_patches=num_patches,
|
num_embeds=num_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
return LlavaImagePixelInputs(
|
return LlavaImagePixelInputs(
|
||||||
@ -716,33 +716,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
image_embeds = torch.split(image_embeds, feature_sizes)
|
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
def _get_mm_embeds(
|
|
||||||
self,
|
|
||||||
features: torch.Tensor, # Shape: (num_patch, d)
|
|
||||||
num_patches: torch.Tensor, # Shape: (num_images,)
|
|
||||||
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
|
|
||||||
) -> tuple[torch.Tensor, ...]:
|
|
||||||
"""Scatter the patch features into a contiguous tensor that corresponds
|
|
||||||
to the embedding tokens defined by the multimodal processor.
|
|
||||||
|
|
||||||
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
|
|
||||||
"""
|
|
||||||
# Insert columns of nan values according to `embed_is_patch`. This work
|
|
||||||
# ideally should be done in `_process_image_input`, but
|
|
||||||
# `_process_image_input` is used in both V0 and V1 path. It's safer to
|
|
||||||
# put the logic here.
|
|
||||||
# FIXME: Move this logic to `_process_image_input` when v0 is
|
|
||||||
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
|
|
||||||
num_patches_per_image: list[int] = num_patches.tolist()
|
|
||||||
|
|
||||||
embeds_flat = features.new_full(
|
|
||||||
(sum(num_patches_per_image), *features.shape[1:]),
|
|
||||||
fill_value=torch.nan,
|
|
||||||
)
|
|
||||||
embeds_flat[embed_is_patch.view(-1)] = features
|
|
||||||
|
|
||||||
return embeds_flat.split(num_patches_per_image)
|
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
@ -757,9 +730,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
return vision_embeddings
|
return vision_embeddings
|
||||||
|
|
||||||
return flatten_2d_lists(
|
return flatten_2d_lists(
|
||||||
self._get_mm_embeds(*args) for args in zip(
|
scatter_patch_features(*args) for args in zip(
|
||||||
vision_embeddings,
|
vision_embeddings,
|
||||||
image_input["num_patches"],
|
image_input["num_embeds"],
|
||||||
image_input["embed_is_patch"],
|
image_input["embed_is_patch"],
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -770,16 +743,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
# Extract the patch tokens
|
|
||||||
patch_embeddings = json_map_leaves(
|
|
||||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
|
||||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
cast(NestedTensors, patch_embeddings),
|
select_patch_features(multimodal_embeddings),
|
||||||
self.config.image_token_index,
|
self.config.image_token_index,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import math
|
|||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property, partial
|
from functools import cached_property, partial
|
||||||
from typing import List, Optional, Set, Tuple, TypedDict, Union, cast
|
from typing import List, Optional, Set, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -24,7 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
split_tensor_along_last_dim,
|
split_tensor_along_last_dim,
|
||||||
tensor_model_parallel_all_gather)
|
tensor_model_parallel_all_gather)
|
||||||
from vllm.jsontree import JSONTree, json_map_leaves
|
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
|
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
|
||||||
SiluAndMul)
|
SiluAndMul)
|
||||||
@ -42,8 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||||
NestedTensors)
|
|
||||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||||
MultiModalDataItems)
|
MultiModalDataItems)
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
@ -59,6 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
|||||||
is_pp_missing_parameter,
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
from .vision import select_patch_features
|
||||||
|
|
||||||
# TODO: hard-coded for now. Consider making it configurable.
|
# TODO: hard-coded for now. Consider making it configurable.
|
||||||
VIT_LAYERS = [-2, -9]
|
VIT_LAYERS = [-2, -9]
|
||||||
@ -1602,16 +1601,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
|||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
assert self.img_patch_id is not None
|
assert self.img_patch_id is not None
|
||||||
|
|
||||||
# Extract the patch tokens scattered in _get_mm_embeds
|
|
||||||
patch_embeddings = json_map_leaves(
|
|
||||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
|
||||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
|
||||||
)
|
|
||||||
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
cast(NestedTensors, patch_embeddings),
|
select_patch_features(multimodal_embeddings),
|
||||||
self.img_patch_id,
|
self.img_patch_id,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import math
|
|||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union, cast
|
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -22,7 +22,6 @@ from transformers.tokenization_utils_base import TextInput
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
from vllm.jsontree import JSONTree, json_map_leaves
|
|
||||||
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -48,7 +47,8 @@ from vllm.utils import flatten_2d_lists
|
|||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs,
|
||||||
|
scatter_patch_features, select_patch_features)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from xformers import ops as xops
|
from xformers import ops as xops
|
||||||
@ -77,7 +77,7 @@ class PixtralImagePixelInputs(TypedDict):
|
|||||||
Shape: `(batch_size, num_images, num_embeds)`
|
Shape: `(batch_size, num_images, num_embeds)`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_patches: Union[torch.Tensor, list[torch.Tensor]]
|
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||||
"""Shape: `(batch_size, num_images)`"""
|
"""Shape: `(batch_size, num_images)`"""
|
||||||
|
|
||||||
|
|
||||||
@ -153,7 +153,7 @@ class PixtralProcessorAdapter:
|
|||||||
images_processed = list[torch.Tensor]()
|
images_processed = list[torch.Tensor]()
|
||||||
images_tokens = list[torch.Tensor]()
|
images_tokens = list[torch.Tensor]()
|
||||||
images_embed_is_patch = list[torch.Tensor]()
|
images_embed_is_patch = list[torch.Tensor]()
|
||||||
images_num_patches = list[int]()
|
images_num_embeds = list[int]()
|
||||||
|
|
||||||
for image in images:
|
for image in images:
|
||||||
image_inputs = self.image_processor(ImageChunk(image=image))
|
image_inputs = self.image_processor(ImageChunk(image=image))
|
||||||
@ -163,13 +163,13 @@ class PixtralProcessorAdapter:
|
|||||||
images_processed.append(image_processed)
|
images_processed.append(image_processed)
|
||||||
images_tokens.append(image_tokens)
|
images_tokens.append(image_tokens)
|
||||||
images_embed_is_patch.append(image_tokens == image_token_id)
|
images_embed_is_patch.append(image_tokens == image_token_id)
|
||||||
images_num_patches.append(len(image_tokens))
|
images_num_embeds.append(len(image_tokens))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||||
"images": images_processed,
|
"images": images_processed,
|
||||||
"embed_is_patch": images_embed_is_patch,
|
"embed_is_patch": images_embed_is_patch,
|
||||||
"num_patches": torch.tensor(images_num_patches),
|
"num_embeds": torch.tensor(images_num_embeds),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -273,7 +273,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
|||||||
return dict(
|
return dict(
|
||||||
images=MultiModalFieldConfig.batched("image"),
|
images=MultiModalFieldConfig.batched("image"),
|
||||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||||
num_patches=MultiModalFieldConfig.batched("image"),
|
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
@ -394,16 +394,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
raise ValueError("Incorrect type of embed_is_patch. "
|
raise ValueError("Incorrect type of embed_is_patch. "
|
||||||
f"Got type: {type(embed_is_patch)}")
|
f"Got type: {type(embed_is_patch)}")
|
||||||
|
|
||||||
num_patches = kwargs.pop("num_patches")
|
num_embeds = kwargs.pop("num_embeds")
|
||||||
if not isinstance(num_patches, (torch.Tensor, list)):
|
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||||
raise ValueError("Incorrect type of num_patches. "
|
raise ValueError("Incorrect type of num_embeds. "
|
||||||
f"Got type: {type(num_patches)}")
|
f"Got type: {type(num_embeds)}")
|
||||||
|
|
||||||
return PixtralImagePixelInputs(
|
return PixtralImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
images=flatten_bn(images),
|
images=flatten_bn(images),
|
||||||
embed_is_patch=embed_is_patch,
|
embed_is_patch=embed_is_patch,
|
||||||
num_patches=num_patches,
|
num_embeds=num_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_image_input(
|
def _process_image_input(
|
||||||
@ -433,33 +433,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
image_embeds = torch.split(image_embeds, feature_sizes)
|
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||||
return image_embeds
|
return image_embeds
|
||||||
|
|
||||||
def _get_mm_embeds(
|
|
||||||
self,
|
|
||||||
features: torch.Tensor, # Shape: (num_patch, d)
|
|
||||||
num_patches: torch.Tensor, # Shape: (num_images,)
|
|
||||||
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
|
|
||||||
) -> tuple[torch.Tensor, ...]:
|
|
||||||
"""Scatter the patch features into a contiguous tensor that corresponds
|
|
||||||
to the embedding tokens defined by the multimodal processor.
|
|
||||||
|
|
||||||
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
|
|
||||||
"""
|
|
||||||
# Insert columns of nan values according to `embed_is_patch`. This work
|
|
||||||
# ideally should be done in `_process_image_input`, but
|
|
||||||
# `_process_image_input` is used in both V0 and V1 path. It's safer to
|
|
||||||
# put the logic here.
|
|
||||||
# FIXME: Move this logic to `_process_image_input` when v0 is
|
|
||||||
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
|
|
||||||
num_patches_per_image: list[int] = num_patches.tolist()
|
|
||||||
|
|
||||||
embeds_flat = features.new_full(
|
|
||||||
(sum(num_patches_per_image), *features.shape[1:]),
|
|
||||||
fill_value=torch.nan,
|
|
||||||
)
|
|
||||||
embeds_flat[embed_is_patch.view(-1)] = features
|
|
||||||
|
|
||||||
return embeds_flat.split(num_patches_per_image)
|
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
@ -472,9 +445,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return image_features
|
return image_features
|
||||||
|
|
||||||
return flatten_2d_lists(
|
return flatten_2d_lists(
|
||||||
self._get_mm_embeds(*args) for args in zip(
|
scatter_patch_features(*args) for args in zip(
|
||||||
image_features,
|
image_features,
|
||||||
image_input["num_patches"],
|
image_input["num_embeds"],
|
||||||
image_input["embed_is_patch"],
|
image_input["embed_is_patch"],
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -485,15 +458,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
if multimodal_embeddings is not None:
|
if multimodal_embeddings is not None:
|
||||||
# Extract the patch tokens
|
|
||||||
patch_embeddings = json_map_leaves(
|
|
||||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
|
||||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
|
||||||
)
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
input_ids,
|
input_ids,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
cast(NestedTensors, patch_embeddings),
|
select_patch_features(multimodal_embeddings),
|
||||||
self.vision_args.image_token_id,
|
self.vision_args.image_token_id,
|
||||||
)
|
)
|
||||||
return inputs_embeds
|
return inputs_embeds
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -9,9 +9,12 @@ from transformers import PretrainedConfig
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.selector import (backend_name_to_enum,
|
from vllm.attention.selector import (backend_name_to_enum,
|
||||||
get_global_forced_attn_backend)
|
get_global_forced_attn_backend)
|
||||||
|
from vllm.jsontree import JSONTree, json_map_leaves
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
|
|
||||||
|
from .interfaces import MultiModalEmbeddings
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||||
@ -148,3 +151,48 @@ def resolve_visual_encoder_outputs(
|
|||||||
if post_layer_norm is not None and uses_last_layer:
|
if post_layer_norm is not None and uses_last_layer:
|
||||||
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||||
return torch.cat(hs_pool, dim=-1)
|
return torch.cat(hs_pool, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def scatter_patch_features(
|
||||||
|
features: torch.Tensor,
|
||||||
|
num_embeds: torch.Tensor,
|
||||||
|
embed_is_patch: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
"""
|
||||||
|
Scatter the patch features into a contiguous tensor that corresponds
|
||||||
|
to the embedding tokens defined by the multimodal processor.
|
||||||
|
|
||||||
|
The rest of the values in the tensor are set to NaN so that they
|
||||||
|
can be filtered out by :func`select_patch_features`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: The patch features, concatenated across each image.
|
||||||
|
Shape: `(num_patch, feature_depth)`
|
||||||
|
num_embeds: The number of image embeddings for each image.
|
||||||
|
Shape: `(num_images,)`
|
||||||
|
embed_is_patch: A boolean mask indicating which image embeddings
|
||||||
|
correspond to patch tokens for each image.
|
||||||
|
Shape: `(num_images, num_embeds)`
|
||||||
|
"""
|
||||||
|
num_embeds_per_image: list[int] = num_embeds.tolist()
|
||||||
|
|
||||||
|
embeds_flat = features.new_full(
|
||||||
|
(sum(num_embeds_per_image), features.shape[-1]),
|
||||||
|
fill_value=torch.nan,
|
||||||
|
)
|
||||||
|
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2)
|
||||||
|
|
||||||
|
return embeds_flat.split(num_embeds_per_image)
|
||||||
|
|
||||||
|
|
||||||
|
def select_patch_features(
|
||||||
|
multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
|
||||||
|
"""
|
||||||
|
Given the outputs of :func:`scatter_patch_features`, return only
|
||||||
|
the values that correspond to patch features.
|
||||||
|
"""
|
||||||
|
selected_features = json_map_leaves(
|
||||||
|
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||||
|
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||||
|
)
|
||||||
|
return cast(MultiModalEmbeddings, selected_features)
|
||||||
|
|||||||
@ -511,8 +511,35 @@ def iter_token_matches(
|
|||||||
start_idx += 1
|
start_idx += 1
|
||||||
|
|
||||||
|
|
||||||
|
def replace_token_matches(
|
||||||
|
token_ids: list[int],
|
||||||
|
match_ids: list[int],
|
||||||
|
new_ids: list[int],
|
||||||
|
) -> list[int]:
|
||||||
|
"""
|
||||||
|
Replace each occurrence of :code:`match_ids` in :code:`token_ids`
|
||||||
|
with :code:`new_ids`.
|
||||||
|
|
||||||
|
Note that empty matches are ignored.
|
||||||
|
"""
|
||||||
|
out_seqs = list[list[int]]()
|
||||||
|
prev_end_idx = 0
|
||||||
|
|
||||||
|
for match in iter_token_matches(token_ids, match_ids):
|
||||||
|
start_idx = match.start_idx
|
||||||
|
end_idx = match.end_idx
|
||||||
|
|
||||||
|
out_seqs.append(token_ids[prev_end_idx:start_idx])
|
||||||
|
out_seqs.append(new_ids)
|
||||||
|
prev_end_idx = end_idx
|
||||||
|
|
||||||
|
out_seqs.append(token_ids[prev_end_idx:])
|
||||||
|
|
||||||
|
return flatten_2d_lists(out_seqs)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False)
|
@dataclass(repr=False)
|
||||||
class _PromptTargetMatch(ABC):
|
class PromptTargetMatch(ABC):
|
||||||
_origin: BoundPromptUpdate
|
_origin: BoundPromptUpdate
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -535,7 +562,7 @@ class _PromptTargetMatch(ABC):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False)
|
@dataclass(repr=False)
|
||||||
class _PromptTargetIndexMatch(_PromptTargetMatch):
|
class _PromptTargetIndexMatch(PromptTargetMatch):
|
||||||
match_idx: int
|
match_idx: int
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -548,7 +575,7 @@ class _PromptTargetIndexMatch(_PromptTargetMatch):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False)
|
@dataclass(repr=False)
|
||||||
class _PromptTargetTokenMatch(_PromptTargetMatch):
|
class _PromptTargetTokenMatch(PromptTargetMatch):
|
||||||
match: _TokenMatch
|
match: _TokenMatch
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -561,7 +588,7 @@ class _PromptTargetTokenMatch(_PromptTargetMatch):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(repr=False)
|
@dataclass(repr=False)
|
||||||
class _PromptTargetTextMatch(_PromptTargetMatch):
|
class _PromptTargetTextMatch(PromptTargetMatch):
|
||||||
match: re.Match[str]
|
match: re.Match[str]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -594,7 +621,7 @@ class PlaceholderFeaturesInfo:
|
|||||||
def find_token_matches(
|
def find_token_matches(
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
prompt_updates: Sequence[BoundPromptUpdate],
|
prompt_updates: Sequence[BoundPromptUpdate],
|
||||||
) -> Sequence[_PromptTargetMatch]:
|
) -> Sequence[PromptTargetMatch]:
|
||||||
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
|
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
|
||||||
|
|
||||||
def get_matches(update: BoundPromptUpdate):
|
def get_matches(update: BoundPromptUpdate):
|
||||||
@ -620,7 +647,7 @@ def find_token_matches(
|
|||||||
def find_text_matches(
|
def find_text_matches(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
prompt_updates: Sequence[BoundPromptUpdate],
|
prompt_updates: Sequence[BoundPromptUpdate],
|
||||||
) -> Sequence[_PromptTargetMatch]:
|
) -> Sequence[PromptTargetMatch]:
|
||||||
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
|
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
|
||||||
|
|
||||||
def get_matches(update: BoundPromptUpdate):
|
def get_matches(update: BoundPromptUpdate):
|
||||||
@ -645,15 +672,15 @@ def find_text_matches(
|
|||||||
|
|
||||||
def _resolve_matches(
|
def _resolve_matches(
|
||||||
prompt: PromptSeq,
|
prompt: PromptSeq,
|
||||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||||
) -> list[_PromptTargetMatch]:
|
) -> list[PromptTargetMatch]:
|
||||||
"""
|
"""
|
||||||
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
|
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
|
||||||
and sort them such that earlier matches take priority over later ones.
|
and sort them such that earlier matches take priority over later ones.
|
||||||
"""
|
"""
|
||||||
matches = [m for matches in mm_matches.values() for m in matches]
|
matches = [m for matches in mm_matches.values() for m in matches]
|
||||||
|
|
||||||
seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt)
|
seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt)
|
||||||
|
|
||||||
for match in matches:
|
for match in matches:
|
||||||
for idx in range(match.start_idx, match.end_idx):
|
for idx in range(match.start_idx, match.end_idx):
|
||||||
@ -669,7 +696,7 @@ def _resolve_matches(
|
|||||||
|
|
||||||
def _apply_matches(
|
def _apply_matches(
|
||||||
prompt: _S,
|
prompt: _S,
|
||||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> list[_S]:
|
) -> list[_S]:
|
||||||
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
||||||
@ -718,7 +745,7 @@ def _apply_matches(
|
|||||||
|
|
||||||
def apply_token_matches(
|
def apply_token_matches(
|
||||||
prompt: list[int],
|
prompt: list[int],
|
||||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
||||||
@ -732,7 +759,7 @@ def apply_token_matches(
|
|||||||
|
|
||||||
def apply_text_matches(
|
def apply_text_matches(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||||
mm_item_counts: Mapping[str, int],
|
mm_item_counts: Mapping[str, int],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
||||||
@ -1055,14 +1082,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
Given the original multi-modal items for this modality
|
Given the original multi-modal items for this modality
|
||||||
and HF-processed data, output the updates to perform.
|
and HF-processed data, output the updates to perform.
|
||||||
|
|
||||||
Notes:
|
The information returned by this method is used to update token inputs
|
||||||
- You should not assume that HF processor always performs prompt
|
which bypass the HF processor. It is also used to update the output of
|
||||||
updates: in :meth:`_apply_hf_processor_missing`, this method
|
HF processor if the HF process does not apply prompt updates to text
|
||||||
is called on text-only and multimodal-only inputs separately,
|
inputs.
|
||||||
instead of passing them in the same call.
|
|
||||||
- The update information returned by this method is also used to
|
Moreover, this information is critical to determine the token positions
|
||||||
determine the placeholder token positions for each multi-modal
|
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
|
||||||
item.
|
for each multi-modal item.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -1357,6 +1384,22 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
it = (update.bind(tokenizer) for update in prompt_updates)
|
it = (update.bind(tokenizer) for update in prompt_updates)
|
||||||
return dict(full_groupby_modality(it))
|
return dict(full_groupby_modality(it))
|
||||||
|
|
||||||
|
def _apply_token_matches(
|
||||||
|
self,
|
||||||
|
prompt: list[int],
|
||||||
|
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||||
|
mm_item_counts: Mapping[str, int],
|
||||||
|
) -> list[int]:
|
||||||
|
return apply_token_matches(prompt, mm_matches, mm_item_counts)
|
||||||
|
|
||||||
|
def _apply_text_matches(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||||
|
mm_item_counts: Mapping[str, int],
|
||||||
|
) -> str:
|
||||||
|
return apply_text_matches(prompt, mm_matches, mm_item_counts)
|
||||||
|
|
||||||
def _apply_prompt_updates(
|
def _apply_prompt_updates(
|
||||||
self,
|
self,
|
||||||
token_ids: list[int],
|
token_ids: list[int],
|
||||||
@ -1388,7 +1431,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
mm_match_counts.get(modality, 0) >= item_count
|
mm_match_counts.get(modality, 0) >= item_count
|
||||||
for modality, item_count in mm_item_counts.items()
|
for modality, item_count in mm_item_counts.items()
|
||||||
): # yapf: disable
|
): # yapf: disable
|
||||||
token_ids = apply_token_matches(
|
token_ids = self._apply_token_matches(
|
||||||
token_ids,
|
token_ids,
|
||||||
mm_token_matches,
|
mm_token_matches,
|
||||||
mm_item_counts,
|
mm_item_counts,
|
||||||
@ -1406,7 +1449,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
modality: find_text_matches(text, updates)
|
modality: find_text_matches(text, updates)
|
||||||
for modality, updates in mm_prompt_updates.items()
|
for modality, updates in mm_prompt_updates.items()
|
||||||
}
|
}
|
||||||
text = apply_text_matches(
|
text = self._apply_text_matches(
|
||||||
text,
|
text,
|
||||||
mm_text_matches,
|
mm_text_matches,
|
||||||
mm_item_counts,
|
mm_item_counts,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user