mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:44:30 +08:00
[Misc] Remove unused utils and clean up imports (#15708)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
70e132244a
commit
c6bc0034d0
@ -9,12 +9,10 @@ from typing import TYPE_CHECKING, NamedTuple, Optional
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image, ImageChops
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.utils import (MediaConnector,
|
||||
merge_and_sort_multimodal_metadata,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
merge_and_sort_multimodal_metadata)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.hasher import MultiModalHashDict
|
||||
@ -136,71 +134,6 @@ async def test_fetch_image_local_files(image_url: str):
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
def test_repeat_and_pad_placeholder_tokens(model):
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
image_token_id = config.image_token_index
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
test_cases = [
|
||||
(
|
||||
"<image>",
|
||||
2,
|
||||
"<image><image>",
|
||||
[32000, 32000],
|
||||
[{ "offset": 0, "length": 2 }],
|
||||
),
|
||||
(
|
||||
"<image><image>",
|
||||
2,
|
||||
"<image><image><image>",
|
||||
[32000, 32000, 32000],
|
||||
[{ "offset": 0, "length": 2 }],
|
||||
),
|
||||
(
|
||||
"<image><image>",
|
||||
[3, 2],
|
||||
"<image><image><image><image><image>",
|
||||
[32000, 32000, 32000, 32000, 32000],
|
||||
[{ "offset": 0, "length": 3 }, { "offset": 3, "length": 2 }],
|
||||
),
|
||||
(
|
||||
"Image:<image>Image:<image>!",
|
||||
[3, 2],
|
||||
"Image:<image><image><image>Image:<image><image>!",
|
||||
[9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
[{ "offset": 2, "length": 3 }, { "offset": 7, "length": 2 }],
|
||||
),
|
||||
(
|
||||
"<image>",
|
||||
[3, 2],
|
||||
"<image><image><image>",
|
||||
[32000, 32000, 32000],
|
||||
[{ "offset": 0, "length": 3 }],
|
||||
),
|
||||
] # yapf: disable
|
||||
|
||||
for (
|
||||
prompt,
|
||||
repeat_count,
|
||||
expected_prompt,
|
||||
expected_token_ids,
|
||||
expected_ranges,
|
||||
) in test_cases:
|
||||
new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=tokenizer.encode(prompt,
|
||||
add_special_tokens=False),
|
||||
placeholder_token_id=image_token_id,
|
||||
repeat_count=repeat_count,
|
||||
)
|
||||
assert new_prompt == expected_prompt
|
||||
assert new_token_ids == expected_token_ids
|
||||
assert ranges == expected_ranges
|
||||
|
||||
|
||||
# Used for the next two tests related to `merge_and_sort_multimodal_metadata`.
|
||||
class TestCase(NamedTuple):
|
||||
mm_positions: "MultiModalPlaceholderDict"
|
||||
|
||||
@ -12,8 +12,6 @@ from PIL import Image
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.connections import HTTPConnection, global_http_connection
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .audio import AudioMediaIO
|
||||
from .base import MediaIO
|
||||
@ -21,8 +19,6 @@ from .image import ImageEmbeddingMediaIO, ImageMediaIO
|
||||
from .inputs import PlaceholderRange
|
||||
from .video import VideoMediaIO
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_M = TypeVar("_M")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -296,121 +292,6 @@ def encode_video_base64(frames: npt.NDArray) -> str:
|
||||
return video_io.encode_base64(frames)
|
||||
|
||||
|
||||
# Utilities for input processors
|
||||
_T = TypeVar("_T", str, int)
|
||||
|
||||
|
||||
def repeat_and_pad_token(
|
||||
token: _T,
|
||||
*,
|
||||
repeat_count: int = 1,
|
||||
pad_token_left: Optional[_T] = None,
|
||||
pad_token_right: Optional[_T] = None,
|
||||
) -> list[_T]:
|
||||
replacement = [token] * repeat_count
|
||||
if pad_token_left is not None:
|
||||
replacement = [pad_token_left] + replacement
|
||||
if pad_token_right is not None:
|
||||
replacement = replacement + [pad_token_right]
|
||||
|
||||
return replacement
|
||||
|
||||
|
||||
def repeat_and_pad_placeholder_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: list[int],
|
||||
*,
|
||||
placeholder_token_id: int,
|
||||
repeat_count: Union[int, list[int]],
|
||||
pad_token_left: Optional[int] = None,
|
||||
pad_token_right: Optional[int] = None,
|
||||
) -> tuple[Optional[str], list[int], list[PlaceholderRange]]:
|
||||
if isinstance(repeat_count, int):
|
||||
repeat_count = [repeat_count]
|
||||
|
||||
if prompt is None:
|
||||
new_prompt = None
|
||||
else:
|
||||
placeholder_token_str = tokenizer.decode(placeholder_token_id)
|
||||
pad_token_str_left = (None if pad_token_left is None else
|
||||
tokenizer.decode(pad_token_left))
|
||||
pad_token_str_right = (None if pad_token_right is None else
|
||||
tokenizer.decode(pad_token_right))
|
||||
|
||||
placeholder_token_count = prompt.count(placeholder_token_str)
|
||||
# This is an arbitrary number to distinguish between the two cases
|
||||
if placeholder_token_count > 16:
|
||||
logger.warning(
|
||||
"Please follow the prompt format that is "
|
||||
"documented on HuggingFace which does not involve "
|
||||
"repeating %s tokens.", placeholder_token_str)
|
||||
if placeholder_token_count < len(repeat_count):
|
||||
logger.warning(
|
||||
"The number of multi-modal placeholder tokens in the prompt "
|
||||
"is less than the number of multi-modal inputs. Extra "
|
||||
"placeholder tokens will be treated as plain text")
|
||||
repeat_count = repeat_count[:placeholder_token_count]
|
||||
|
||||
prompt_parts = prompt.split(placeholder_token_str,
|
||||
maxsplit=len(repeat_count))
|
||||
new_prompt = ""
|
||||
for i, repeat_count_item in enumerate(repeat_count):
|
||||
replacement_str = "".join(
|
||||
repeat_and_pad_token(
|
||||
placeholder_token_str,
|
||||
repeat_count=repeat_count_item,
|
||||
pad_token_left=pad_token_str_left,
|
||||
pad_token_right=pad_token_str_right,
|
||||
))
|
||||
# The image tokens are removed to be consistent with HuggingFace
|
||||
new_prompt += prompt_parts[i] + replacement_str
|
||||
new_prompt += prompt_parts[-1]
|
||||
|
||||
new_token_ids = list[int]()
|
||||
placeholder_ranges = list[PlaceholderRange]()
|
||||
placeholder_token_idx = 0
|
||||
for i, token in enumerate(prompt_token_ids):
|
||||
if token == placeholder_token_id:
|
||||
curr_repeat_count = repeat_count[placeholder_token_idx]
|
||||
replacement_ids = repeat_and_pad_token(
|
||||
placeholder_token_id,
|
||||
repeat_count=curr_repeat_count,
|
||||
pad_token_left=pad_token_left,
|
||||
pad_token_right=pad_token_right,
|
||||
)
|
||||
offset = len(new_token_ids)
|
||||
if pad_token_left is not None:
|
||||
offset += 1
|
||||
placeholder_ranges.append({
|
||||
"offset": offset,
|
||||
"length": curr_repeat_count,
|
||||
})
|
||||
new_token_ids.extend(replacement_ids)
|
||||
placeholder_token_idx += 1
|
||||
|
||||
# No need to further scan the list since we replaced all tokens
|
||||
if placeholder_token_idx >= len(repeat_count):
|
||||
new_token_ids.extend(prompt_token_ids[i + 1:])
|
||||
break
|
||||
else:
|
||||
new_token_ids.append(token)
|
||||
|
||||
return new_prompt, new_token_ids, placeholder_ranges
|
||||
|
||||
|
||||
def consecutive_placeholder_ranges(
|
||||
num_items: int,
|
||||
item_size: int,
|
||||
initial_offset: int = 0) -> list[PlaceholderRange]:
|
||||
"""Returns a list of consecutive PlaceholderRanges of a fixed size"""
|
||||
|
||||
return [
|
||||
PlaceholderRange(offset=initial_offset + i * item_size,
|
||||
length=item_size) for i in range(num_items)
|
||||
]
|
||||
|
||||
|
||||
def merge_and_sort_multimodal_metadata(
|
||||
mm_positions: "MultiModalPlaceholderDict",
|
||||
mm_hashes: Optional["MultiModalHashDict"],
|
||||
|
||||
@ -10,8 +10,7 @@ if TYPE_CHECKING:
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.base import PlaceholderRange
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@ -2,13 +2,13 @@
|
||||
# Datastructures defining an input batch
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
@ -18,9 +18,6 @@ from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
@ -29,7 +26,7 @@ class CachedRequestState:
|
||||
prompt_token_ids: list[int]
|
||||
prompt: Optional[str]
|
||||
mm_inputs: list[MultiModalKwargs]
|
||||
mm_positions: list["PlaceholderRange"]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: SamplingParams
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user