mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 23:53:14 +08:00
[Misc] Remove redundant num_embeds (#15443)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
a9e879b316
commit
5994430b84
@ -63,9 +63,6 @@ class Gemma3ImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
|
||||
@ -317,11 +314,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||
for image_repl in image_repl_features
|
||||
]
|
||||
num_embeds = [
|
||||
len(image_repl_feature_tokens)
|
||||
for image_repl_feature_tokens in image_repls_feature_tokens
|
||||
]
|
||||
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
|
||||
|
||||
vocab = tokenizer.get_vocab()
|
||||
image_token_id = vocab[tokenizer.image_token]
|
||||
@ -354,7 +346,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
"image", num_crops + 1),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -583,7 +574,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
num_embeds = kwargs.pop("num_embeds", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
@ -601,10 +591,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
@ -613,7 +599,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
pixel_values=self._validate_pixel_values(pixel_values),
|
||||
num_patches=num_crops + 1,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
def _image_pixels_to_features(
|
||||
@ -656,7 +641,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
|
||||
@ -69,9 +69,6 @@ class InternVLImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
class InternVLImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
@ -426,7 +423,6 @@ class BaseInternVLProcessor(ABC):
|
||||
tokenizer = self.tokenizer
|
||||
image_token_id = self.image_token_id
|
||||
|
||||
num_embeds = list[int]()
|
||||
embed_is_patch = list[torch.Tensor]()
|
||||
|
||||
for pixel_values in pixel_values_lst:
|
||||
@ -438,11 +434,9 @@ class BaseInternVLProcessor(ABC):
|
||||
add_special_tokens=False)
|
||||
|
||||
text = [t.replace('<image>', image_repl.full, 1) for t in text]
|
||||
num_embeds.append(len(feature_tokens))
|
||||
embed_is_patch.append(
|
||||
torch.tensor(feature_tokens) == image_token_id)
|
||||
|
||||
image_inputs["num_embeds"] = torch.tensor(num_embeds)
|
||||
image_inputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
@ -607,7 +601,6 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
"image", image_num_patches),
|
||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
@ -840,7 +833,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
num_embeds = kwargs.pop("num_embeds", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values_flat is None and image_embeds is None:
|
||||
@ -873,10 +865,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
|
||||
image_num_patches = flatten_bn(image_num_patches, concat=True)
|
||||
|
||||
@ -886,7 +874,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
pixel_values_flat),
|
||||
num_patches=image_num_patches,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
@ -941,7 +928,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
|
||||
@ -76,9 +76,6 @@ class PixtralHFImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
class LlavaImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
@ -358,15 +355,10 @@ class PixtralHFMultiModalProcessor(
|
||||
image_height=pixel_value.shape[-2],
|
||||
) for pixel_value in processed_outputs["pixel_values"]
|
||||
]
|
||||
num_embeds = torch.tensor([(ncols + 1) * nrows
|
||||
for ncols, nrows in tile_sizes])
|
||||
# Each image may result to masks of different sizes, so we need to
|
||||
# later use `num_embeds` to get per-image masks.
|
||||
embed_is_patch = [
|
||||
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||
for ncols, nrows in tile_sizes
|
||||
]
|
||||
processed_outputs["num_embeds"] = num_embeds
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
@ -378,7 +370,6 @@ class PixtralHFMultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
@ -627,16 +618,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_embeds = kwargs.pop("num_embeds")
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
return PixtralHFImagePixelInputs(
|
||||
type="pixel_values_pixtral",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
@ -738,7 +723,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
vision_embeddings,
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
|
||||
@ -77,9 +77,6 @@ class PixtralImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
class PixtralProcessorAdapter:
|
||||
"""
|
||||
@ -153,7 +150,6 @@ class PixtralProcessorAdapter:
|
||||
images_processed = list[torch.Tensor]()
|
||||
images_tokens = list[torch.Tensor]()
|
||||
images_embed_is_patch = list[torch.Tensor]()
|
||||
images_num_embeds = list[int]()
|
||||
|
||||
for image in images:
|
||||
image_inputs = self.image_processor(ImageChunk(image=image))
|
||||
@ -163,13 +159,11 @@ class PixtralProcessorAdapter:
|
||||
images_processed.append(image_processed)
|
||||
images_tokens.append(image_tokens)
|
||||
images_embed_is_patch.append(image_tokens == image_token_id)
|
||||
images_num_embeds.append(len(image_tokens))
|
||||
|
||||
return {
|
||||
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||
"images": images_processed,
|
||||
"embed_is_patch": images_embed_is_patch,
|
||||
"num_embeds": torch.tensor(images_num_embeds),
|
||||
}
|
||||
|
||||
|
||||
@ -273,7 +267,6 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
return dict(
|
||||
images=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -394,16 +387,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_embeds = kwargs.pop("num_embeds")
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
return PixtralImagePixelInputs(
|
||||
type="pixel_values",
|
||||
images=flatten_bn(images),
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
@ -447,7 +434,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
|
||||
@ -155,7 +155,6 @@ def resolve_visual_encoder_outputs(
|
||||
|
||||
def scatter_patch_features(
|
||||
features: torch.Tensor,
|
||||
num_embeds: torch.Tensor,
|
||||
embed_is_patch: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
@ -168,13 +167,35 @@ def scatter_patch_features(
|
||||
Args:
|
||||
features: The patch features, concatenated across each image.
|
||||
Shape: `(num_patch, feature_depth)`
|
||||
num_embeds: The number of image embeddings for each image.
|
||||
Shape: `(num_images,)`
|
||||
embed_is_patch: A boolean mask indicating which image embeddings
|
||||
correspond to patch tokens for each image.
|
||||
Shape: `(num_images, num_embeds)`
|
||||
|
||||
Note:
|
||||
The original code only considers patch tokens as feature
|
||||
tokens, but our processor considers all image-related tokens
|
||||
as feature tokens because the feature tokens need to be
|
||||
consecutive in `input_ids`.
|
||||
|
||||
Example:
|
||||
A simplified example for one image:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Embedding tokens (from HF processor):
|
||||
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
|
||||
|
||||
embed_is_patch (from HF processor):
|
||||
[ False True True False True True False False ]
|
||||
|
||||
Encoder outputs (from model):
|
||||
[ p1 p2 p3 p4 ]
|
||||
|
||||
The resulting embedding tensor is:
|
||||
[ nan p1 p2 nan p3 p4 nan nan ]
|
||||
"""
|
||||
num_embeds_per_image: list[int] = num_embeds.tolist()
|
||||
num_images, num_embeds = embed_is_patch.shape
|
||||
num_embeds_per_image = [num_embeds] * num_images
|
||||
|
||||
embeds_flat = features.new_full(
|
||||
(sum(num_embeds_per_image), features.shape[-1]),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user