[Misc] Remove redundant num_embeds (#15443)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-25 18:27:57 +08:00 committed by GitHub
parent a9e879b316
commit 5994430b84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 25 additions and 64 deletions

View File

@ -63,9 +63,6 @@ class Gemma3ImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs
@ -317,11 +314,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
num_embeds = [
len(image_repl_feature_tokens)
for image_repl_feature_tokens in image_repls_feature_tokens
]
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
@ -354,7 +346,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@ -583,7 +574,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
@ -601,10 +591,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
@ -613,7 +599,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_crops + 1,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
def _image_pixels_to_features(
@ -656,7 +641,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

View File

@ -69,9 +69,6 @@ class InternVLImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class InternVLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@ -426,7 +423,6 @@ class BaseInternVLProcessor(ABC):
tokenizer = self.tokenizer
image_token_id = self.image_token_id
num_embeds = list[int]()
embed_is_patch = list[torch.Tensor]()
for pixel_values in pixel_values_lst:
@ -438,11 +434,9 @@ class BaseInternVLProcessor(ABC):
add_special_tokens=False)
text = [t.replace('<image>', image_repl.full, 1) for t in text]
num_embeds.append(len(feature_tokens))
embed_is_patch.append(
torch.tensor(feature_tokens) == image_token_id)
image_inputs["num_embeds"] = torch.tensor(num_embeds)
image_inputs["embed_is_patch"] = embed_is_patch
text_inputs = self.tokenizer(text)
@ -607,7 +601,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
@ -840,7 +833,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values_flat is None and image_embeds is None:
@ -873,10 +865,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
@ -886,7 +874,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
pixel_values_flat),
num_patches=image_num_patches,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
raise AssertionError("This line should be unreachable.")
@ -941,7 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

View File

@ -76,9 +76,6 @@ class PixtralHFImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@ -358,15 +355,10 @@ class PixtralHFMultiModalProcessor(
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
num_embeds = torch.tensor([(ncols + 1) * nrows
for ncols, nrows in tile_sizes])
# Each image may result to masks of different sizes, so we need to
# later use `num_embeds` to get per-image masks.
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["num_embeds"] = num_embeds
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
@ -378,7 +370,6 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@ -627,16 +618,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
return LlavaImagePixelInputs(
@ -738,7 +723,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
vision_embeddings,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

View File

@ -77,9 +77,6 @@ class PixtralImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
class PixtralProcessorAdapter:
"""
@ -153,7 +150,6 @@ class PixtralProcessorAdapter:
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
images_embed_is_patch = list[torch.Tensor]()
images_num_embeds = list[int]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
@ -163,13 +159,11 @@ class PixtralProcessorAdapter:
images_processed.append(image_processed)
images_tokens.append(image_tokens)
images_embed_is_patch.append(image_tokens == image_token_id)
images_num_embeds.append(len(image_tokens))
return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
"embed_is_patch": images_embed_is_patch,
"num_embeds": torch.tensor(images_num_embeds),
}
@ -273,7 +267,6 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
return dict(
images=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@ -394,16 +387,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
num_embeds = kwargs.pop("num_embeds")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
def _process_image_input(
@ -447,7 +434,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))

View File

@ -155,7 +155,6 @@ def resolve_visual_encoder_outputs(
def scatter_patch_features(
features: torch.Tensor,
num_embeds: torch.Tensor,
embed_is_patch: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""
@ -168,13 +167,35 @@ def scatter_patch_features(
Args:
features: The patch features, concatenated across each image.
Shape: `(num_patch, feature_depth)`
num_embeds: The number of image embeddings for each image.
Shape: `(num_images,)`
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`
Note:
The original code only considers patch tokens as feature
tokens, but our processor considers all image-related tokens
as feature tokens because the feature tokens need to be
consecutive in `input_ids`.
Example:
A simplified example for one image:
.. code-block::
Embedding tokens (from HF processor):
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
embed_is_patch (from HF processor):
[ False True True False True True False False ]
Encoder outputs (from model):
[ p1 p2 p3 p4 ]
The resulting embedding tensor is:
[ nan p1 p2 nan p3 p4 nan nan ]
"""
num_embeds_per_image: list[int] = num_embeds.tolist()
num_images, num_embeds = embed_is_patch.shape
num_embeds_per_image = [num_embeds] * num_images
embeds_flat = features.new_full(
(sum(num_embeds_per_image), features.shape[-1]),