[V0 Deprecation] Remove MultiModalPlaceholderMap (#25366)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-09-22 16:49:19 +08:00 committed by GitHub
parent 6d0b827cbd
commit f92d952632
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 2 additions and 128 deletions

View File

@ -959,7 +959,6 @@ def make_test_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
@ -1009,7 +1008,6 @@ def make_test_metadata(
return attn_backend_obj.make_metadata(
num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,

View File

@ -10,7 +10,6 @@ from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.multimodal import MultiModalPlaceholderMap
class AttentionType:
@ -116,15 +115,6 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
#
# N.B. These aren't really related to attention and don't belong on this
# type -- this is just a temporary solution to make them available to
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool

View File

@ -1,10 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from itertools import accumulate
from typing import Dict, List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type
import torch
@ -12,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d
# Placeholder attention backend for models like Mamba and pooling models that
@ -141,8 +139,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
@ -178,7 +174,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
@ -210,9 +205,6 @@ class PlaceholderAttentionMetadataBuilder(
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
@ -232,12 +224,6 @@ class PlaceholderAttentionMetadataBuilder(
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
@ -295,12 +281,6 @@ class PlaceholderAttentionMetadataBuilder(
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
# Placeholders
slot_mapping_tensor = torch.empty(0)
block_tables = torch.empty(0)
@ -308,7 +288,6 @@ class PlaceholderAttentionMetadataBuilder(
return PlaceholderAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend utils"""
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
@ -15,7 +14,6 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
from vllm.attention.backends.abstract import AttentionType
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = init_logger(__name__)
@ -135,9 +133,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
@ -154,12 +149,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
inter_data.curr_sliding_window_blocks):
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
@ -254,16 +243,10 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
@ -320,7 +303,6 @@ class CommonAttentionState(AttentionState):
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .base import MultiModalPlaceholderMap
from .hasher import MultiModalHasher
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalKwargs,
@ -27,7 +26,6 @@ __all__ = [
"MultiModalKwargs",
"MultiModalKwargsItems",
"MultiModalPlaceholderDict",
"MultiModalPlaceholderMap",
"MultiModalUUIDDict",
"NestedTensors",
"MULTIMODAL_REGISTRY",

View File

@ -3,83 +3,11 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Generic, NamedTuple, TypeVar
from typing import Generic, TypeVar
_T = TypeVar("_T")
class MultiModalPlaceholderMap:
"""
Relates multi-modal embeddings to their corresponding placeholders.
Note: This is only used in V0.
"""
class IndexMap(NamedTuple):
src: list[int]
dest: list[int]
src_ranges: list[range]
"""
The indices of the multi-modal embeddings that will replace the
corresponding placeholder embeddings pointed to by ``dest_ranges``.
"""
src_len: int
"""
The total number of flattened multi-modal embeddings.
"""
dest_ranges: list[range]
"""
The indices of the placeholder embeddings that will be replaced by the
multimodal embeddings.
"""
dest_len: int
"""
The total number of embeddings in the destination tensor.
"""
def __init__(self):
self.src_ranges = []
self.src_len = 0
self.dest_ranges = []
self.dest_len = 0
def extend(self, other: "MultiModalPlaceholderMap"):
"""
Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
instance based on the source and destination tensors being
concatenated.
"""
self.src_ranges.extend(
range(self.src_len + r.start, self.src_len + r.stop)
for r in other.src_ranges)
self.src_len += other.src_len
self.dest_ranges.extend(
range(self.dest_len + r.start, self.dest_len + r.stop)
for r in other.dest_ranges)
self.dest_len += other.dest_len
def index_map(self) -> "IndexMap":
"""
Finalizes the placeholder map into lists of indices that can be used to
index the source and destination tensors.
"""
src_indices = [i for r in self.src_ranges for i in r]
dest_indices = [i for r in self.dest_ranges for i in r]
if len(src_indices) != len(dest_indices):
raise ValueError(
f"The number of source ({len(src_indices)}) and destination "
f"indices ({len(dest_indices)}) must be the same.")
return self.IndexMap(src=src_indices, dest=dest_indices)
class MediaIO(ABC, Generic[_T]):
@abstractmethod

View File

@ -425,7 +425,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
num_prompt_req], # prefill
query_start_loc=query_start_loc_cpu[:num_reqs +
1], # for logits index
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
)