diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 8f76d895fdd29..a3f136c5667d5 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -9,12 +9,10 @@ from typing import TYPE_CHECKING, NamedTuple, Optional import numpy as np import pytest from PIL import Image, ImageChops -from transformers import AutoConfig, AutoTokenizer from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import (MediaConnector, - merge_and_sort_multimodal_metadata, - repeat_and_pad_placeholder_tokens) + merge_and_sort_multimodal_metadata) if TYPE_CHECKING: from vllm.multimodal.hasher import MultiModalHashDict @@ -136,71 +134,6 @@ async def test_fetch_image_local_files(image_url: str): f"file://{temp_dir}/../{os.path.basename(image_url)}") -@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) -def test_repeat_and_pad_placeholder_tokens(model): - config = AutoConfig.from_pretrained(model) - image_token_id = config.image_token_index - - tokenizer = AutoTokenizer.from_pretrained(model) - - test_cases = [ - ( - "", - 2, - "", - [32000, 32000], - [{ "offset": 0, "length": 2 }], - ), - ( - "", - 2, - "", - [32000, 32000, 32000], - [{ "offset": 0, "length": 2 }], - ), - ( - "", - [3, 2], - "", - [32000, 32000, 32000, 32000, 32000], - [{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }], - ), - ( - "Image:Image:!", - [3, 2], - "Image:Image:!", - [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], - [{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }], - ), - ( - "", - [3, 2], - "", - [32000, 32000, 32000], - [{ "offset": 0, "length": 3 }], - ), - ] # yapf: disable - - for ( - prompt, - repeat_count, - expected_prompt, - expected_token_ids, - expected_ranges, - ) in test_cases: - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer=tokenizer, - prompt=prompt, - prompt_token_ids=tokenizer.encode(prompt, - add_special_tokens=False), - placeholder_token_id=image_token_id, - repeat_count=repeat_count, - ) - assert new_prompt == expected_prompt - assert new_token_ids == expected_token_ids - assert ranges == expected_ranges - - # Used for the next two tests related to `merge_and_sort_multimodal_metadata`. class TestCase(NamedTuple): mm_positions: "MultiModalPlaceholderDict" diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index ad381e1d1d00d..8e4fb7eac49c0 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -12,8 +12,6 @@ from PIL import Image import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection -from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import AnyTokenizer from .audio import AudioMediaIO from .base import MediaIO @@ -21,8 +19,6 @@ from .image import ImageEmbeddingMediaIO, ImageMediaIO from .inputs import PlaceholderRange from .video import VideoMediaIO -logger = init_logger(__name__) - _M = TypeVar("_M") if TYPE_CHECKING: @@ -296,121 +292,6 @@ def encode_video_base64(frames: npt.NDArray) -> str: return video_io.encode_base64(frames) -# Utilities for input processors -_T = TypeVar("_T", str, int) - - -def repeat_and_pad_token( - token: _T, - *, - repeat_count: int = 1, - pad_token_left: Optional[_T] = None, - pad_token_right: Optional[_T] = None, -) -> list[_T]: - replacement = [token] * repeat_count - if pad_token_left is not None: - replacement = [pad_token_left] + replacement - if pad_token_right is not None: - replacement = replacement + [pad_token_right] - - return replacement - - -def repeat_and_pad_placeholder_tokens( - tokenizer: AnyTokenizer, - prompt: Optional[str], - prompt_token_ids: list[int], - *, - placeholder_token_id: int, - repeat_count: Union[int, list[int]], - pad_token_left: Optional[int] = None, - pad_token_right: Optional[int] = None, -) -> tuple[Optional[str], list[int], list[PlaceholderRange]]: - if isinstance(repeat_count, int): - repeat_count = [repeat_count] - - if prompt is None: - new_prompt = None - else: - placeholder_token_str = tokenizer.decode(placeholder_token_id) - pad_token_str_left = (None if pad_token_left is None else - tokenizer.decode(pad_token_left)) - pad_token_str_right = (None if pad_token_right is None else - tokenizer.decode(pad_token_right)) - - placeholder_token_count = prompt.count(placeholder_token_str) - # This is an arbitrary number to distinguish between the two cases - if placeholder_token_count > 16: - logger.warning( - "Please follow the prompt format that is " - "documented on HuggingFace which does not involve " - "repeating %s tokens.", placeholder_token_str) - if placeholder_token_count < len(repeat_count): - logger.warning( - "The number of multi-modal placeholder tokens in the prompt " - "is less than the number of multi-modal inputs. Extra " - "placeholder tokens will be treated as plain text") - repeat_count = repeat_count[:placeholder_token_count] - - prompt_parts = prompt.split(placeholder_token_str, - maxsplit=len(repeat_count)) - new_prompt = "" - for i, repeat_count_item in enumerate(repeat_count): - replacement_str = "".join( - repeat_and_pad_token( - placeholder_token_str, - repeat_count=repeat_count_item, - pad_token_left=pad_token_str_left, - pad_token_right=pad_token_str_right, - )) - # The image tokens are removed to be consistent with HuggingFace - new_prompt += prompt_parts[i] + replacement_str - new_prompt += prompt_parts[-1] - - new_token_ids = list[int]() - placeholder_ranges = list[PlaceholderRange]() - placeholder_token_idx = 0 - for i, token in enumerate(prompt_token_ids): - if token == placeholder_token_id: - curr_repeat_count = repeat_count[placeholder_token_idx] - replacement_ids = repeat_and_pad_token( - placeholder_token_id, - repeat_count=curr_repeat_count, - pad_token_left=pad_token_left, - pad_token_right=pad_token_right, - ) - offset = len(new_token_ids) - if pad_token_left is not None: - offset += 1 - placeholder_ranges.append({ - "offset": offset, - "length": curr_repeat_count, - }) - new_token_ids.extend(replacement_ids) - placeholder_token_idx += 1 - - # No need to further scan the list since we replaced all tokens - if placeholder_token_idx >= len(repeat_count): - new_token_ids.extend(prompt_token_ids[i + 1:]) - break - else: - new_token_ids.append(token) - - return new_prompt, new_token_ids, placeholder_ranges - - -def consecutive_placeholder_ranges( - num_items: int, - item_size: int, - initial_offset: int = 0) -> list[PlaceholderRange]: - """Returns a list of consecutive PlaceholderRanges of a fixed size""" - - return [ - PlaceholderRange(offset=initial_offset + i * item_size, - length=item_size) for i in range(num_items) - ] - - def merge_and_sort_multimodal_metadata( mm_positions: "MultiModalPlaceholderDict", mm_hashes: Optional["MultiModalHashDict"], diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index bb883acdb44b6..dc0d2d59fea7f 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -10,8 +10,7 @@ if TYPE_CHECKING: import numpy.typing as npt from vllm.lora.request import LoRARequest - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.base import PlaceholderRange + from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.request import Request diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 01a5cb5548bb4..351b358155801 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -2,13 +2,13 @@ # Datastructures defining an input batch from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, cast +from typing import Optional, cast import numpy as np import torch from vllm.lora.request import LoRARequest -from vllm.multimodal import MultiModalKwargs +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import swap_dict_values from vllm.v1.outputs import LogprobsTensors @@ -18,9 +18,6 @@ from vllm.v1.worker.block_table import BlockTable _SAMPLING_EPS = 1e-5 -if TYPE_CHECKING: - from vllm.multimodal.inputs import PlaceholderRange - @dataclass class CachedRequestState: @@ -29,7 +26,7 @@ class CachedRequestState: prompt_token_ids: list[int] prompt: Optional[str] mm_inputs: list[MultiModalKwargs] - mm_positions: list["PlaceholderRange"] + mm_positions: list[PlaceholderRange] sampling_params: SamplingParams generator: Optional[torch.Generator]