diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 8d6ce381976b..39ea07309134 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -959,7 +959,6 @@ def make_test_metadata( return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, @@ -1009,7 +1008,6 @@ def make_test_metadata( return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ab7ef2112b08..1b392cd7c88d 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -10,7 +10,6 @@ from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple, import torch from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey -from vllm.multimodal import MultiModalPlaceholderMap class AttentionType: @@ -116,15 +115,6 @@ class AttentionMetadata: # in block 0, and 1st slot in block 1, respectively. slot_mapping: torch.Tensor - # The index maps that relate multi-modal embeddings to the corresponding - # placeholders. - # - # N.B. These aren't really related to attention and don't belong on this - # type -- this is just a temporary solution to make them available to - # `model_executable`. - multi_modal_placeholder_index_maps: Optional[Dict[ - str, MultiModalPlaceholderMap.IndexMap]] - # Enable/disable KV scales calculation. This is so that we can disable the # calculation until after prefill and cuda graph capture. enable_kv_scales_calculation: bool diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index f82d28938f45..cddeb2cf39bf 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from dataclasses import dataclass from itertools import accumulate -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch @@ -12,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionMetadataBuilder) from vllm.attention.backends.utils import CommonAttentionState -from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d # Placeholder attention backend for models like Mamba and pooling models that @@ -141,8 +139,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=0, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, @@ -178,7 +174,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata): num_prefill_tokens=0, num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=seq_lens_tensor, @@ -210,9 +205,6 @@ class PlaceholderAttentionMetadataBuilder( self.prefill_seq_lens: List[int] = [] self.context_lens: List[int] = [] self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -232,12 +224,6 @@ class PlaceholderAttentionMetadataBuilder( self.context_lens.append(context_len) if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -295,12 +281,6 @@ class PlaceholderAttentionMetadataBuilder( seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } - # Placeholders slot_mapping_tensor = torch.empty(0) block_tables = torch.empty(0) @@ -308,7 +288,6 @@ class PlaceholderAttentionMetadataBuilder( return PlaceholderAttentionMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, - multi_modal_placeholder_index_maps=placeholder_index_maps, enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 3f15580872c7..33d8168f8a13 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention backend utils""" -from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from itertools import accumulate @@ -15,7 +14,6 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, from vllm.attention.backends.abstract import AttentionType from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.multimodal import MultiModalPlaceholderMap from vllm.utils import async_tensor_h2d, make_tensor_with_pad logger = init_logger(__name__) @@ -135,9 +133,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] - self.multimodal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) self.num_prefills = 0 self.num_prefill_tokens = 0 self.num_decode_tokens = 0 @@ -154,12 +149,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: - mm_maps = inter_data.multi_modal_placeholder_maps - if mm_maps: - for modality, placeholders in mm_maps.items(): - self.multimodal_placeholder_maps[modality].extend( - placeholders) - self.num_prefills += 1 self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) @@ -254,16 +243,10 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): self.runner.pin_memory) seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, device, self.runner.pin_memory) - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - self.multimodal_placeholder_maps.items() - } return self._metadata_cls( # type: ignore num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, - multi_modal_placeholder_index_maps=placeholder_index_maps, enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, @@ -320,7 +303,6 @@ class CommonAttentionState(AttentionState): num_prefill_tokens=0, num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 7ffa732cf370..8ea79078465e 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from .base import MultiModalPlaceholderMap from .hasher import MultiModalHasher from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins, MultiModalDataDict, MultiModalKwargs, @@ -27,7 +26,6 @@ __all__ = [ "MultiModalKwargs", "MultiModalKwargsItems", "MultiModalPlaceholderDict", - "MultiModalPlaceholderMap", "MultiModalUUIDDict", "NestedTensors", "MULTIMODAL_REGISTRY", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index e0edb3e883ed..faffddd57199 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -3,83 +3,11 @@ from abc import ABC, abstractmethod from pathlib import Path -from typing import Generic, NamedTuple, TypeVar +from typing import Generic, TypeVar _T = TypeVar("_T") -class MultiModalPlaceholderMap: - """ - Relates multi-modal embeddings to their corresponding placeholders. - - Note: This is only used in V0. - """ - - class IndexMap(NamedTuple): - src: list[int] - dest: list[int] - - src_ranges: list[range] - """ - The indices of the multi-modal embeddings that will replace the - corresponding placeholder embeddings pointed to by ``dest_ranges``. - """ - - src_len: int - """ - The total number of flattened multi-modal embeddings. - """ - - dest_ranges: list[range] - """ - The indices of the placeholder embeddings that will be replaced by the - multimodal embeddings. - """ - - dest_len: int - """ - The total number of embeddings in the destination tensor. - """ - - def __init__(self): - self.src_ranges = [] - self.src_len = 0 - self.dest_ranges = [] - self.dest_len = 0 - - def extend(self, other: "MultiModalPlaceholderMap"): - """ - Adds the placeholders from another ``MultiModalPlaceholderMap`` to this - instance based on the source and destination tensors being - concatenated. - """ - - self.src_ranges.extend( - range(self.src_len + r.start, self.src_len + r.stop) - for r in other.src_ranges) - self.src_len += other.src_len - self.dest_ranges.extend( - range(self.dest_len + r.start, self.dest_len + r.stop) - for r in other.dest_ranges) - self.dest_len += other.dest_len - - def index_map(self) -> "IndexMap": - """ - Finalizes the placeholder map into lists of indices that can be used to - index the source and destination tensors. - """ - - src_indices = [i for r in self.src_ranges for i in r] - dest_indices = [i for r in self.dest_ranges for i in r] - - if len(src_indices) != len(dest_indices): - raise ValueError( - f"The number of source ({len(src_indices)}) and destination " - f"indices ({len(dest_indices)}) must be the same.") - - return self.IndexMap(src=src_indices, dest=dest_indices) - - class MediaIO(ABC, Generic[_T]): @abstractmethod diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 6627164c9879..7e485fea2689 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -425,7 +425,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): num_prompt_req], # prefill query_start_loc=query_start_loc_cpu[:num_reqs + 1], # for logits index - multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, )