mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 07:24:54 +08:00
[Chore] Separate out vllm.utils.collections (#26990)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
17838e50ef
commit
d2740fafbf
@ -57,7 +57,8 @@ from vllm.multimodal.utils import fetch_image
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils import is_list_of, set_default_torch_num_threads
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.utils.collections import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ from transformers import (
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.logprobs import SampleLogprobs
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.collections import is_list_of
|
||||
|
||||
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
|
||||
from .types import RunnerOutput
|
||||
|
||||
@ -35,7 +35,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.collections import is_list_of
|
||||
|
||||
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
||||
from ...utils import dummy_hf_overrides
|
||||
|
||||
31
tests/utils_/test_collections.py
Normal file
31
tests/utils_/test_collections.py
Normal file
@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.utils.collections import swap_dict_values
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"obj,key1,key2",
|
||||
[
|
||||
# Tests for both keys exist
|
||||
({1: "a", 2: "b"}, 1, 2),
|
||||
# Tests for one key does not exist
|
||||
({1: "a", 2: "b"}, 1, 3),
|
||||
# Tests for both keys do not exist
|
||||
({1: "a", 2: "b"}, 3, 4),
|
||||
],
|
||||
)
|
||||
def test_swap_dict_values(obj, key1, key2):
|
||||
original_obj = obj.copy()
|
||||
|
||||
swap_dict_values(obj, key1, key2)
|
||||
|
||||
if key1 in original_obj:
|
||||
assert obj[key2] == original_obj[key1]
|
||||
else:
|
||||
assert key2 not in obj
|
||||
if key2 in original_obj:
|
||||
assert obj[key1] == original_obj[key2]
|
||||
else:
|
||||
assert key1 not in obj
|
||||
@ -38,7 +38,6 @@ from vllm.utils import (
|
||||
sha256,
|
||||
split_host_port,
|
||||
split_zmq_path,
|
||||
swap_dict_values,
|
||||
unique_filepath,
|
||||
)
|
||||
|
||||
@ -516,30 +515,6 @@ def test_placeholder_module_error_handling():
|
||||
_ = placeholder_attr.module
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"obj,key1,key2",
|
||||
[
|
||||
# Tests for both keys exist
|
||||
({1: "a", 2: "b"}, 1, 2),
|
||||
# Tests for one key does not exist
|
||||
({1: "a", 2: "b"}, 1, 3),
|
||||
# Tests for both keys do not exist
|
||||
({1: "a", 2: "b"}, 3, 4),
|
||||
],
|
||||
)
|
||||
def test_swap_dict_values(obj, key1, key2):
|
||||
original_obj = obj.copy()
|
||||
swap_dict_values(obj, key1, key2)
|
||||
if key1 in original_obj:
|
||||
assert obj[key2] == original_obj[key1]
|
||||
else:
|
||||
assert key2 not in obj
|
||||
if key2 in original_obj:
|
||||
assert obj[key1] == original_obj[key2]
|
||||
else:
|
||||
assert key1 not in obj
|
||||
|
||||
|
||||
def test_model_specification(
|
||||
parser_with_config, cli_config_file, cli_config_file_with_model
|
||||
):
|
||||
|
||||
@ -75,7 +75,8 @@ from vllm.transformers_utils.tokenizer import (
|
||||
get_cached_tokenizer,
|
||||
)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, Device, as_iter, is_list_of
|
||||
from vllm.utils import Counter, Device
|
||||
from vllm.utils.collections import as_iter, is_list_of
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
@ -70,7 +70,7 @@ from vllm.transformers_utils.tokenizers import (
|
||||
truncate_tool_call_ids,
|
||||
validate_request_params,
|
||||
)
|
||||
from vllm.utils import as_list
|
||||
from vllm.utils.collections import as_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -34,8 +34,8 @@ from vllm.logprobs import Logprob
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import as_list
|
||||
from vllm.utils.asyncio import merge_async_iterators
|
||||
from vllm.utils.collections import as_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -39,8 +39,8 @@ from vllm.outputs import (
|
||||
RequestOutput,
|
||||
)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.utils import chunk_list
|
||||
from vllm.utils.asyncio import merge_async_iterators
|
||||
from vllm.utils.collections import chunk_list
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -90,13 +90,14 @@ from vllm.tracing import (
|
||||
log_tracing_disabled_warning,
|
||||
)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import is_list_of, random_uuid
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils.asyncio import (
|
||||
AsyncMicrobatchTokenizer,
|
||||
collect_from_async_generator,
|
||||
make_async,
|
||||
merge_async_iterators,
|
||||
)
|
||||
from vllm.utils.collections import is_list_of
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -12,7 +12,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import import_from_path, is_list_of
|
||||
from vllm.utils import import_from_path
|
||||
from vllm.utils.collections import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cas
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.collections import is_list_of
|
||||
|
||||
from .data import (
|
||||
EmbedsPrompt,
|
||||
|
||||
@ -17,7 +17,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import LazyDict
|
||||
from vllm.utils.collections import LazyDict
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ from vllm.model_executor.parameter import (
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.collections import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
|
||||
@ -57,7 +57,7 @@ from vllm.model_executor.parameter import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.collections import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -49,7 +49,7 @@ from vllm.transformers_utils.configs.deepseek_vl2 import (
|
||||
)
|
||||
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.collections import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
|
||||
@ -33,7 +33,7 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.collections import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
|
||||
@ -86,7 +86,7 @@ from vllm.multimodal.processing import (
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
from vllm.utils.collections import flatten_2d_lists
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
|
||||
@ -79,7 +79,7 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.collections import is_list_of
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
|
||||
@ -22,7 +22,8 @@ from typing import (
|
||||
import numpy as np
|
||||
from typing_extensions import NotRequired, TypeVar, deprecated
|
||||
|
||||
from vllm.utils import LazyLoader, full_groupby, is_list_of
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.utils.collections import full_groupby, is_list_of
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -19,7 +19,8 @@ import numpy as np
|
||||
import torch
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.utils import LazyLoader, is_list_of
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.utils.collections import is_list_of
|
||||
|
||||
from .audio import AudioResampler
|
||||
from .inputs import (
|
||||
@ -364,7 +365,7 @@ class MultiModalDataParser:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.ndim == 3
|
||||
if is_list_of(data, torch.Tensor):
|
||||
return data[0].ndim == 2
|
||||
return data[0].ndim == 2 # type: ignore[index]
|
||||
|
||||
return False
|
||||
|
||||
@ -422,6 +423,7 @@ class MultiModalDataParser:
|
||||
if self._is_embeddings(data):
|
||||
return AudioEmbeddingItems(data)
|
||||
|
||||
data_items: list[AudioItem]
|
||||
if (
|
||||
is_list_of(data, float)
|
||||
or isinstance(data, (np.ndarray, torch.Tensor))
|
||||
@ -432,7 +434,7 @@ class MultiModalDataParser:
|
||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||
data_items = [elem for elem in data]
|
||||
else:
|
||||
data_items = data
|
||||
data_items = data # type: ignore[assignment]
|
||||
|
||||
new_audios = list[np.ndarray]()
|
||||
for data_item in data_items:
|
||||
@ -485,6 +487,7 @@ class MultiModalDataParser:
|
||||
if self._is_embeddings(data):
|
||||
return VideoEmbeddingItems(data)
|
||||
|
||||
data_items: list[VideoItem]
|
||||
if (
|
||||
is_list_of(data, PILImage.Image)
|
||||
or isinstance(data, (np.ndarray, torch.Tensor))
|
||||
@ -496,7 +499,7 @@ class MultiModalDataParser:
|
||||
elif isinstance(data, tuple) and len(data) == 2:
|
||||
data_items = [data]
|
||||
else:
|
||||
data_items = data
|
||||
data_items = data # type: ignore[assignment]
|
||||
|
||||
new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
|
||||
metadata_lst: list[dict[str, Any] | None] = []
|
||||
|
||||
@ -25,7 +25,7 @@ from typing_extensions import TypeVar, assert_never
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
|
||||
from vllm.utils import flatten_2d_lists, full_groupby
|
||||
from vllm.utils.collections import flatten_2d_lists, full_groupby
|
||||
from vllm.utils.functools import get_allowed_kwarg_only_overrides
|
||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
||||
|
||||
@ -484,8 +484,11 @@ _M = TypeVar("_M", bound=_HasModalityAttr | _HasModalityProp)
|
||||
|
||||
|
||||
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
|
||||
"""Convenience function to apply [`full_groupby`][vllm.utils.full_groupby]
|
||||
based on modality."""
|
||||
"""
|
||||
Convenience function to apply
|
||||
[`full_groupby`][vllm.utils.collections.full_groupby]
|
||||
based on modality.
|
||||
"""
|
||||
return full_groupby(values, key=lambda x: x.modality)
|
||||
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ import torch.nn as nn
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
|
||||
from vllm.utils import ClassRegistry
|
||||
from vllm.utils.collections import ClassRegistry
|
||||
|
||||
from .cache import BaseMultiModalProcessorCache
|
||||
from .processing import (
|
||||
|
||||
@ -8,7 +8,8 @@ from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import import_from_path, is_list_of
|
||||
from vllm.utils import import_from_path
|
||||
from vllm.utils.collections import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
|
||||
@ -37,29 +37,19 @@ from argparse import (
|
||||
RawDescriptionHelpFormatter,
|
||||
_ArgumentGroup,
|
||||
)
|
||||
from collections import UserDict, defaultdict
|
||||
from collections import defaultdict
|
||||
from collections.abc import (
|
||||
Callable,
|
||||
Collection,
|
||||
Generator,
|
||||
Hashable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
from concurrent.futures.process import ProcessPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cache, lru_cache, partial, wraps
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TextIO,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, TextIO, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@ -78,7 +68,7 @@ import zmq.asyncio
|
||||
from packaging import version
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
from typing_extensions import Never, TypeIs, assert_never
|
||||
from typing_extensions import Never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
@ -170,9 +160,6 @@ def set_default_torch_num_threads(num_threads: int):
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
_K = TypeVar("_K", bound=Hashable)
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
class Device(enum.Enum):
|
||||
GPU = enum.auto()
|
||||
@ -421,12 +408,6 @@ def update_environment_variables(envs: dict[str, str]):
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
def chunk_list(lst: list[T], chunk_size: int):
|
||||
"""Yield successive chunk_size chunks from lst."""
|
||||
for i in range(0, len(lst), chunk_size):
|
||||
yield lst[i : i + chunk_size]
|
||||
|
||||
|
||||
def cdiv(a: int, b: int) -> int:
|
||||
"""Ceiling division."""
|
||||
return -(a // -b)
|
||||
@ -743,53 +724,6 @@ def common_broadcastable_dtype(dtypes: Collection[torch.dtype]):
|
||||
)
|
||||
|
||||
|
||||
def as_list(maybe_list: Iterable[T]) -> list[T]:
|
||||
"""Convert iterable to list, unless it's already a list."""
|
||||
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
|
||||
|
||||
|
||||
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
|
||||
if isinstance(obj, str) or not isinstance(obj, Iterable):
|
||||
return [obj] # type: ignore[list-item]
|
||||
return obj
|
||||
|
||||
|
||||
# `collections` helpers
|
||||
def is_list_of(
|
||||
value: object,
|
||||
typ: type[T] | tuple[type[T], ...],
|
||||
*,
|
||||
check: Literal["first", "all"] = "first",
|
||||
) -> TypeIs[list[T]]:
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if check == "first":
|
||||
return len(value) == 0 or isinstance(value[0], typ)
|
||||
elif check == "all":
|
||||
return all(isinstance(v, typ) for v in value)
|
||||
|
||||
assert_never(check)
|
||||
|
||||
|
||||
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
|
||||
"""Flatten a list of lists to a single list."""
|
||||
return [item for sublist in lists for item in sublist]
|
||||
|
||||
|
||||
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
||||
"""
|
||||
Unlike [`itertools.groupby`][], groups are not broken by
|
||||
non-contiguous data.
|
||||
"""
|
||||
groups = defaultdict[_K, list[_V]](list)
|
||||
|
||||
for value in values:
|
||||
groups[key(value)].append(value)
|
||||
|
||||
return groups.items()
|
||||
|
||||
|
||||
# TODO: This function can be removed if transformer_modules classes are
|
||||
# serialized by value when communicating between processes
|
||||
def init_cached_hf_modules() -> None:
|
||||
@ -1578,50 +1512,6 @@ class AtomicCounter:
|
||||
return self._value
|
||||
|
||||
|
||||
# Adapted from: https://stackoverflow.com/a/47212782/5082708
|
||||
class LazyDict(Mapping[str, T], Generic[T]):
|
||||
def __init__(self, factory: dict[str, Callable[[], T]]):
|
||||
self._factory = factory
|
||||
self._dict: dict[str, T] = {}
|
||||
|
||||
def __getitem__(self, key: str) -> T:
|
||||
if key not in self._dict:
|
||||
if key not in self._factory:
|
||||
raise KeyError(key)
|
||||
self._dict[key] = self._factory[key]()
|
||||
return self._dict[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Callable[[], T]):
|
||||
self._factory[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._factory)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._factory)
|
||||
|
||||
|
||||
class ClassRegistry(UserDict[type[T], _V]):
|
||||
def __getitem__(self, key: type[T]) -> _V:
|
||||
for cls in key.mro():
|
||||
if cls in self.data:
|
||||
return self.data[cls]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return self.contains(key)
|
||||
|
||||
def contains(self, key: object, *, strict: bool = False) -> bool:
|
||||
if not isinstance(key, type):
|
||||
return False
|
||||
|
||||
if strict:
|
||||
return key in self.data
|
||||
|
||||
return any(cls in self.data for cls in key.mro())
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: Any) -> Any:
|
||||
"""
|
||||
Create a weak reference to a tensor.
|
||||
@ -2588,22 +2478,6 @@ class LazyLoader(types.ModuleType):
|
||||
return dir(self._module)
|
||||
|
||||
|
||||
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
|
||||
"""
|
||||
Helper function to swap values for two keys
|
||||
"""
|
||||
v1 = obj.get(key1)
|
||||
v2 = obj.get(key2)
|
||||
if v1 is not None:
|
||||
obj[key2] = v1
|
||||
else:
|
||||
obj.pop(key2, None)
|
||||
if v2 is not None:
|
||||
obj[key1] = v2
|
||||
else:
|
||||
obj.pop(key1, None)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def cprofile_context(save_file: str | None = None):
|
||||
"""Run a cprofile
|
||||
|
||||
139
vllm/utils/collections.py
Normal file
139
vllm/utils/collections.py
Normal file
@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Contains helpers that are applied to collections.
|
||||
|
||||
This is similar in concept to the `collections` module.
|
||||
"""
|
||||
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
|
||||
from typing import Generic, Literal, TypeVar
|
||||
|
||||
from typing_extensions import TypeIs, assert_never
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
_K = TypeVar("_K", bound=Hashable)
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
class ClassRegistry(UserDict[type[T], _V]):
|
||||
"""
|
||||
A registry that acts like a dictionary but searches for other classes
|
||||
in the MRO if the original class is not found.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: type[T]) -> _V:
|
||||
for cls in key.mro():
|
||||
if cls in self.data:
|
||||
return self.data[cls]
|
||||
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return self.contains(key)
|
||||
|
||||
def contains(self, key: object, *, strict: bool = False) -> bool:
|
||||
if not isinstance(key, type):
|
||||
return False
|
||||
|
||||
if strict:
|
||||
return key in self.data
|
||||
|
||||
return any(cls in self.data for cls in key.mro())
|
||||
|
||||
|
||||
class LazyDict(Mapping[str, T], Generic[T]):
|
||||
"""
|
||||
Evaluates dictionary items only when they are accessed.
|
||||
|
||||
Adapted from: https://stackoverflow.com/a/47212782/5082708
|
||||
"""
|
||||
|
||||
def __init__(self, factory: dict[str, Callable[[], T]]):
|
||||
self._factory = factory
|
||||
self._dict: dict[str, T] = {}
|
||||
|
||||
def __getitem__(self, key: str) -> T:
|
||||
if key not in self._dict:
|
||||
if key not in self._factory:
|
||||
raise KeyError(key)
|
||||
self._dict[key] = self._factory[key]()
|
||||
return self._dict[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Callable[[], T]):
|
||||
self._factory[key] = value
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._factory)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._factory)
|
||||
|
||||
|
||||
def as_list(maybe_list: Iterable[T]) -> list[T]:
|
||||
"""Convert iterable to list, unless it's already a list."""
|
||||
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
|
||||
|
||||
|
||||
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
|
||||
if isinstance(obj, str) or not isinstance(obj, Iterable):
|
||||
return [obj] # type: ignore[list-item]
|
||||
return obj
|
||||
|
||||
|
||||
def is_list_of(
|
||||
value: object,
|
||||
typ: type[T] | tuple[type[T], ...],
|
||||
*,
|
||||
check: Literal["first", "all"] = "first",
|
||||
) -> TypeIs[list[T]]:
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if check == "first":
|
||||
return len(value) == 0 or isinstance(value[0], typ)
|
||||
elif check == "all":
|
||||
return all(isinstance(v, typ) for v in value)
|
||||
|
||||
assert_never(check)
|
||||
|
||||
|
||||
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
|
||||
"""Yield successive chunk_size chunks from lst."""
|
||||
for i in range(0, len(lst), chunk_size):
|
||||
yield lst[i : i + chunk_size]
|
||||
|
||||
|
||||
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
|
||||
"""Flatten a list of lists to a single list."""
|
||||
return [item for sublist in lists for item in sublist]
|
||||
|
||||
|
||||
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
||||
"""
|
||||
Unlike [`itertools.groupby`][], groups are not broken by
|
||||
non-contiguous data.
|
||||
"""
|
||||
groups = defaultdict[_K, list[_V]](list)
|
||||
|
||||
for value in values:
|
||||
groups[key(value)].append(value)
|
||||
|
||||
return groups.items()
|
||||
|
||||
|
||||
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
|
||||
"""Swap values between two keys."""
|
||||
v1 = obj.get(key1)
|
||||
v2 = obj.get(key2)
|
||||
if v1 is not None:
|
||||
obj[key2] = v1
|
||||
else:
|
||||
obj.pop(key2, None)
|
||||
if v2 is not None:
|
||||
obj[key1] = v2
|
||||
else:
|
||||
obj.pop(key1, None)
|
||||
@ -29,8 +29,9 @@ from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, as_list, cdiv
|
||||
from vllm.utils import Device, cdiv
|
||||
from vllm.utils.asyncio import cancel_task_threadsafe
|
||||
from vllm.utils.collections import as_list
|
||||
from vllm.utils.functools import deprecate_kwargs
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
|
||||
@ -12,7 +12,8 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.collections import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import (
|
||||
|
||||
@ -9,7 +9,8 @@ import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.utils.collections import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user