Remove index_put from MM embeddings merging (#22105)

Co-authored-by: Chenxi Yang <cxyang@meta.com>
This commit is contained in:
Chenxi Yang 2025-08-03 22:15:14 -07:00 committed by GitHub
parent 49bcd893e7
commit e5949e5ae0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -418,17 +418,23 @@ def _merge_multimodal_embeddings(
Note: Note:
This updates ``inputs_embeds`` in place. This updates ``inputs_embeds`` in place.
""" """
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
flattened = _flatten_embeddings(multimodal_embeddings) flattened = _flatten_embeddings(multimodal_embeddings)
if flattened.shape[0] != num_expected_tokens: try:
expr = _embedding_count_expression(multimodal_embeddings) # This is equivalent to: inputs_embeds[is_multimodal] = flattened.
raise ValueError( inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened)
f"Attempted to assign {expr} = {flattened.shape[0]} " except RuntimeError as e:
f"multimodal tokens to {num_expected_tokens} placeholders") num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
else:
raise ValueError("Error during masked scatter operation") from e
inputs_embeds[is_multimodal] = flattened
return inputs_embeds return inputs_embeds