[Misc] Remove unused utils and clean up imports (#15708)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-03-29 00:41:16 +08:00 committed by GitHub
parent 70e132244a
commit c6bc0034d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 5 additions and 195 deletions

View File

@ -9,12 +9,10 @@ from typing import TYPE_CHECKING, NamedTuple, Optional
import numpy as np
import pytest
from PIL import Image, ImageChops
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import (MediaConnector,
merge_and_sort_multimodal_metadata,
repeat_and_pad_placeholder_tokens)
merge_and_sort_multimodal_metadata)
if TYPE_CHECKING:
from vllm.multimodal.hasher import MultiModalHashDict
@ -136,71 +134,6 @@ async def test_fetch_image_local_files(image_url: str):
f"file://{temp_dir}/../{os.path.basename(image_url)}")
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
def test_repeat_and_pad_placeholder_tokens(model):
config = AutoConfig.from_pretrained(model)
image_token_id = config.image_token_index
tokenizer = AutoTokenizer.from_pretrained(model)
test_cases = [
(
"<image>",
2,
"<image><image>",
[32000, 32000],
[{ "offset": 0, "length": 2 }],
),
(
"<image><image>",
2,
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 2 }],
),
(
"<image><image>",
[3, 2],
"<image><image><image><image><image>",
[32000, 32000, 32000, 32000, 32000],
[{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }],
),
(
"Image:<image>Image:<image>!",
[3, 2],
"Image:<image><image><image>Image:<image><image>!",
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
[{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }],
),
(
"<image>",
[3, 2],
"<image><image><image>",
[32000, 32000, 32000],
[{ "offset": 0, "length": 3 }],
),
] # yapf: disable
for (
prompt,
repeat_count,
expected_prompt,
expected_token_ids,
expected_ranges,
) in test_cases:
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
tokenizer=tokenizer,
prompt=prompt,
prompt_token_ids=tokenizer.encode(prompt,
add_special_tokens=False),
placeholder_token_id=image_token_id,
repeat_count=repeat_count,
)
assert new_prompt == expected_prompt
assert new_token_ids == expected_token_ids
assert ranges == expected_ranges
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
class TestCase(NamedTuple):
mm_positions: "MultiModalPlaceholderDict"

View File

@ -12,8 +12,6 @@ from PIL import Image
import vllm.envs as envs
from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .audio import AudioMediaIO
from .base import MediaIO
@ -21,8 +19,6 @@ from .image import ImageEmbeddingMediaIO, ImageMediaIO
from .inputs import PlaceholderRange
from .video import VideoMediaIO
logger = init_logger(__name__)
_M = TypeVar("_M")
if TYPE_CHECKING:
@ -296,121 +292,6 @@ def encode_video_base64(frames: npt.NDArray) -> str:
return video_io.encode_base64(frames)
# Utilities for input processors
_T = TypeVar("_T", str, int)
def repeat_and_pad_token(
token: _T,
*,
repeat_count: int = 1,
pad_token_left: Optional[_T] = None,
pad_token_right: Optional[_T] = None,
) -> list[_T]:
replacement = [token] * repeat_count
if pad_token_left is not None:
replacement = [pad_token_left] + replacement
if pad_token_right is not None:
replacement = replacement + [pad_token_right]
return replacement
def repeat_and_pad_placeholder_tokens(
tokenizer: AnyTokenizer,
prompt: Optional[str],
prompt_token_ids: list[int],
*,
placeholder_token_id: int,
repeat_count: Union[int, list[int]],
pad_token_left: Optional[int] = None,
pad_token_right: Optional[int] = None,
) -> tuple[Optional[str], list[int], list[PlaceholderRange]]:
if isinstance(repeat_count, int):
repeat_count = [repeat_count]
if prompt is None:
new_prompt = None
else:
placeholder_token_str = tokenizer.decode(placeholder_token_id)
pad_token_str_left = (None if pad_token_left is None else
tokenizer.decode(pad_token_left))
pad_token_str_right = (None if pad_token_right is None else
tokenizer.decode(pad_token_right))
placeholder_token_count = prompt.count(placeholder_token_str)
# This is an arbitrary number to distinguish between the two cases
if placeholder_token_count > 16:
logger.warning(
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating %s tokens.", placeholder_token_str)
if placeholder_token_count < len(repeat_count):
logger.warning(
"The number of multi-modal placeholder tokens in the prompt "
"is less than the number of multi-modal inputs. Extra "
"placeholder tokens will be treated as plain text")
repeat_count = repeat_count[:placeholder_token_count]
prompt_parts = prompt.split(placeholder_token_str,
maxsplit=len(repeat_count))
new_prompt = ""
for i, repeat_count_item in enumerate(repeat_count):
replacement_str = "".join(
repeat_and_pad_token(
placeholder_token_str,
repeat_count=repeat_count_item,
pad_token_left=pad_token_str_left,
pad_token_right=pad_token_str_right,
))
# The image tokens are removed to be consistent with HuggingFace
new_prompt += prompt_parts[i] + replacement_str
new_prompt += prompt_parts[-1]
new_token_ids = list[int]()
placeholder_ranges = list[PlaceholderRange]()
placeholder_token_idx = 0
for i, token in enumerate(prompt_token_ids):
if token == placeholder_token_id:
curr_repeat_count = repeat_count[placeholder_token_idx]
replacement_ids = repeat_and_pad_token(
placeholder_token_id,
repeat_count=curr_repeat_count,
pad_token_left=pad_token_left,
pad_token_right=pad_token_right,
)
offset = len(new_token_ids)
if pad_token_left is not None:
offset += 1
placeholder_ranges.append({
"offset": offset,
"length": curr_repeat_count,
})
new_token_ids.extend(replacement_ids)
placeholder_token_idx += 1
# No need to further scan the list since we replaced all tokens
if placeholder_token_idx >= len(repeat_count):
new_token_ids.extend(prompt_token_ids[i + 1:])
break
else:
new_token_ids.append(token)
return new_prompt, new_token_ids, placeholder_ranges
def consecutive_placeholder_ranges(
num_items: int,
item_size: int,
initial_offset: int = 0) -> list[PlaceholderRange]:
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
return [
PlaceholderRange(offset=initial_offset + i * item_size,
length=item_size) for i in range(num_items)
]
def merge_and_sort_multimodal_metadata(
mm_positions: "MultiModalPlaceholderDict",
mm_hashes: Optional["MultiModalHashDict"],

View File

@ -10,8 +10,7 @@ if TYPE_CHECKING:
import numpy.typing as npt
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request

View File

@ -2,13 +2,13 @@
# Datastructures defining an input batch
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, cast
from typing import Optional, cast
import numpy as np
import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
@ -18,9 +18,6 @@ from vllm.v1.worker.block_table import BlockTable
_SAMPLING_EPS = 1e-5
if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange
@dataclass
class CachedRequestState:
@ -29,7 +26,7 @@ class CachedRequestState:
prompt_token_ids: list[int]
prompt: Optional[str]
mm_inputs: list[MultiModalKwargs]
mm_positions: list["PlaceholderRange"]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
generator: Optional[torch.Generator]