mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
[Chore] Separate out vllm.utils.collections (#26990)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
17838e50ef
commit
d2740fafbf
@ -57,7 +57,8 @@ from vllm.multimodal.utils import fetch_image
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams
|
from vllm.sampling_params import BeamSearchParams
|
||||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||||
from vllm.utils import is_list_of, set_default_torch_num_threads
|
from vllm.utils import set_default_torch_num_threads
|
||||||
|
from vllm.utils.collections import is_list_of
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from transformers import (
|
|||||||
from transformers.video_utils import VideoMetadata
|
from transformers.video_utils import VideoMetadata
|
||||||
|
|
||||||
from vllm.logprobs import SampleLogprobs
|
from vllm.logprobs import SampleLogprobs
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils.collections import is_list_of
|
||||||
|
|
||||||
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
|
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
|
||||||
from .types import RunnerOutput
|
from .types import RunnerOutput
|
||||||
|
|||||||
@ -35,7 +35,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
|
|||||||
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
|
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
|
||||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils.collections import is_list_of
|
||||||
|
|
||||||
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
||||||
from ...utils import dummy_hf_overrides
|
from ...utils import dummy_hf_overrides
|
||||||
|
|||||||
31
tests/utils_/test_collections.py
Normal file
31
tests/utils_/test_collections.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.utils.collections import swap_dict_values
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"obj,key1,key2",
|
||||||
|
[
|
||||||
|
# Tests for both keys exist
|
||||||
|
({1: "a", 2: "b"}, 1, 2),
|
||||||
|
# Tests for one key does not exist
|
||||||
|
({1: "a", 2: "b"}, 1, 3),
|
||||||
|
# Tests for both keys do not exist
|
||||||
|
({1: "a", 2: "b"}, 3, 4),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_swap_dict_values(obj, key1, key2):
|
||||||
|
original_obj = obj.copy()
|
||||||
|
|
||||||
|
swap_dict_values(obj, key1, key2)
|
||||||
|
|
||||||
|
if key1 in original_obj:
|
||||||
|
assert obj[key2] == original_obj[key1]
|
||||||
|
else:
|
||||||
|
assert key2 not in obj
|
||||||
|
if key2 in original_obj:
|
||||||
|
assert obj[key1] == original_obj[key2]
|
||||||
|
else:
|
||||||
|
assert key1 not in obj
|
||||||
@ -38,7 +38,6 @@ from vllm.utils import (
|
|||||||
sha256,
|
sha256,
|
||||||
split_host_port,
|
split_host_port,
|
||||||
split_zmq_path,
|
split_zmq_path,
|
||||||
swap_dict_values,
|
|
||||||
unique_filepath,
|
unique_filepath,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -516,30 +515,6 @@ def test_placeholder_module_error_handling():
|
|||||||
_ = placeholder_attr.module
|
_ = placeholder_attr.module
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"obj,key1,key2",
|
|
||||||
[
|
|
||||||
# Tests for both keys exist
|
|
||||||
({1: "a", 2: "b"}, 1, 2),
|
|
||||||
# Tests for one key does not exist
|
|
||||||
({1: "a", 2: "b"}, 1, 3),
|
|
||||||
# Tests for both keys do not exist
|
|
||||||
({1: "a", 2: "b"}, 3, 4),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_swap_dict_values(obj, key1, key2):
|
|
||||||
original_obj = obj.copy()
|
|
||||||
swap_dict_values(obj, key1, key2)
|
|
||||||
if key1 in original_obj:
|
|
||||||
assert obj[key2] == original_obj[key1]
|
|
||||||
else:
|
|
||||||
assert key2 not in obj
|
|
||||||
if key2 in original_obj:
|
|
||||||
assert obj[key1] == original_obj[key2]
|
|
||||||
else:
|
|
||||||
assert key1 not in obj
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_specification(
|
def test_model_specification(
|
||||||
parser_with_config, cli_config_file, cli_config_file_with_model
|
parser_with_config, cli_config_file, cli_config_file_with_model
|
||||||
):
|
):
|
||||||
|
|||||||
@ -75,7 +75,8 @@ from vllm.transformers_utils.tokenizer import (
|
|||||||
get_cached_tokenizer,
|
get_cached_tokenizer,
|
||||||
)
|
)
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Counter, Device, as_iter, is_list_of
|
from vllm.utils import Counter, Device
|
||||||
|
from vllm.utils.collections import as_iter, is_list_of
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.llm_engine import LLMEngine
|
from vllm.v1.engine.llm_engine import LLMEngine
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||||
|
|||||||
@ -70,7 +70,7 @@ from vllm.transformers_utils.tokenizers import (
|
|||||||
truncate_tool_call_ids,
|
truncate_tool_call_ids,
|
||||||
validate_request_params,
|
validate_request_params,
|
||||||
)
|
)
|
||||||
from vllm.utils import as_list
|
from vllm.utils.collections import as_list
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -34,8 +34,8 @@ from vllm.logprobs import Logprob
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import as_list
|
|
||||||
from vllm.utils.asyncio import merge_async_iterators
|
from vllm.utils.asyncio import merge_async_iterators
|
||||||
|
from vllm.utils.collections import as_list
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -39,8 +39,8 @@ from vllm.outputs import (
|
|||||||
RequestOutput,
|
RequestOutput,
|
||||||
)
|
)
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.utils import chunk_list
|
|
||||||
from vllm.utils.asyncio import merge_async_iterators
|
from vllm.utils.asyncio import merge_async_iterators
|
||||||
|
from vllm.utils.collections import chunk_list
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -90,13 +90,14 @@ from vllm.tracing import (
|
|||||||
log_tracing_disabled_warning,
|
log_tracing_disabled_warning,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import is_list_of, random_uuid
|
from vllm.utils import random_uuid
|
||||||
from vllm.utils.asyncio import (
|
from vllm.utils.asyncio import (
|
||||||
AsyncMicrobatchTokenizer,
|
AsyncMicrobatchTokenizer,
|
||||||
collect_from_async_generator,
|
collect_from_async_generator,
|
||||||
make_async,
|
make_async,
|
||||||
merge_async_iterators,
|
merge_async_iterators,
|
||||||
)
|
)
|
||||||
|
from vllm.utils.collections import is_list_of
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@ -12,7 +12,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import import_from_path, is_list_of
|
from vllm.utils import import_from_path
|
||||||
|
from vllm.utils.collections import is_list_of
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Literal, NamedTuple, TypeAlias, TypedDict, cas
|
|||||||
|
|
||||||
from typing_extensions import TypeIs
|
from typing_extensions import TypeIs
|
||||||
|
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils.collections import is_list_of
|
||||||
|
|
||||||
from .data import (
|
from .data import (
|
||||||
EmbedsPrompt,
|
EmbedsPrompt,
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import LazyDict
|
from vllm.utils.collections import LazyDict
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from vllm.model_executor.parameter import (
|
|||||||
RowvLLMParameter,
|
RowvLLMParameter,
|
||||||
)
|
)
|
||||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils.collections import is_list_of
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
|
|||||||
@ -57,7 +57,7 @@ from vllm.model_executor.parameter import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils.collections import is_list_of
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@ -49,7 +49,7 @@ from vllm.transformers_utils.configs.deepseek_vl2 import (
|
|||||||
)
|
)
|
||||||
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
|
from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils.collections import is_list_of
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
|
|||||||
@ -33,7 +33,7 @@ from vllm.multimodal.processing import (
|
|||||||
)
|
)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils.collections import is_list_of
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
|
|||||||
@ -86,7 +86,7 @@ from vllm.multimodal.processing import (
|
|||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import flatten_2d_lists
|
from vllm.utils.collections import flatten_2d_lists
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||||
|
|||||||
@ -79,7 +79,7 @@ from vllm.multimodal.processing import (
|
|||||||
)
|
)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils.collections import is_list_of
|
||||||
|
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
|
|||||||
@ -22,7 +22,8 @@ from typing import (
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from typing_extensions import NotRequired, TypeVar, deprecated
|
from typing_extensions import NotRequired, TypeVar, deprecated
|
||||||
|
|
||||||
from vllm.utils import LazyLoader, full_groupby, is_list_of
|
from vllm.utils import LazyLoader
|
||||||
|
from vllm.utils.collections import full_groupby, is_list_of
|
||||||
from vllm.utils.jsontree import json_map_leaves
|
from vllm.utils.jsontree import json_map_leaves
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@ -19,7 +19,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
from vllm.utils import LazyLoader, is_list_of
|
from vllm.utils import LazyLoader
|
||||||
|
from vllm.utils.collections import is_list_of
|
||||||
|
|
||||||
from .audio import AudioResampler
|
from .audio import AudioResampler
|
||||||
from .inputs import (
|
from .inputs import (
|
||||||
@ -364,7 +365,7 @@ class MultiModalDataParser:
|
|||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
return data.ndim == 3
|
return data.ndim == 3
|
||||||
if is_list_of(data, torch.Tensor):
|
if is_list_of(data, torch.Tensor):
|
||||||
return data[0].ndim == 2
|
return data[0].ndim == 2 # type: ignore[index]
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -422,6 +423,7 @@ class MultiModalDataParser:
|
|||||||
if self._is_embeddings(data):
|
if self._is_embeddings(data):
|
||||||
return AudioEmbeddingItems(data)
|
return AudioEmbeddingItems(data)
|
||||||
|
|
||||||
|
data_items: list[AudioItem]
|
||||||
if (
|
if (
|
||||||
is_list_of(data, float)
|
is_list_of(data, float)
|
||||||
or isinstance(data, (np.ndarray, torch.Tensor))
|
or isinstance(data, (np.ndarray, torch.Tensor))
|
||||||
@ -432,7 +434,7 @@ class MultiModalDataParser:
|
|||||||
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
elif isinstance(data, (np.ndarray, torch.Tensor)):
|
||||||
data_items = [elem for elem in data]
|
data_items = [elem for elem in data]
|
||||||
else:
|
else:
|
||||||
data_items = data
|
data_items = data # type: ignore[assignment]
|
||||||
|
|
||||||
new_audios = list[np.ndarray]()
|
new_audios = list[np.ndarray]()
|
||||||
for data_item in data_items:
|
for data_item in data_items:
|
||||||
@ -485,6 +487,7 @@ class MultiModalDataParser:
|
|||||||
if self._is_embeddings(data):
|
if self._is_embeddings(data):
|
||||||
return VideoEmbeddingItems(data)
|
return VideoEmbeddingItems(data)
|
||||||
|
|
||||||
|
data_items: list[VideoItem]
|
||||||
if (
|
if (
|
||||||
is_list_of(data, PILImage.Image)
|
is_list_of(data, PILImage.Image)
|
||||||
or isinstance(data, (np.ndarray, torch.Tensor))
|
or isinstance(data, (np.ndarray, torch.Tensor))
|
||||||
@ -496,7 +499,7 @@ class MultiModalDataParser:
|
|||||||
elif isinstance(data, tuple) and len(data) == 2:
|
elif isinstance(data, tuple) and len(data) == 2:
|
||||||
data_items = [data]
|
data_items = [data]
|
||||||
else:
|
else:
|
||||||
data_items = data
|
data_items = data # type: ignore[assignment]
|
||||||
|
|
||||||
new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
|
new_videos = list[tuple[np.ndarray, dict[str, Any] | None]]()
|
||||||
metadata_lst: list[dict[str, Any] | None] = []
|
metadata_lst: list[dict[str, Any] | None] = []
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from typing_extensions import TypeVar, assert_never
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, decode_tokens, encode_tokens
|
||||||
from vllm.utils import flatten_2d_lists, full_groupby
|
from vllm.utils.collections import flatten_2d_lists, full_groupby
|
||||||
from vllm.utils.functools import get_allowed_kwarg_only_overrides
|
from vllm.utils.functools import get_allowed_kwarg_only_overrides
|
||||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
||||||
|
|
||||||
@ -484,8 +484,11 @@ _M = TypeVar("_M", bound=_HasModalityAttr | _HasModalityProp)
|
|||||||
|
|
||||||
|
|
||||||
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
|
def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
|
||||||
"""Convenience function to apply [`full_groupby`][vllm.utils.full_groupby]
|
"""
|
||||||
based on modality."""
|
Convenience function to apply
|
||||||
|
[`full_groupby`][vllm.utils.collections.full_groupby]
|
||||||
|
based on modality.
|
||||||
|
"""
|
||||||
return full_groupby(values, key=lambda x: x.modality)
|
return full_groupby(values, key=lambda x: x.modality)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import torch.nn as nn
|
|||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, cached_tokenizer_from_config
|
||||||
from vllm.utils import ClassRegistry
|
from vllm.utils.collections import ClassRegistry
|
||||||
|
|
||||||
from .cache import BaseMultiModalProcessorCache
|
from .cache import BaseMultiModalProcessorCache
|
||||||
from .processing import (
|
from .processing import (
|
||||||
|
|||||||
@ -8,7 +8,8 @@ from functools import cached_property
|
|||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import import_from_path, is_list_of
|
from vllm.utils import import_from_path
|
||||||
|
from vllm.utils.collections import is_list_of
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
|||||||
@ -37,29 +37,19 @@ from argparse import (
|
|||||||
RawDescriptionHelpFormatter,
|
RawDescriptionHelpFormatter,
|
||||||
_ArgumentGroup,
|
_ArgumentGroup,
|
||||||
)
|
)
|
||||||
from collections import UserDict, defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import (
|
from collections.abc import (
|
||||||
Callable,
|
Callable,
|
||||||
Collection,
|
Collection,
|
||||||
Generator,
|
Generator,
|
||||||
Hashable,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
Iterator,
|
||||||
Mapping,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
)
|
)
|
||||||
from concurrent.futures.process import ProcessPoolExecutor
|
from concurrent.futures.process import ProcessPoolExecutor
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import cache, lru_cache, partial, wraps
|
from functools import cache, lru_cache, partial, wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, Any, TextIO, TypeVar
|
||||||
TYPE_CHECKING,
|
|
||||||
Any,
|
|
||||||
Generic,
|
|
||||||
Literal,
|
|
||||||
TextIO,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -78,7 +68,7 @@ import zmq.asyncio
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
from typing_extensions import Never, TypeIs, assert_never
|
from typing_extensions import Never
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import enable_trace_function_call, init_logger
|
from vllm.logger import enable_trace_function_call, init_logger
|
||||||
@ -170,9 +160,6 @@ def set_default_torch_num_threads(num_threads: int):
|
|||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
U = TypeVar("U")
|
U = TypeVar("U")
|
||||||
|
|
||||||
_K = TypeVar("_K", bound=Hashable)
|
|
||||||
_V = TypeVar("_V")
|
|
||||||
|
|
||||||
|
|
||||||
class Device(enum.Enum):
|
class Device(enum.Enum):
|
||||||
GPU = enum.auto()
|
GPU = enum.auto()
|
||||||
@ -421,12 +408,6 @@ def update_environment_variables(envs: dict[str, str]):
|
|||||||
os.environ[k] = v
|
os.environ[k] = v
|
||||||
|
|
||||||
|
|
||||||
def chunk_list(lst: list[T], chunk_size: int):
|
|
||||||
"""Yield successive chunk_size chunks from lst."""
|
|
||||||
for i in range(0, len(lst), chunk_size):
|
|
||||||
yield lst[i : i + chunk_size]
|
|
||||||
|
|
||||||
|
|
||||||
def cdiv(a: int, b: int) -> int:
|
def cdiv(a: int, b: int) -> int:
|
||||||
"""Ceiling division."""
|
"""Ceiling division."""
|
||||||
return -(a // -b)
|
return -(a // -b)
|
||||||
@ -743,53 +724,6 @@ def common_broadcastable_dtype(dtypes: Collection[torch.dtype]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def as_list(maybe_list: Iterable[T]) -> list[T]:
|
|
||||||
"""Convert iterable to list, unless it's already a list."""
|
|
||||||
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
|
|
||||||
|
|
||||||
|
|
||||||
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
|
|
||||||
if isinstance(obj, str) or not isinstance(obj, Iterable):
|
|
||||||
return [obj] # type: ignore[list-item]
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
|
||||||
# `collections` helpers
|
|
||||||
def is_list_of(
|
|
||||||
value: object,
|
|
||||||
typ: type[T] | tuple[type[T], ...],
|
|
||||||
*,
|
|
||||||
check: Literal["first", "all"] = "first",
|
|
||||||
) -> TypeIs[list[T]]:
|
|
||||||
if not isinstance(value, list):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if check == "first":
|
|
||||||
return len(value) == 0 or isinstance(value[0], typ)
|
|
||||||
elif check == "all":
|
|
||||||
return all(isinstance(v, typ) for v in value)
|
|
||||||
|
|
||||||
assert_never(check)
|
|
||||||
|
|
||||||
|
|
||||||
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
|
|
||||||
"""Flatten a list of lists to a single list."""
|
|
||||||
return [item for sublist in lists for item in sublist]
|
|
||||||
|
|
||||||
|
|
||||||
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
|
||||||
"""
|
|
||||||
Unlike [`itertools.groupby`][], groups are not broken by
|
|
||||||
non-contiguous data.
|
|
||||||
"""
|
|
||||||
groups = defaultdict[_K, list[_V]](list)
|
|
||||||
|
|
||||||
for value in values:
|
|
||||||
groups[key(value)].append(value)
|
|
||||||
|
|
||||||
return groups.items()
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: This function can be removed if transformer_modules classes are
|
# TODO: This function can be removed if transformer_modules classes are
|
||||||
# serialized by value when communicating between processes
|
# serialized by value when communicating between processes
|
||||||
def init_cached_hf_modules() -> None:
|
def init_cached_hf_modules() -> None:
|
||||||
@ -1578,50 +1512,6 @@ class AtomicCounter:
|
|||||||
return self._value
|
return self._value
|
||||||
|
|
||||||
|
|
||||||
# Adapted from: https://stackoverflow.com/a/47212782/5082708
|
|
||||||
class LazyDict(Mapping[str, T], Generic[T]):
|
|
||||||
def __init__(self, factory: dict[str, Callable[[], T]]):
|
|
||||||
self._factory = factory
|
|
||||||
self._dict: dict[str, T] = {}
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> T:
|
|
||||||
if key not in self._dict:
|
|
||||||
if key not in self._factory:
|
|
||||||
raise KeyError(key)
|
|
||||||
self._dict[key] = self._factory[key]()
|
|
||||||
return self._dict[key]
|
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: Callable[[], T]):
|
|
||||||
self._factory[key] = value
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return iter(self._factory)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._factory)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassRegistry(UserDict[type[T], _V]):
|
|
||||||
def __getitem__(self, key: type[T]) -> _V:
|
|
||||||
for cls in key.mro():
|
|
||||||
if cls in self.data:
|
|
||||||
return self.data[cls]
|
|
||||||
|
|
||||||
raise KeyError(key)
|
|
||||||
|
|
||||||
def __contains__(self, key: object) -> bool:
|
|
||||||
return self.contains(key)
|
|
||||||
|
|
||||||
def contains(self, key: object, *, strict: bool = False) -> bool:
|
|
||||||
if not isinstance(key, type):
|
|
||||||
return False
|
|
||||||
|
|
||||||
if strict:
|
|
||||||
return key in self.data
|
|
||||||
|
|
||||||
return any(cls in self.data for cls in key.mro())
|
|
||||||
|
|
||||||
|
|
||||||
def weak_ref_tensor(tensor: Any) -> Any:
|
def weak_ref_tensor(tensor: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
Create a weak reference to a tensor.
|
Create a weak reference to a tensor.
|
||||||
@ -2588,22 +2478,6 @@ class LazyLoader(types.ModuleType):
|
|||||||
return dir(self._module)
|
return dir(self._module)
|
||||||
|
|
||||||
|
|
||||||
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
|
|
||||||
"""
|
|
||||||
Helper function to swap values for two keys
|
|
||||||
"""
|
|
||||||
v1 = obj.get(key1)
|
|
||||||
v2 = obj.get(key2)
|
|
||||||
if v1 is not None:
|
|
||||||
obj[key2] = v1
|
|
||||||
else:
|
|
||||||
obj.pop(key2, None)
|
|
||||||
if v2 is not None:
|
|
||||||
obj[key1] = v2
|
|
||||||
else:
|
|
||||||
obj.pop(key1, None)
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def cprofile_context(save_file: str | None = None):
|
def cprofile_context(save_file: str | None = None):
|
||||||
"""Run a cprofile
|
"""Run a cprofile
|
||||||
|
|||||||
139
vllm/utils/collections.py
Normal file
139
vllm/utils/collections.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Contains helpers that are applied to collections.
|
||||||
|
|
||||||
|
This is similar in concept to the `collections` module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections import UserDict, defaultdict
|
||||||
|
from collections.abc import Callable, Generator, Hashable, Iterable, Mapping
|
||||||
|
from typing import Generic, Literal, TypeVar
|
||||||
|
|
||||||
|
from typing_extensions import TypeIs, assert_never
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
U = TypeVar("U")
|
||||||
|
|
||||||
|
_K = TypeVar("_K", bound=Hashable)
|
||||||
|
_V = TypeVar("_V")
|
||||||
|
|
||||||
|
|
||||||
|
class ClassRegistry(UserDict[type[T], _V]):
|
||||||
|
"""
|
||||||
|
A registry that acts like a dictionary but searches for other classes
|
||||||
|
in the MRO if the original class is not found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getitem__(self, key: type[T]) -> _V:
|
||||||
|
for cls in key.mro():
|
||||||
|
if cls in self.data:
|
||||||
|
return self.data[cls]
|
||||||
|
|
||||||
|
raise KeyError(key)
|
||||||
|
|
||||||
|
def __contains__(self, key: object) -> bool:
|
||||||
|
return self.contains(key)
|
||||||
|
|
||||||
|
def contains(self, key: object, *, strict: bool = False) -> bool:
|
||||||
|
if not isinstance(key, type):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if strict:
|
||||||
|
return key in self.data
|
||||||
|
|
||||||
|
return any(cls in self.data for cls in key.mro())
|
||||||
|
|
||||||
|
|
||||||
|
class LazyDict(Mapping[str, T], Generic[T]):
|
||||||
|
"""
|
||||||
|
Evaluates dictionary items only when they are accessed.
|
||||||
|
|
||||||
|
Adapted from: https://stackoverflow.com/a/47212782/5082708
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, factory: dict[str, Callable[[], T]]):
|
||||||
|
self._factory = factory
|
||||||
|
self._dict: dict[str, T] = {}
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> T:
|
||||||
|
if key not in self._dict:
|
||||||
|
if key not in self._factory:
|
||||||
|
raise KeyError(key)
|
||||||
|
self._dict[key] = self._factory[key]()
|
||||||
|
return self._dict[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: Callable[[], T]):
|
||||||
|
self._factory[key] = value
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._factory)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._factory)
|
||||||
|
|
||||||
|
|
||||||
|
def as_list(maybe_list: Iterable[T]) -> list[T]:
|
||||||
|
"""Convert iterable to list, unless it's already a list."""
|
||||||
|
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
|
||||||
|
|
||||||
|
|
||||||
|
def as_iter(obj: T | Iterable[T]) -> Iterable[T]:
|
||||||
|
if isinstance(obj, str) or not isinstance(obj, Iterable):
|
||||||
|
return [obj] # type: ignore[list-item]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def is_list_of(
|
||||||
|
value: object,
|
||||||
|
typ: type[T] | tuple[type[T], ...],
|
||||||
|
*,
|
||||||
|
check: Literal["first", "all"] = "first",
|
||||||
|
) -> TypeIs[list[T]]:
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if check == "first":
|
||||||
|
return len(value) == 0 or isinstance(value[0], typ)
|
||||||
|
elif check == "all":
|
||||||
|
return all(isinstance(v, typ) for v in value)
|
||||||
|
|
||||||
|
assert_never(check)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_list(lst: list[T], chunk_size: int) -> Generator[list[T]]:
|
||||||
|
"""Yield successive chunk_size chunks from lst."""
|
||||||
|
for i in range(0, len(lst), chunk_size):
|
||||||
|
yield lst[i : i + chunk_size]
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_2d_lists(lists: Iterable[Iterable[T]]) -> list[T]:
|
||||||
|
"""Flatten a list of lists to a single list."""
|
||||||
|
return [item for sublist in lists for item in sublist]
|
||||||
|
|
||||||
|
|
||||||
|
def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]):
|
||||||
|
"""
|
||||||
|
Unlike [`itertools.groupby`][], groups are not broken by
|
||||||
|
non-contiguous data.
|
||||||
|
"""
|
||||||
|
groups = defaultdict[_K, list[_V]](list)
|
||||||
|
|
||||||
|
for value in values:
|
||||||
|
groups[key(value)].append(value)
|
||||||
|
|
||||||
|
return groups.items()
|
||||||
|
|
||||||
|
|
||||||
|
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
|
||||||
|
"""Swap values between two keys."""
|
||||||
|
v1 = obj.get(key1)
|
||||||
|
v2 = obj.get(key2)
|
||||||
|
if v1 is not None:
|
||||||
|
obj[key2] = v1
|
||||||
|
else:
|
||||||
|
obj.pop(key2, None)
|
||||||
|
if v2 is not None:
|
||||||
|
obj[key1] = v2
|
||||||
|
else:
|
||||||
|
obj.pop(key1, None)
|
||||||
@ -29,8 +29,9 @@ from vllm.tracing import init_tracer
|
|||||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, init_tokenizer_from_configs
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Device, as_list, cdiv
|
from vllm.utils import Device, cdiv
|
||||||
from vllm.utils.asyncio import cancel_task_threadsafe
|
from vllm.utils.asyncio import cancel_task_threadsafe
|
||||||
|
from vllm.utils.collections import as_list
|
||||||
from vllm.utils.functools import deprecate_kwargs
|
from vllm.utils.functools import deprecate_kwargs
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core_client import EngineCoreClient
|
||||||
|
|||||||
@ -12,7 +12,8 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||||
|
from vllm.utils.collections import swap_dict_values
|
||||||
from vllm.v1.outputs import LogprobsTensors
|
from vllm.v1.outputs import LogprobsTensors
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.logits_processor import (
|
from vllm.v1.sample.logits_processor import (
|
||||||
|
|||||||
@ -9,7 +9,8 @@ import torch
|
|||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||||
|
from vllm.utils.collections import swap_dict_values
|
||||||
from vllm.v1.outputs import LogprobsTensors
|
from vllm.v1.outputs import LogprobsTensors
|
||||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user