From e5949e5ae013692ba09cc52472cf441675f5a270 Mon Sep 17 00:00:00 2001 From: Chenxi Yang Date: Sun, 3 Aug 2025 22:15:14 -0700 Subject: [PATCH] Remove index_put from MM embeddings merging (#22105) Co-authored-by: Chenxi Yang --- vllm/model_executor/models/utils.py | 42 ++++++++++++++++------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 62deb68035b9..28508e1bac1e 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -393,7 +393,7 @@ def merge_multimodal_embeddings_from_map( inputs_embeds: torch.Tensor, multimodal_embeddings: NestedTensors, placeholder_map: MultiModalPlaceholderMap.IndexMap) -> torch.Tensor: """ - Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided + Merge ``multimodal_embeddings`` into ``inputs_embeds`` using the provided placeholder map . Note: @@ -418,17 +418,23 @@ def _merge_multimodal_embeddings( Note: This updates ``inputs_embeds`` in place. """ - num_expected_tokens = is_multimodal.sum().item() - assert isinstance(num_expected_tokens, int) - flattened = _flatten_embeddings(multimodal_embeddings) - if flattened.shape[0] != num_expected_tokens: - expr = _embedding_count_expression(multimodal_embeddings) - raise ValueError( - f"Attempted to assign {expr} = {flattened.shape[0]} " - f"multimodal tokens to {num_expected_tokens} placeholders") + try: + # This is equivalent to: inputs_embeds[is_multimodal] = flattened. + inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), flattened) + except RuntimeError as e: + num_expected_tokens = is_multimodal.sum().item() + assert isinstance(num_expected_tokens, int) + + if flattened.shape[0] != num_expected_tokens: + expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( + f"Attempted to assign {expr} = {flattened.shape[0]} " + f"multimodal tokens to {num_expected_tokens} placeholders" + ) from e + else: + raise ValueError("Error during masked scatter operation") from e - inputs_embeds[is_multimodal] = flattened return inputs_embeds @@ -478,11 +484,11 @@ def merge_multimodal_embeddings( Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the positions in ``inputs_embeds`` corresponding to placeholder tokens in ``input_ids``. - - ``placeholder_token_id`` can be a list of token ids (e.g, token ids - of img_start, img_break, and img_end tokens) when needed: This means - the order of these tokens in the ``input_ids`` MUST MATCH the order of - their embeddings in ``multimodal_embeddings`` since we need to + + ``placeholder_token_id`` can be a list of token ids (e.g, token ids + of img_start, img_break, and img_end tokens) when needed: This means + the order of these tokens in the ``input_ids`` MUST MATCH the order of + their embeddings in ``multimodal_embeddings`` since we need to slice-merge instead of individually scattering. For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where @@ -491,9 +497,9 @@ def merge_multimodal_embeddings( - I is image embedding token - B is image break token - E is image end token. - - Then the image embeddings (that correspond to I's) from vision encoder - must be padded with embeddings of S, B, and E in the same order of + + Then the image embeddings (that correspond to I's) from vision encoder + must be padded with embeddings of S, B, and E in the same order of input_ids for a correct embedding merge. Note: