[Chore] Separate out vllm.utils.collections (#26990)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-16 16:35:35 +08:00 committed by GitHub
parent 17838e50ef
commit d2740fafbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 218 additions and 184 deletions

View File

@ -57,7 +57,8 @@ from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import is_list_of, set_default_torch_num_threads
from vllm.utils import set_default_torch_num_threads
from vllm.utils.collections import is_list_of
logger = init_logger(__name__)

View File

@ -25,7 +25,7 @@ from transformers import (
from transformers.video_utils import VideoMetadata
from vllm.logprobs import SampleLogprobs
from vllm.utils import is_list_of
from vllm.utils.collections import is_list_of
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
from .types import RunnerOutput

View File

@ -35,7 +35,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import is_list_of
from vllm.utils.collections import is_list_of
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
from ...utils import dummy_hf_overrides

View File

@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.utils.collections import swap_dict_values
@pytest.mark.parametrize(
"obj,key1,key2",
[
# Tests for both keys exist
({1: "a", 2: "b"}, 1, 2),
# Tests for one key does not exist
({1: "a", 2: "b"}, 1, 3),
# Tests for both keys do not exist
({1: "a", 2: "b"}, 3, 4),
],
)
def test_swap_dict_values(obj, key1, key2):
original_obj = obj.copy()
swap_dict_values(obj, key1, key2)
if key1 in original_obj:
assert obj[key2] == original_obj[key1]
else:
assert key2 not in obj
if key2 in original_obj:
assert obj[key1] == original_obj[key2]
else:
assert key1 not in obj

View File

@ -38,7 +38,6 @@ from vllm.utils import (
sha256,
split_host_port,
split_zmq_path,
swap_dict_values,
unique_filepath,
)
@ -516,30 +515,6 @@ def test_placeholder_module_error_handling():
_ = placeholder_attr.module
@pytest.mark.parametrize(
"obj,key1,key2",
[
# Tests for both keys exist
({1: "a", 2: "b"}, 1, 2),
# Tests for one key does not exist
({1: "a", 2: "b"}, 1, 3),
# Tests for both keys do not exist
({1: "a", 2: "b"}, 3, 4),
],
)
def test_swap_dict_values(obj, key1, key2):
original_obj = obj.copy()
swap_dict_values(obj, key1, key2)
if key1 in original_obj:
assert obj[key2] == original_obj[key1]
else:
assert key2 not in obj
if key2 in original_obj:
assert obj[key1] == original_obj[key2]
else:
assert key1 not in obj
def test_model_specification(
parser_with_config, cli_config_file, cli_config_file_with_model
):

View File

@ -75,7 +75,8 @@ from vllm.transformers_utils.tokenizer import (
get_cached_tokenizer,
)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, Device, as_iter, is_list_of
from vllm.utils import Counter, Device
from vllm.utils.collections import as_iter, is_list_of
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor

View File

@ -70,7 +70,7 @@ from vllm.transformers_utils.tokenizers import (
truncate_tool_call_ids,
validate_request_params,
)
from vllm.utils import as_list
from vllm.utils.collections import as_list
logger = init_logger(__name__)

View File

@ -34,8 +34,8 @@ from vllm.logprobs import Logprob
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import as_list
from vllm.utils.asyncio import merge_async_iterators
from vllm.utils.collections import as_list
logger = init_logger(__name__)

View File

@ -39,8 +39,8 @@ from vllm.outputs import (
RequestOutput,
)
from vllm.pooling_params import PoolingParams
from vllm.utils import chunk_list
from vllm.utils.asyncio import merge_async_iterators
from vllm.utils.collections import chunk_list
logger = init_logger(__name__)

View File

@ -90,13 +90,14 @@ from vllm.tracing import (
log_tracing_disabled_warning,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import is_list_of, random_uuid
from vllm.utils import random_uuid
from vllm.utils.asyncio import (
AsyncMicrobatchTokenizer,
collect_from_async_generator,
make_async,
merge_async_iterators,
)
from vllm.utils.collections import is_list_of
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)

View File

@ -12,7 +12,8 @@ from vllm.entrypoints.openai.protocol import (
)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import import_from_path, is_list_of
from vllm.utils import import_from_path
from vllm.utils.collections import is_list_of
logger = init_logger(__name__)

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cas
from typing_extensions import TypeIs
from vllm.utils import is_list_of
from vllm.utils.collections import is_list_of
from .data import (
EmbedsPrompt,

View File

@ -17,7 +17,7 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import LazyDict
from vllm.utils.collections import LazyDict
logger = init_logger(__name__)

View File

@ -28,7 +28,7 @@ from vllm.model_executor.parameter import (
RowvLLMParameter,
)
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of
from vllm.utils.collections import is_list_of
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods

View File

@ -57,7 +57,7 @@ from vllm.model_executor.parameter import (
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of
from vllm.utils.collections import is_list_of
logger = init_logger(__name__)

View File

@ -49,7 +49,7 @@ from vllm.transformers_utils.configs.deepseek_vl2 import (
)
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import is_list_of
from vllm.utils.collections import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP

View File

@ -33,7 +33,7 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.collections import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP

View File

@ -86,7 +86,7 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from vllm.utils.collections import flatten_2d_lists
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer

View File

@ -79,7 +79,7 @@ from vllm.multimodal.processing import (
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.collections import is_list_of
from .interfaces import (
MultiModalEmbeddings,

View File

@ -22,7 +22,8 @@ from typing import (
import numpy as np
from typing_extensions import NotRequired, TypeVar, deprecated
from vllm.utils import LazyLoader, full_groupby, is_list_of
from vllm.utils import LazyLoader
from vllm.utils.collections import full_groupby, is_list_of
from vllm.utils.jsontree import json_map_leaves
if TYPE_CHECKING:

View File

@ -19,7 +19,8 @@ import numpy as np
import torch
from typing_extensions import assert_never
from vllm.utils import LazyLoader, is_list_of
from vllm.utils import LazyLoader
from vllm.utils.collections import is_list_of
from .audio import AudioResampler
from .inputs import (
@ -364,7 +365,7 @@ class MultiModalDataParser:
if isinstance(data, torch.Tensor):
return data.ndim == 3
if is_list_of(data, torch.Tensor):
return data[0].ndim == 2
return data[0].ndim == 2 # type: ignore[index]
return False
@ -422,6 +423,7 @@ class MultiModalDataParser:
if self._is_embeddings(data):
return AudioEmbeddingItems(data)
data_items: list[AudioItem]
if (
is_list_of(data, float)
or isinstance(data, (np.ndarray, torch.Tensor))
@ -432,7 +434,7 @@ class MultiModalDataParser:
elif isinstance(data, (np.ndarray, torch.Tensor)):
data_items = [elem for elem in data]
else:
data_items = data
data_items = data # type: ignore[assignment]
new_audios = list[np.ndarray]()
for data_item in data_items:
@ -485,6 +487,7 @@ class MultiModalDataParser:
if self._is_embeddings(data):
return VideoEmbeddingItems(data)
data_items: list[VideoItem]
if (
is_list_of(data, PILImage.Image)
or isinstance(data, (np.ndarray, torch.Tensor))
@ -496,7 +499,7 @@ class MultiModalDataParser:
elif isinstance(data, tuple) and len(data) == 2:
data_items = [data]
else:
data_items = data
data_items = data # type: ignore[assignment]
new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
metadata_lst: list[dict[str, Any] | None] = []

View File

@ -25,7 +25,7 @@ from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
from vllm.utils import flatten_2d_lists, full_groupby
from vllm.utils.collections import flatten_2d_lists, full_groupby
from vllm.utils.functools import get_allowed_kwarg_only_overrides
from vllm.utils.jsontree import JSONTree, json_map_leaves
@ -484,8 +484,11 @@ _M = TypeVar("_M", bound=_HasModalityAttr | _HasModalityProp)
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
"""Convenience function to apply [`full_groupby`][vllm.utils.full_groupby]
based on modality."""
"""
Convenience function to apply
[`full_groupby`][vllm.utils.collections.full_groupby]
based on modality.
"""
return full_groupby(values, key=lambda x: x.modality)

View File

@ -9,7 +9,7 @@ import torch.nn as nn
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
from vllm.utils import ClassRegistry
from vllm.utils.collections import ClassRegistry
from .cache import BaseMultiModalProcessorCache
from .processing import (

View File

@ -8,7 +8,8 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any
from vllm.logger import init_logger
from vllm.utils import import_from_path, is_list_of
from vllm.utils import import_from_path
from vllm.utils.collections import is_list_of
if TYPE_CHECKING:
from vllm.entrypoints.openai.protocol import (

View File

@ -37,29 +37,19 @@ from argparse import (
RawDescriptionHelpFormatter,
_ArgumentGroup,
)
from collections import UserDict, defaultdict
from collections import defaultdict
from collections.abc import (
Callable,
Collection,
Generator,
Hashable,
Iterable,
Iterator,
Mapping,
Sequence,
)
from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
TextIO,
TypeVar,
)
from typing import TYPE_CHECKING, Any, TextIO, TypeVar
from urllib.parse import urlparse
from uuid import uuid4
@ -78,7 +68,7 @@ import zmq.asyncio
from packaging import version
from packaging.version import Version
from torch.library import Library
from typing_extensions import Never, TypeIs, assert_never
from typing_extensions import Never
import vllm.envs as envs
from vllm.logger import enable_trace_function_call, init_logger
@ -170,9 +160,6 @@ def set_default_torch_num_threads(num_threads: int):
T = TypeVar("T")
U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
class Device(enum.Enum):
GPU = enum.auto()
@ -421,12 +408,6 @@ def update_environment_variables(envs: dict[str, str]):
os.environ[k] = v
def chunk_list(lst: list[T], chunk_size: int):
"""Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size]
def cdiv(a: int, b: int) -> int:
"""Ceiling division."""
return -(a // -b)
@ -743,53 +724,6 @@ def common_broadcastable_dtype(dtypes: Collection[torch.dtype]):
)
def as_list(maybe_list: Iterable[T]) -> list[T]:
"""Convert iterable to list, unless it's already a list."""
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
if isinstance(obj, str) or not isinstance(obj, Iterable):
return [obj] # type: ignore[list-item]
return obj
# `collections` helpers
def is_list_of(
value: object,
typ: type[T] | tuple[type[T], ...],
*,
check: Literal["first", "all"] = "first",
) -> TypeIs[list[T]]:
if not isinstance(value, list):
return False
if check == "first":
return len(value) == 0 or isinstance(value[0], typ)
elif check == "all":
return all(isinstance(v, typ) for v in value)
assert_never(check)
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike [`itertools.groupby`][], groups are not broken by
non-contiguous data.
"""
groups = defaultdict[_K, list[_V]](list)
for value in values:
groups[key(value)].append(value)
return groups.items()
# TODO: This function can be removed if transformer_modules classes are
# serialized by value when communicating between processes
def init_cached_hf_modules() -> None:
@ -1578,50 +1512,6 @@ class AtomicCounter:
return self._value
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class LazyDict(Mapping[str, T], Generic[T]):
def __init__(self, factory: dict[str, Callable[[], T]]):
self._factory = factory
self._dict: dict[str, T] = {}
def __getitem__(self, key: str) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
self._dict[key] = self._factory[key]()
return self._dict[key]
def __setitem__(self, key: str, value: Callable[[], T]):
self._factory[key] = value
def __iter__(self):
return iter(self._factory)
def __len__(self):
return len(self._factory)
class ClassRegistry(UserDict[type[T], _V]):
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
raise KeyError(key)
def __contains__(self, key: object) -> bool:
return self.contains(key)
def contains(self, key: object, *, strict: bool = False) -> bool:
if not isinstance(key, type):
return False
if strict:
return key in self.data
return any(cls in self.data for cls in key.mro())
def weak_ref_tensor(tensor: Any) -> Any:
"""
Create a weak reference to a tensor.
@ -2588,22 +2478,6 @@ class LazyLoader(types.ModuleType):
return dir(self._module)
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
"""
Helper function to swap values for two keys
"""
v1 = obj.get(key1)
v2 = obj.get(key2)
if v1 is not None:
obj[key2] = v1
else:
obj.pop(key2, None)
if v2 is not None:
obj[key1] = v2
else:
obj.pop(key1, None)
@contextlib.contextmanager
def cprofile_context(save_file: str | None = None):
"""Run a cprofile

139
vllm/utils/collections.py Normal file
View File

@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Contains helpers that are applied to collections.
This is similar in concept to the `collections` module.
"""
from collections import UserDict, defaultdict
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
from typing import Generic, Literal, TypeVar
from typing_extensions import TypeIs, assert_never
T = TypeVar("T")
U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
class ClassRegistry(UserDict[type[T], _V]):
"""
A registry that acts like a dictionary but searches for other classes
in the MRO if the original class is not found.
"""
def __getitem__(self, key: type[T]) -> _V:
for cls in key.mro():
if cls in self.data:
return self.data[cls]
raise KeyError(key)
def __contains__(self, key: object) -> bool:
return self.contains(key)
def contains(self, key: object, *, strict: bool = False) -> bool:
if not isinstance(key, type):
return False
if strict:
return key in self.data
return any(cls in self.data for cls in key.mro())
class LazyDict(Mapping[str, T], Generic[T]):
"""
Evaluates dictionary items only when they are accessed.
Adapted from: https://stackoverflow.com/a/47212782/5082708
"""
def __init__(self, factory: dict[str, Callable[[], T]]):
self._factory = factory
self._dict: dict[str, T] = {}
def __getitem__(self, key: str) -> T:
if key not in self._dict:
if key not in self._factory:
raise KeyError(key)
self._dict[key] = self._factory[key]()
return self._dict[key]
def __setitem__(self, key: str, value: Callable[[], T]):
self._factory[key] = value
def __iter__(self):
return iter(self._factory)
def __len__(self):
return len(self._factory)
def as_list(maybe_list: Iterable[T]) -> list[T]:
"""Convert iterable to list, unless it's already a list."""
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
if isinstance(obj, str) or not isinstance(obj, Iterable):
return [obj] # type: ignore[list-item]
return obj
def is_list_of(
value: object,
typ: type[T] | tuple[type[T], ...],
*,
check: Literal["first", "all"] = "first",
) -> TypeIs[list[T]]:
if not isinstance(value, list):
return False
if check == "first":
return len(value) == 0 or isinstance(value[0], typ)
elif check == "all":
return all(isinstance(v, typ) for v in value)
assert_never(check)
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
"""Yield successive chunk_size chunks from lst."""
for i in range(0, len(lst), chunk_size):
yield lst[i : i + chunk_size]
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
"""
Unlike [`itertools.groupby`][], groups are not broken by
non-contiguous data.
"""
groups = defaultdict[_K, list[_V]](list)
for value in values:
groups[key(value)].append(value)
return groups.items()
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
"""Swap values between two keys."""
v1 = obj.get(key1)
v2 = obj.get(key2)
if v1 is not None:
obj[key2] = v1
else:
obj.pop(key2, None)
if v2 is not None:
obj[key1] = v2
else:
obj.pop(key1, None)

View File

@ -29,8 +29,9 @@ from vllm.tracing import init_tracer
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, as_list, cdiv
from vllm.utils import Device, cdiv
from vllm.utils.asyncio import cancel_task_threadsafe
from vllm.utils.collections import as_list
from vllm.utils.functools import deprecate_kwargs
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient

View File

@ -12,7 +12,8 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collections import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (

View File

@ -9,7 +9,8 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingType
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.utils.collections import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState