mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 01:15:44 +08:00
Move DeviceConfig, ObservabilityConfig, SpeechToTextConfig to their own files (#25564)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
71566e8afc
commit
44d6701f70
@ -27,6 +27,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
|||||||
PrefixCachingHashAlgo)
|
PrefixCachingHashAlgo)
|
||||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||||
CUDAGraphMode, PassConfig)
|
CUDAGraphMode, PassConfig)
|
||||||
|
from vllm.config.device import Device, DeviceConfig
|
||||||
from vllm.config.kv_events import KVEventsConfig
|
from vllm.config.kv_events import KVEventsConfig
|
||||||
from vllm.config.kv_transfer import KVTransferConfig
|
from vllm.config.kv_transfer import KVTransferConfig
|
||||||
from vllm.config.load import LoadConfig
|
from vllm.config.load import LoadConfig
|
||||||
@ -38,11 +39,13 @@ from vllm.config.model import (ConvertOption, HfOverrides, LogprobsMode,
|
|||||||
try_match_architecture_defaults)
|
try_match_architecture_defaults)
|
||||||
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||||
MultiModalConfig)
|
MultiModalConfig)
|
||||||
|
from vllm.config.observability import DetailedTraceModules, ObservabilityConfig
|
||||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||||
ParallelConfig)
|
ParallelConfig)
|
||||||
from vllm.config.pooler import PoolerConfig
|
from vllm.config.pooler import PoolerConfig
|
||||||
from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy
|
from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy
|
||||||
from vllm.config.speculative import SpeculativeConfig
|
from vllm.config.speculative import SpeculativeConfig
|
||||||
|
from vllm.config.speech_to_text import SpeechToTextConfig
|
||||||
from vllm.config.structured_outputs import StructuredOutputsConfig
|
from vllm.config.structured_outputs import StructuredOutputsConfig
|
||||||
from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field
|
from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -81,158 +84,6 @@ class SupportsMetricsInfo(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
|
||||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
|
||||||
class DeviceConfig:
|
|
||||||
"""Configuration for the device to use for vLLM execution."""
|
|
||||||
|
|
||||||
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
|
|
||||||
"""Device type for vLLM execution.
|
|
||||||
This parameter is deprecated and will be
|
|
||||||
removed in a future release.
|
|
||||||
It will now be set automatically based
|
|
||||||
on the current platform."""
|
|
||||||
device_type: str = field(init=False)
|
|
||||||
"""Device type from the current platform. This is set in
|
|
||||||
`__post_init__`."""
|
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
|
||||||
"""
|
|
||||||
WARNING: Whenever a new field is added to this config,
|
|
||||||
ensure that it is included in the factors list if
|
|
||||||
it affects the computation graph.
|
|
||||||
|
|
||||||
Provide a hash that uniquely identifies all the configs
|
|
||||||
that affect the structure of the computation
|
|
||||||
graph from input ids/embeddings to the final hidden states,
|
|
||||||
excluding anything before input ids/embeddings and after
|
|
||||||
the final hidden states.
|
|
||||||
"""
|
|
||||||
# no factors to consider.
|
|
||||||
# the device/platform information will be summarized
|
|
||||||
# by torch/vllm automatically.
|
|
||||||
factors: list[Any] = []
|
|
||||||
hash_str = hashlib.md5(str(factors).encode(),
|
|
||||||
usedforsecurity=False).hexdigest()
|
|
||||||
return hash_str
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.device == "auto":
|
|
||||||
# Automated device type detection
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
self.device_type = current_platform.device_type
|
|
||||||
if not self.device_type:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Failed to infer device type, please set "
|
|
||||||
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
|
|
||||||
"to turn on verbose logging to help debug the issue.")
|
|
||||||
else:
|
|
||||||
# Device type is assigned explicitly
|
|
||||||
if isinstance(self.device, str):
|
|
||||||
self.device_type = self.device
|
|
||||||
elif isinstance(self.device, torch.device):
|
|
||||||
self.device_type = self.device.type
|
|
||||||
|
|
||||||
# Some device types require processing inputs on CPU
|
|
||||||
if self.device_type in ["tpu"]:
|
|
||||||
self.device = None
|
|
||||||
else:
|
|
||||||
# Set device with device type
|
|
||||||
self.device = torch.device(self.device_type)
|
|
||||||
|
|
||||||
|
|
||||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
|
||||||
@dataclass
|
|
||||||
class ObservabilityConfig:
|
|
||||||
"""Configuration for observability - metrics and tracing."""
|
|
||||||
|
|
||||||
show_hidden_metrics_for_version: Optional[str] = None
|
|
||||||
"""Enable deprecated Prometheus metrics that have been hidden since the
|
|
||||||
specified version. For example, if a previously deprecated metric has been
|
|
||||||
hidden since the v0.7.0 release, you use
|
|
||||||
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
|
|
||||||
you migrate to new metrics. The metric is likely to be removed completely
|
|
||||||
in an upcoming release."""
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def show_hidden_metrics(self) -> bool:
|
|
||||||
"""Check if the hidden metrics should be shown."""
|
|
||||||
if self.show_hidden_metrics_for_version is None:
|
|
||||||
return False
|
|
||||||
return version._prev_minor_version_was(
|
|
||||||
self.show_hidden_metrics_for_version)
|
|
||||||
|
|
||||||
otlp_traces_endpoint: Optional[str] = None
|
|
||||||
"""Target URL to which OpenTelemetry traces will be sent."""
|
|
||||||
|
|
||||||
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
|
|
||||||
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
|
|
||||||
set, it will collect detailed traces for the specified modules. This
|
|
||||||
involves use of possibly costly and or blocking operations and hence might
|
|
||||||
have a performance impact.
|
|
||||||
|
|
||||||
Note that collecting detailed timing information for each request can be
|
|
||||||
expensive."""
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def collect_model_forward_time(self) -> bool:
|
|
||||||
"""Whether to collect model forward time for the request."""
|
|
||||||
return (self.collect_detailed_traces is not None
|
|
||||||
and ("model" in self.collect_detailed_traces
|
|
||||||
or "all" in self.collect_detailed_traces))
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def collect_model_execute_time(self) -> bool:
|
|
||||||
"""Whether to collect model execute time for the request."""
|
|
||||||
return (self.collect_detailed_traces is not None
|
|
||||||
and ("worker" in self.collect_detailed_traces
|
|
||||||
or "all" in self.collect_detailed_traces))
|
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
|
||||||
"""
|
|
||||||
WARNING: Whenever a new field is added to this config,
|
|
||||||
ensure that it is included in the factors list if
|
|
||||||
it affects the computation graph.
|
|
||||||
|
|
||||||
Provide a hash that uniquely identifies all the configs
|
|
||||||
that affect the structure of the computation
|
|
||||||
graph from input ids/embeddings to the final hidden states,
|
|
||||||
excluding anything before input ids/embeddings and after
|
|
||||||
the final hidden states.
|
|
||||||
"""
|
|
||||||
# no factors to consider.
|
|
||||||
# this config will not affect the computation graph.
|
|
||||||
factors: list[Any] = []
|
|
||||||
hash_str = hashlib.md5(str(factors).encode(),
|
|
||||||
usedforsecurity=False).hexdigest()
|
|
||||||
return hash_str
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if (self.collect_detailed_traces is not None
|
|
||||||
and len(self.collect_detailed_traces) == 1
|
|
||||||
and "," in self.collect_detailed_traces[0]):
|
|
||||||
self._parse_collect_detailed_traces()
|
|
||||||
|
|
||||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
|
||||||
if not is_otel_available() and self.otlp_traces_endpoint is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"OpenTelemetry is not available. Unable to configure "
|
|
||||||
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
|
||||||
f"installed. Original error:\n{otel_import_error_traceback}")
|
|
||||||
|
|
||||||
def _parse_collect_detailed_traces(self):
|
|
||||||
assert isinstance(self.collect_detailed_traces, list)
|
|
||||||
self.collect_detailed_traces = cast(
|
|
||||||
list[DetailedTraceModules],
|
|
||||||
self.collect_detailed_traces[0].split(","))
|
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||||
class VllmConfig:
|
class VllmConfig:
|
||||||
@ -1009,37 +860,6 @@ def get_layers_from_vllm_config(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@config
|
|
||||||
@dataclass
|
|
||||||
class SpeechToTextConfig:
|
|
||||||
"""Configuration for speech-to-text models."""
|
|
||||||
|
|
||||||
sample_rate: float = 16_000
|
|
||||||
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
|
||||||
16kHz audio input. The input audio will be automatically resampled to this
|
|
||||||
rate before processing."""
|
|
||||||
|
|
||||||
max_audio_clip_s: int = 30
|
|
||||||
"""Maximum duration in seconds for a single audio clip without chunking.
|
|
||||||
Audio longer than this will be split into smaller chunks if
|
|
||||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
|
||||||
|
|
||||||
overlap_chunk_second: int = 1
|
|
||||||
"""Overlap duration in seconds between consecutive audio chunks when
|
|
||||||
splitting long audio. This helps maintain context across chunk boundaries
|
|
||||||
and improves transcription quality at split points."""
|
|
||||||
|
|
||||||
min_energy_split_window_size: Optional[int] = 1600
|
|
||||||
"""Window size in samples for finding low-energy (quiet) regions to split
|
|
||||||
audio chunks. The algorithm looks for the quietest moment within this
|
|
||||||
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
|
||||||
at 16kHz. If None, no chunking will be done."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def allow_audio_chunking(self) -> bool:
|
|
||||||
return self.min_energy_split_window_size is not None
|
|
||||||
|
|
||||||
|
|
||||||
def update_config(config: DataclassInstanceT,
|
def update_config(config: DataclassInstanceT,
|
||||||
overrides: dict[str, Any]) -> DataclassInstanceT:
|
overrides: dict[str, Any]) -> DataclassInstanceT:
|
||||||
processed_overrides = {}
|
processed_overrides = {}
|
||||||
|
|||||||
74
vllm/config/device.py
Normal file
74
vllm/config/device.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from dataclasses import field
|
||||||
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import ConfigDict, SkipValidation
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm.config.utils import config
|
||||||
|
|
||||||
|
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||||
|
class DeviceConfig:
|
||||||
|
"""Configuration for the device to use for vLLM execution."""
|
||||||
|
|
||||||
|
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
|
||||||
|
"""Device type for vLLM execution.
|
||||||
|
This parameter is deprecated and will be
|
||||||
|
removed in a future release.
|
||||||
|
It will now be set automatically based
|
||||||
|
on the current platform."""
|
||||||
|
device_type: str = field(init=False)
|
||||||
|
"""Device type from the current platform. This is set in
|
||||||
|
`__post_init__`."""
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
"""
|
||||||
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
ensure that it is included in the factors list if
|
||||||
|
it affects the computation graph.
|
||||||
|
|
||||||
|
Provide a hash that uniquely identifies all the configs
|
||||||
|
that affect the structure of the computation
|
||||||
|
graph from input ids/embeddings to the final hidden states,
|
||||||
|
excluding anything before input ids/embeddings and after
|
||||||
|
the final hidden states.
|
||||||
|
"""
|
||||||
|
# no factors to consider.
|
||||||
|
# the device/platform information will be summarized
|
||||||
|
# by torch/vllm automatically.
|
||||||
|
factors: list[Any] = []
|
||||||
|
hash_str = hashlib.md5(str(factors).encode(),
|
||||||
|
usedforsecurity=False).hexdigest()
|
||||||
|
return hash_str
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.device == "auto":
|
||||||
|
# Automated device type detection
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
self.device_type = current_platform.device_type
|
||||||
|
if not self.device_type:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to infer device type, please set "
|
||||||
|
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
|
||||||
|
"to turn on verbose logging to help debug the issue.")
|
||||||
|
else:
|
||||||
|
# Device type is assigned explicitly
|
||||||
|
if isinstance(self.device, str):
|
||||||
|
self.device_type = self.device
|
||||||
|
elif isinstance(self.device, torch.device):
|
||||||
|
self.device_type = self.device.type
|
||||||
|
|
||||||
|
# Some device types require processing inputs on CPU
|
||||||
|
if self.device_type in ["tpu"]:
|
||||||
|
self.device = None
|
||||||
|
else:
|
||||||
|
# Set device with device type
|
||||||
|
self.device = torch.device(self.device_type)
|
||||||
99
vllm/config/observability.py
Normal file
99
vllm/config/observability.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from functools import cached_property
|
||||||
|
from typing import Any, Literal, Optional, cast
|
||||||
|
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm import version
|
||||||
|
from vllm.config.utils import config
|
||||||
|
|
||||||
|
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class ObservabilityConfig:
|
||||||
|
"""Configuration for observability - metrics and tracing."""
|
||||||
|
|
||||||
|
show_hidden_metrics_for_version: Optional[str] = None
|
||||||
|
"""Enable deprecated Prometheus metrics that have been hidden since the
|
||||||
|
specified version. For example, if a previously deprecated metric has been
|
||||||
|
hidden since the v0.7.0 release, you use
|
||||||
|
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
|
||||||
|
you migrate to new metrics. The metric is likely to be removed completely
|
||||||
|
in an upcoming release."""
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def show_hidden_metrics(self) -> bool:
|
||||||
|
"""Check if the hidden metrics should be shown."""
|
||||||
|
if self.show_hidden_metrics_for_version is None:
|
||||||
|
return False
|
||||||
|
return version._prev_minor_version_was(
|
||||||
|
self.show_hidden_metrics_for_version)
|
||||||
|
|
||||||
|
otlp_traces_endpoint: Optional[str] = None
|
||||||
|
"""Target URL to which OpenTelemetry traces will be sent."""
|
||||||
|
|
||||||
|
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
|
||||||
|
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
|
||||||
|
set, it will collect detailed traces for the specified modules. This
|
||||||
|
involves use of possibly costly and or blocking operations and hence might
|
||||||
|
have a performance impact.
|
||||||
|
|
||||||
|
Note that collecting detailed timing information for each request can be
|
||||||
|
expensive."""
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def collect_model_forward_time(self) -> bool:
|
||||||
|
"""Whether to collect model forward time for the request."""
|
||||||
|
return (self.collect_detailed_traces is not None
|
||||||
|
and ("model" in self.collect_detailed_traces
|
||||||
|
or "all" in self.collect_detailed_traces))
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def collect_model_execute_time(self) -> bool:
|
||||||
|
"""Whether to collect model execute time for the request."""
|
||||||
|
return (self.collect_detailed_traces is not None
|
||||||
|
and ("worker" in self.collect_detailed_traces
|
||||||
|
or "all" in self.collect_detailed_traces))
|
||||||
|
|
||||||
|
def compute_hash(self) -> str:
|
||||||
|
"""
|
||||||
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
ensure that it is included in the factors list if
|
||||||
|
it affects the computation graph.
|
||||||
|
|
||||||
|
Provide a hash that uniquely identifies all the configs
|
||||||
|
that affect the structure of the computation
|
||||||
|
graph from input ids/embeddings to the final hidden states,
|
||||||
|
excluding anything before input ids/embeddings and after
|
||||||
|
the final hidden states.
|
||||||
|
"""
|
||||||
|
# no factors to consider.
|
||||||
|
# this config will not affect the computation graph.
|
||||||
|
factors: list[Any] = []
|
||||||
|
hash_str = hashlib.md5(str(factors).encode(),
|
||||||
|
usedforsecurity=False).hexdigest()
|
||||||
|
return hash_str
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if (self.collect_detailed_traces is not None
|
||||||
|
and len(self.collect_detailed_traces) == 1
|
||||||
|
and "," in self.collect_detailed_traces[0]):
|
||||||
|
self._parse_collect_detailed_traces()
|
||||||
|
|
||||||
|
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||||
|
if not is_otel_available() and self.otlp_traces_endpoint is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"OpenTelemetry is not available. Unable to configure "
|
||||||
|
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
||||||
|
f"installed. Original error:\n{otel_import_error_traceback}")
|
||||||
|
|
||||||
|
def _parse_collect_detailed_traces(self):
|
||||||
|
assert isinstance(self.collect_detailed_traces, list)
|
||||||
|
self.collect_detailed_traces = cast(
|
||||||
|
list[DetailedTraceModules],
|
||||||
|
self.collect_detailed_traces[0].split(","))
|
||||||
39
vllm/config/speech_to_text.py
Normal file
39
vllm/config/speech_to_text.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
from vllm.config.utils import config
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
|
@dataclass
|
||||||
|
class SpeechToTextConfig:
|
||||||
|
"""Configuration for speech-to-text models."""
|
||||||
|
|
||||||
|
sample_rate: float = 16_000
|
||||||
|
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
||||||
|
16kHz audio input. The input audio will be automatically resampled to this
|
||||||
|
rate before processing."""
|
||||||
|
|
||||||
|
max_audio_clip_s: int = 30
|
||||||
|
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||||
|
Audio longer than this will be split into smaller chunks if
|
||||||
|
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
||||||
|
|
||||||
|
overlap_chunk_second: int = 1
|
||||||
|
"""Overlap duration in seconds between consecutive audio chunks when
|
||||||
|
splitting long audio. This helps maintain context across chunk boundaries
|
||||||
|
and improves transcription quality at split points."""
|
||||||
|
|
||||||
|
min_energy_split_window_size: Optional[int] = 1600
|
||||||
|
"""Window size in samples for finding low-energy (quiet) regions to split
|
||||||
|
audio chunks. The algorithm looks for the quietest moment within this
|
||||||
|
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
||||||
|
at 16kHz. If None, no chunking will be done."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allow_audio_chunking(self) -> bool:
|
||||||
|
return self.min_energy_split_window_size is not None
|
||||||
Loading…
x
Reference in New Issue
Block a user