Move DeviceConfig, ObservabilityConfig, SpeechToTextConfig to their own files (#25564)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-09-24 14:59:05 +01:00 committed by GitHub
parent e18b714b2e
commit 8938774c79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 215 additions and 183 deletions

View File

@ -27,6 +27,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
PrefixCachingHashAlgo)
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
CUDAGraphMode, PassConfig)
from vllm.config.device import Device, DeviceConfig
from vllm.config.kv_events import KVEventsConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.config.load import LoadConfig
@ -38,11 +39,13 @@ from vllm.config.model import (ConvertOption, HfOverrides, LogprobsMode,
try_match_architecture_defaults)
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
MultiModalConfig)
from vllm.config.observability import DetailedTraceModules, ObservabilityConfig
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
ParallelConfig)
from vllm.config.pooler import PoolerConfig
from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy
from vllm.config.speculative import SpeculativeConfig
from vllm.config.speech_to_text import SpeechToTextConfig
from vllm.config.structured_outputs import StructuredOutputsConfig
from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field
from vllm.logger import init_logger
@ -81,158 +84,6 @@ class SupportsMetricsInfo(Protocol):
...
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
"""Device type for vLLM execution.
This parameter is deprecated and will be
removed in a future release.
It will now be set automatically based
on the current platform."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
if not self.device_type:
raise RuntimeError(
"Failed to infer device type, please set "
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
"to turn on verbose logging to help debug the issue.")
else:
# Device type is assigned explicitly
if isinstance(self.device, str):
self.device_type = self.device
elif isinstance(self.device, torch.device):
self.device_type = self.device.type
# Some device types require processing inputs on CPU
if self.device_type in ["tpu"]:
self.device = None
else:
# Set device with device type
self.device = torch.device(self.device_type)
DetailedTraceModules = Literal["model", "worker", "all"]
@config
@dataclass
class ObservabilityConfig:
"""Configuration for observability - metrics and tracing."""
show_hidden_metrics_for_version: Optional[str] = None
"""Enable deprecated Prometheus metrics that have been hidden since the
specified version. For example, if a previously deprecated metric has been
hidden since the v0.7.0 release, you use
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
you migrate to new metrics. The metric is likely to be removed completely
in an upcoming release."""
@cached_property
def show_hidden_metrics(self) -> bool:
"""Check if the hidden metrics should be shown."""
if self.show_hidden_metrics_for_version is None:
return False
return version._prev_minor_version_was(
self.show_hidden_metrics_for_version)
otlp_traces_endpoint: Optional[str] = None
"""Target URL to which OpenTelemetry traces will be sent."""
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
set, it will collect detailed traces for the specified modules. This
involves use of possibly costly and or blocking operations and hence might
have a performance impact.
Note that collecting detailed timing information for each request can be
expensive."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""
return (self.collect_detailed_traces is not None
and ("model" in self.collect_detailed_traces
or "all" in self.collect_detailed_traces))
@cached_property
def collect_model_execute_time(self) -> bool:
"""Whether to collect model execute time for the request."""
return (self.collect_detailed_traces is not None
and ("worker" in self.collect_detailed_traces
or "all" in self.collect_detailed_traces))
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if (self.collect_detailed_traces is not None
and len(self.collect_detailed_traces) == 1
and "," in self.collect_detailed_traces[0]):
self._parse_collect_detailed_traces()
from vllm.tracing import is_otel_available, otel_import_error_traceback
if not is_otel_available() and self.otlp_traces_endpoint is not None:
raise ValueError(
"OpenTelemetry is not available. Unable to configure "
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
f"installed. Original error:\n{otel_import_error_traceback}")
def _parse_collect_detailed_traces(self):
assert isinstance(self.collect_detailed_traces, list)
self.collect_detailed_traces = cast(
list[DetailedTraceModules],
self.collect_detailed_traces[0].split(","))
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig:
@ -1009,37 +860,6 @@ def get_layers_from_vllm_config(
}
@config
@dataclass
class SpeechToTextConfig:
"""Configuration for speech-to-text models."""
sample_rate: float = 16_000
"""Sample rate (Hz) to resample input audio to. Most speech models expect
16kHz audio input. The input audio will be automatically resampled to this
rate before processing."""
max_audio_clip_s: int = 30
"""Maximum duration in seconds for a single audio clip without chunking.
Audio longer than this will be split into smaller chunks if
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
overlap_chunk_second: int = 1
"""Overlap duration in seconds between consecutive audio chunks when
splitting long audio. This helps maintain context across chunk boundaries
and improves transcription quality at split points."""
min_energy_split_window_size: Optional[int] = 1600
"""Window size in samples for finding low-energy (quiet) regions to split
audio chunks. The algorithm looks for the quietest moment within this
window to minimize cutting through speech. Default 1600 samples 100ms
at 16kHz. If None, no chunking will be done."""
@property
def allow_audio_chunking(self) -> bool:
return self.min_energy_split_window_size is not None
def update_config(config: DataclassInstanceT,
overrides: dict[str, Any]) -> DataclassInstanceT:
processed_overrides = {}

74
vllm/config/device.py Normal file
View File

@ -0,0 +1,74 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import field
from typing import Any, Literal, Optional, Union
import torch
from pydantic import ConfigDict, SkipValidation
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
@config
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig:
"""Configuration for the device to use for vLLM execution."""
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
"""Device type for vLLM execution.
This parameter is deprecated and will be
removed in a future release.
It will now be set automatically based
on the current platform."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# the device/platform information will be summarized
# by torch/vllm automatically.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if self.device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
if not self.device_type:
raise RuntimeError(
"Failed to infer device type, please set "
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
"to turn on verbose logging to help debug the issue.")
else:
# Device type is assigned explicitly
if isinstance(self.device, str):
self.device_type = self.device
elif isinstance(self.device, torch.device):
self.device_type = self.device.type
# Some device types require processing inputs on CPU
if self.device_type in ["tpu"]:
self.device = None
else:
# Set device with device type
self.device = torch.device(self.device_type)

View File

@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from functools import cached_property
from typing import Any, Literal, Optional, cast
from pydantic.dataclasses import dataclass
from vllm import version
from vllm.config.utils import config
DetailedTraceModules = Literal["model", "worker", "all"]
@config
@dataclass
class ObservabilityConfig:
"""Configuration for observability - metrics and tracing."""
show_hidden_metrics_for_version: Optional[str] = None
"""Enable deprecated Prometheus metrics that have been hidden since the
specified version. For example, if a previously deprecated metric has been
hidden since the v0.7.0 release, you use
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
you migrate to new metrics. The metric is likely to be removed completely
in an upcoming release."""
@cached_property
def show_hidden_metrics(self) -> bool:
"""Check if the hidden metrics should be shown."""
if self.show_hidden_metrics_for_version is None:
return False
return version._prev_minor_version_was(
self.show_hidden_metrics_for_version)
otlp_traces_endpoint: Optional[str] = None
"""Target URL to which OpenTelemetry traces will be sent."""
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
set, it will collect detailed traces for the specified modules. This
involves use of possibly costly and or blocking operations and hence might
have a performance impact.
Note that collecting detailed timing information for each request can be
expensive."""
@cached_property
def collect_model_forward_time(self) -> bool:
"""Whether to collect model forward time for the request."""
return (self.collect_detailed_traces is not None
and ("model" in self.collect_detailed_traces
or "all" in self.collect_detailed_traces))
@cached_property
def collect_model_execute_time(self) -> bool:
"""Whether to collect model execute time for the request."""
return (self.collect_detailed_traces is not None
and ("worker" in self.collect_detailed_traces
or "all" in self.collect_detailed_traces))
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.
Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# no factors to consider.
# this config will not affect the computation graph.
factors: list[Any] = []
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
if (self.collect_detailed_traces is not None
and len(self.collect_detailed_traces) == 1
and "," in self.collect_detailed_traces[0]):
self._parse_collect_detailed_traces()
from vllm.tracing import is_otel_available, otel_import_error_traceback
if not is_otel_available() and self.otlp_traces_endpoint is not None:
raise ValueError(
"OpenTelemetry is not available. Unable to configure "
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
f"installed. Original error:\n{otel_import_error_traceback}")
def _parse_collect_detailed_traces(self):
assert isinstance(self.collect_detailed_traces, list)
self.collect_detailed_traces = cast(
list[DetailedTraceModules],
self.collect_detailed_traces[0].split(","))

View File

@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@config
@dataclass
class SpeechToTextConfig:
"""Configuration for speech-to-text models."""
sample_rate: float = 16_000
"""Sample rate (Hz) to resample input audio to. Most speech models expect
16kHz audio input. The input audio will be automatically resampled to this
rate before processing."""
max_audio_clip_s: int = 30
"""Maximum duration in seconds for a single audio clip without chunking.
Audio longer than this will be split into smaller chunks if
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
overlap_chunk_second: int = 1
"""Overlap duration in seconds between consecutive audio chunks when
splitting long audio. This helps maintain context across chunk boundaries
and improves transcription quality at split points."""
min_energy_split_window_size: Optional[int] = 1600
"""Window size in samples for finding low-energy (quiet) regions to split
audio chunks. The algorithm looks for the quietest moment within this
window to minimize cutting through speech. Default 1600 samples 100ms
at 16kHz. If None, no chunking will be done."""
@property
def allow_audio_chunking(self) -> bool:
return self.min_energy_split_window_size is not None