Remove index_put from MM embeddings merging (#22105)

Co-authored-by: Chenxi Yang <cxyang@meta.com>
This commit is contained in:
Chenxi Yang 2025-08-03 22:15:14 -07:00 committed by GitHub
parent 49bcd893e7
commit e5949e5ae0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -393,7 +393,7 @@ def merge_multimodal_embeddings_from_map(
inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors,
placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor: placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor:
""" """
Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided
placeholder map . placeholder map .
Note: Note:
@ -418,17 +418,23 @@ def _merge_multimodal_embeddings(
Note: Note:
This updates ``inputs_embeds`` in place. This updates ``inputs_embeds`` in place.
""" """
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings) flattened = _flatten_embeddings(multimodal_embeddings)
if flattened.shape[0] != num_expected_tokens: try:
expr = _embedding_count_expression(multimodal_embeddings) # This is equivalent to: inputs_embeds[is_multimodal] = flattened.
raise ValueError( inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened)
f"Attempted to assign {expr} = {flattened.shape[0]} " except RuntimeError as e:
f"multimodal tokens to {num_expected_tokens} placeholders") num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
else:
raise ValueError("Error during masked scatter operation") from e
inputs_embeds[is_multimodal] = flattened
return inputs_embeds return inputs_embeds
@ -478,11 +484,11 @@ def merge_multimodal_embeddings(
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``. ``input_ids``.
``placeholder_token_id`` can be a list of token ids (e.g, token ids ``placeholder_token_id`` can be a list of token ids (e.g, token ids
of img_start, img_break, and img_end tokens) when needed: This means of img_start, img_break, and img_end tokens) when needed: This means
the order of these tokens in the ``input_ids`` MUST MATCH the order of the order of these tokens in the ``input_ids`` MUST MATCH the order of
their embeddings in ``multimodal_embeddings`` since we need to their embeddings in ``multimodal_embeddings`` since we need to
slice-merge instead of individually scattering. slice-merge instead of individually scattering.
For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
@ -491,9 +497,9 @@ def merge_multimodal_embeddings(
- I is image embedding token - I is image embedding token
- B is image break token - B is image break token
- E is image end token. - E is image end token.
Then the image embeddings (that correspond to I's) from vision encoder Then the image embeddings (that correspond to I's) from vision encoder
must be padded with embeddings of S, B, and E in the same order of must be padded with embeddings of S, B, and E in the same order of
input_ids for a correct embedding merge. input_ids for a correct embedding merge.
Note: Note: