mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 03:25:01 +08:00
[V0 Deprecation] Remove MultiModalPlaceholderMap (#25366)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
6d0b827cbd
commit
f92d952632
@ -959,7 +959,6 @@ def make_test_metadata(
|
|||||||
return attn_backend_obj.make_metadata(
|
return attn_backend_obj.make_metadata(
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
||||||
multi_modal_placeholder_index_maps=None,
|
|
||||||
enable_kv_scales_calculation=True,
|
enable_kv_scales_calculation=True,
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
@ -1009,7 +1008,6 @@ def make_test_metadata(
|
|||||||
return attn_backend_obj.make_metadata(
|
return attn_backend_obj.make_metadata(
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
slot_mapping=kv_mmap.slot_mapping,
|
slot_mapping=kv_mmap.slot_mapping,
|
||||||
multi_modal_placeholder_index_maps=None,
|
|
||||||
enable_kv_scales_calculation=True,
|
enable_kv_scales_calculation=True,
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
|||||||
@ -10,7 +10,6 @@ from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionType:
|
class AttentionType:
|
||||||
@ -116,15 +115,6 @@ class AttentionMetadata:
|
|||||||
# in block 0, and 1st slot in block 1, respectively.
|
# in block 0, and 1st slot in block 1, respectively.
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
# The index maps that relate multi-modal embeddings to the corresponding
|
|
||||||
# placeholders.
|
|
||||||
#
|
|
||||||
# N.B. These aren't really related to attention and don't belong on this
|
|
||||||
# type -- this is just a temporary solution to make them available to
|
|
||||||
# `model_executable`.
|
|
||||||
multi_modal_placeholder_index_maps: Optional[Dict[
|
|
||||||
str, MultiModalPlaceholderMap.IndexMap]]
|
|
||||||
|
|
||||||
# Enable/disable KV scales calculation. This is so that we can disable the
|
# Enable/disable KV scales calculation. This is so that we can disable the
|
||||||
# calculation until after prefill and cuda graph capture.
|
# calculation until after prefill and cuda graph capture.
|
||||||
enable_kv_scales_calculation: bool
|
enable_kv_scales_calculation: bool
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -12,7 +11,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
AttentionMetadataBuilder)
|
AttentionMetadataBuilder)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
|
||||||
from vllm.utils import async_tensor_h2d
|
from vllm.utils import async_tensor_h2d
|
||||||
|
|
||||||
# Placeholder attention backend for models like Mamba and pooling models that
|
# Placeholder attention backend for models like Mamba and pooling models that
|
||||||
@ -141,8 +139,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
|
|||||||
num_prefill_tokens=self.num_prefill_tokens,
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
num_decode_tokens=0,
|
num_decode_tokens=0,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
multi_modal_placeholder_index_maps=self.
|
|
||||||
multi_modal_placeholder_index_maps,
|
|
||||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
@ -178,7 +174,6 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
|
|||||||
num_prefill_tokens=0,
|
num_prefill_tokens=0,
|
||||||
num_decode_tokens=self.num_decode_tokens,
|
num_decode_tokens=self.num_decode_tokens,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
multi_modal_placeholder_index_maps=None,
|
|
||||||
enable_kv_scales_calculation=True,
|
enable_kv_scales_calculation=True,
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
@ -210,9 +205,6 @@ class PlaceholderAttentionMetadataBuilder(
|
|||||||
self.prefill_seq_lens: List[int] = []
|
self.prefill_seq_lens: List[int] = []
|
||||||
self.context_lens: List[int] = []
|
self.context_lens: List[int] = []
|
||||||
self.curr_seq_lens: List[int] = []
|
self.curr_seq_lens: List[int] = []
|
||||||
self.multimodal_placeholder_maps: Dict[
|
|
||||||
str,
|
|
||||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
|
||||||
self.num_prefills = 0
|
self.num_prefills = 0
|
||||||
self.num_prefill_tokens = 0
|
self.num_prefill_tokens = 0
|
||||||
self.num_decode_tokens = 0
|
self.num_decode_tokens = 0
|
||||||
@ -232,12 +224,6 @@ class PlaceholderAttentionMetadataBuilder(
|
|||||||
self.context_lens.append(context_len)
|
self.context_lens.append(context_len)
|
||||||
|
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
|
||||||
if mm_maps:
|
|
||||||
for modality, placeholders in mm_maps.items():
|
|
||||||
self.multimodal_placeholder_maps[modality].extend(
|
|
||||||
placeholders)
|
|
||||||
|
|
||||||
self.num_prefills += 1
|
self.num_prefills += 1
|
||||||
self.num_prefill_tokens += token_len
|
self.num_prefill_tokens += token_len
|
||||||
self.prefill_seq_lens.append(seq_len)
|
self.prefill_seq_lens.append(seq_len)
|
||||||
@ -295,12 +281,6 @@ class PlaceholderAttentionMetadataBuilder(
|
|||||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||||
device, self.runner.pin_memory)
|
device, self.runner.pin_memory)
|
||||||
|
|
||||||
placeholder_index_maps = {
|
|
||||||
modality: placeholder_map.index_map()
|
|
||||||
for modality, placeholder_map in
|
|
||||||
self.multimodal_placeholder_maps.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Placeholders
|
# Placeholders
|
||||||
slot_mapping_tensor = torch.empty(0)
|
slot_mapping_tensor = torch.empty(0)
|
||||||
block_tables = torch.empty(0)
|
block_tables = torch.empty(0)
|
||||||
@ -308,7 +288,6 @@ class PlaceholderAttentionMetadataBuilder(
|
|||||||
return PlaceholderAttentionMetadata(
|
return PlaceholderAttentionMetadata(
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
slot_mapping=slot_mapping_tensor,
|
slot_mapping=slot_mapping_tensor,
|
||||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
|
||||||
enable_kv_scales_calculation=True,
|
enable_kv_scales_calculation=True,
|
||||||
num_prefill_tokens=self.num_prefill_tokens,
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Attention backend utils"""
|
"""Attention backend utils"""
|
||||||
from collections import defaultdict
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
@ -15,7 +14,6 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
|
|||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalPlaceholderMap
|
|
||||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -135,9 +133,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
|||||||
self.context_lens: List[int] = []
|
self.context_lens: List[int] = []
|
||||||
self.block_tables: List[List[int]] = []
|
self.block_tables: List[List[int]] = []
|
||||||
self.curr_seq_lens: List[int] = []
|
self.curr_seq_lens: List[int] = []
|
||||||
self.multimodal_placeholder_maps: Dict[
|
|
||||||
str,
|
|
||||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
|
||||||
self.num_prefills = 0
|
self.num_prefills = 0
|
||||||
self.num_prefill_tokens = 0
|
self.num_prefill_tokens = 0
|
||||||
self.num_decode_tokens = 0
|
self.num_decode_tokens = 0
|
||||||
@ -154,12 +149,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
|||||||
inter_data.curr_sliding_window_blocks):
|
inter_data.curr_sliding_window_blocks):
|
||||||
self.context_lens.append(context_len)
|
self.context_lens.append(context_len)
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
|
||||||
if mm_maps:
|
|
||||||
for modality, placeholders in mm_maps.items():
|
|
||||||
self.multimodal_placeholder_maps[modality].extend(
|
|
||||||
placeholders)
|
|
||||||
|
|
||||||
self.num_prefills += 1
|
self.num_prefills += 1
|
||||||
self.num_prefill_tokens += token_len
|
self.num_prefill_tokens += token_len
|
||||||
self.prefill_seq_lens.append(seq_len)
|
self.prefill_seq_lens.append(seq_len)
|
||||||
@ -254,16 +243,10 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
|||||||
self.runner.pin_memory)
|
self.runner.pin_memory)
|
||||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||||
device, self.runner.pin_memory)
|
device, self.runner.pin_memory)
|
||||||
placeholder_index_maps = {
|
|
||||||
modality: placeholder_map.index_map()
|
|
||||||
for modality, placeholder_map in
|
|
||||||
self.multimodal_placeholder_maps.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return self._metadata_cls( # type: ignore
|
return self._metadata_cls( # type: ignore
|
||||||
num_prefills=self.num_prefills,
|
num_prefills=self.num_prefills,
|
||||||
slot_mapping=slot_mapping_tensor,
|
slot_mapping=slot_mapping_tensor,
|
||||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
|
||||||
enable_kv_scales_calculation=True,
|
enable_kv_scales_calculation=True,
|
||||||
num_prefill_tokens=self.num_prefill_tokens,
|
num_prefill_tokens=self.num_prefill_tokens,
|
||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
@ -320,7 +303,6 @@ class CommonAttentionState(AttentionState):
|
|||||||
num_prefill_tokens=0,
|
num_prefill_tokens=0,
|
||||||
num_decode_tokens=batch_size,
|
num_decode_tokens=batch_size,
|
||||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||||
multi_modal_placeholder_index_maps=None,
|
|
||||||
enable_kv_scales_calculation=True,
|
enable_kv_scales_calculation=True,
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from .base import MultiModalPlaceholderMap
|
|
||||||
from .hasher import MultiModalHasher
|
from .hasher import MultiModalHasher
|
||||||
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
|
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
|
||||||
MultiModalDataDict, MultiModalKwargs,
|
MultiModalDataDict, MultiModalKwargs,
|
||||||
@ -27,7 +26,6 @@ __all__ = [
|
|||||||
"MultiModalKwargs",
|
"MultiModalKwargs",
|
||||||
"MultiModalKwargsItems",
|
"MultiModalKwargsItems",
|
||||||
"MultiModalPlaceholderDict",
|
"MultiModalPlaceholderDict",
|
||||||
"MultiModalPlaceholderMap",
|
|
||||||
"MultiModalUUIDDict",
|
"MultiModalUUIDDict",
|
||||||
"NestedTensors",
|
"NestedTensors",
|
||||||
"MULTIMODAL_REGISTRY",
|
"MULTIMODAL_REGISTRY",
|
||||||
|
|||||||
@ -3,83 +3,11 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generic, NamedTuple, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
class MultiModalPlaceholderMap:
|
|
||||||
"""
|
|
||||||
Relates multi-modal embeddings to their corresponding placeholders.
|
|
||||||
|
|
||||||
Note: This is only used in V0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class IndexMap(NamedTuple):
|
|
||||||
src: list[int]
|
|
||||||
dest: list[int]
|
|
||||||
|
|
||||||
src_ranges: list[range]
|
|
||||||
"""
|
|
||||||
The indices of the multi-modal embeddings that will replace the
|
|
||||||
corresponding placeholder embeddings pointed to by ``dest_ranges``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
src_len: int
|
|
||||||
"""
|
|
||||||
The total number of flattened multi-modal embeddings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dest_ranges: list[range]
|
|
||||||
"""
|
|
||||||
The indices of the placeholder embeddings that will be replaced by the
|
|
||||||
multimodal embeddings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dest_len: int
|
|
||||||
"""
|
|
||||||
The total number of embeddings in the destination tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.src_ranges = []
|
|
||||||
self.src_len = 0
|
|
||||||
self.dest_ranges = []
|
|
||||||
self.dest_len = 0
|
|
||||||
|
|
||||||
def extend(self, other: "MultiModalPlaceholderMap"):
|
|
||||||
"""
|
|
||||||
Adds the placeholders from another ``MultiModalPlaceholderMap`` to this
|
|
||||||
instance based on the source and destination tensors being
|
|
||||||
concatenated.
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.src_ranges.extend(
|
|
||||||
range(self.src_len + r.start, self.src_len + r.stop)
|
|
||||||
for r in other.src_ranges)
|
|
||||||
self.src_len += other.src_len
|
|
||||||
self.dest_ranges.extend(
|
|
||||||
range(self.dest_len + r.start, self.dest_len + r.stop)
|
|
||||||
for r in other.dest_ranges)
|
|
||||||
self.dest_len += other.dest_len
|
|
||||||
|
|
||||||
def index_map(self) -> "IndexMap":
|
|
||||||
"""
|
|
||||||
Finalizes the placeholder map into lists of indices that can be used to
|
|
||||||
index the source and destination tensors.
|
|
||||||
"""
|
|
||||||
|
|
||||||
src_indices = [i for r in self.src_ranges for i in r]
|
|
||||||
dest_indices = [i for r in self.dest_ranges for i in r]
|
|
||||||
|
|
||||||
if len(src_indices) != len(dest_indices):
|
|
||||||
raise ValueError(
|
|
||||||
f"The number of source ({len(src_indices)}) and destination "
|
|
||||||
f"indices ({len(dest_indices)}) must be the same.")
|
|
||||||
|
|
||||||
return self.IndexMap(src=src_indices, dest=dest_indices)
|
|
||||||
|
|
||||||
|
|
||||||
class MediaIO(ABC, Generic[_T]):
|
class MediaIO(ABC, Generic[_T]):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -425,7 +425,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
|||||||
num_prompt_req], # prefill
|
num_prompt_req], # prefill
|
||||||
query_start_loc=query_start_loc_cpu[:num_reqs +
|
query_start_loc=query_start_loc_cpu[:num_reqs +
|
||||||
1], # for logits index
|
1], # for logits index
|
||||||
multi_modal_placeholder_index_maps=None,
|
|
||||||
enable_kv_scales_calculation=False,
|
enable_kv_scales_calculation=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user