mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
Move DeviceConfig, ObservabilityConfig, SpeechToTextConfig to their own files (#25564)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
e18b714b2e
commit
8938774c79
@ -27,6 +27,7 @@ from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType,
|
||||
PrefixCachingHashAlgo)
|
||||
from vllm.config.compilation import (CompilationConfig, CompilationLevel,
|
||||
CUDAGraphMode, PassConfig)
|
||||
from vllm.config.device import Device, DeviceConfig
|
||||
from vllm.config.kv_events import KVEventsConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
@ -38,11 +39,13 @@ from vllm.config.model import (ConvertOption, HfOverrides, LogprobsMode,
|
||||
try_match_architecture_defaults)
|
||||
from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode,
|
||||
MultiModalConfig)
|
||||
from vllm.config.observability import DetailedTraceModules, ObservabilityConfig
|
||||
from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig,
|
||||
ParallelConfig)
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.config.speech_to_text import SpeechToTextConfig
|
||||
from vllm.config.structured_outputs import StructuredOutputsConfig
|
||||
from vllm.config.utils import ConfigType, config, get_attr_docs, is_init_field
|
||||
from vllm.logger import init_logger
|
||||
@ -81,158 +84,6 @@ class SupportsMetricsInfo(Protocol):
|
||||
...
|
||||
|
||||
|
||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class DeviceConfig:
|
||||
"""Configuration for the device to use for vLLM execution."""
|
||||
|
||||
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
|
||||
"""Device type for vLLM execution.
|
||||
This parameter is deprecated and will be
|
||||
removed in a future release.
|
||||
It will now be set automatically based
|
||||
on the current platform."""
|
||||
device_type: str = field(init=False)
|
||||
"""Device type from the current platform. This is set in
|
||||
`__post_init__`."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# the device/platform information will be summarized
|
||||
# by torch/vllm automatically.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device == "auto":
|
||||
# Automated device type detection
|
||||
from vllm.platforms import current_platform
|
||||
self.device_type = current_platform.device_type
|
||||
if not self.device_type:
|
||||
raise RuntimeError(
|
||||
"Failed to infer device type, please set "
|
||||
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
|
||||
"to turn on verbose logging to help debug the issue.")
|
||||
else:
|
||||
# Device type is assigned explicitly
|
||||
if isinstance(self.device, str):
|
||||
self.device_type = self.device
|
||||
elif isinstance(self.device, torch.device):
|
||||
self.device_type = self.device.type
|
||||
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["tpu"]:
|
||||
self.device = None
|
||||
else:
|
||||
# Set device with device type
|
||||
self.device = torch.device(self.device_type)
|
||||
|
||||
|
||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ObservabilityConfig:
|
||||
"""Configuration for observability - metrics and tracing."""
|
||||
|
||||
show_hidden_metrics_for_version: Optional[str] = None
|
||||
"""Enable deprecated Prometheus metrics that have been hidden since the
|
||||
specified version. For example, if a previously deprecated metric has been
|
||||
hidden since the v0.7.0 release, you use
|
||||
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
|
||||
you migrate to new metrics. The metric is likely to be removed completely
|
||||
in an upcoming release."""
|
||||
|
||||
@cached_property
|
||||
def show_hidden_metrics(self) -> bool:
|
||||
"""Check if the hidden metrics should be shown."""
|
||||
if self.show_hidden_metrics_for_version is None:
|
||||
return False
|
||||
return version._prev_minor_version_was(
|
||||
self.show_hidden_metrics_for_version)
|
||||
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
"""Target URL to which OpenTelemetry traces will be sent."""
|
||||
|
||||
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
|
||||
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
|
||||
set, it will collect detailed traces for the specified modules. This
|
||||
involves use of possibly costly and or blocking operations and hence might
|
||||
have a performance impact.
|
||||
|
||||
Note that collecting detailed timing information for each request can be
|
||||
expensive."""
|
||||
|
||||
@cached_property
|
||||
def collect_model_forward_time(self) -> bool:
|
||||
"""Whether to collect model forward time for the request."""
|
||||
return (self.collect_detailed_traces is not None
|
||||
and ("model" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces))
|
||||
|
||||
@cached_property
|
||||
def collect_model_execute_time(self) -> bool:
|
||||
"""Whether to collect model execute time for the request."""
|
||||
return (self.collect_detailed_traces is not None
|
||||
and ("worker" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces))
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if (self.collect_detailed_traces is not None
|
||||
and len(self.collect_detailed_traces) == 1
|
||||
and "," in self.collect_detailed_traces[0]):
|
||||
self._parse_collect_detailed_traces()
|
||||
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
if not is_otel_available() and self.otlp_traces_endpoint is not None:
|
||||
raise ValueError(
|
||||
"OpenTelemetry is not available. Unable to configure "
|
||||
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
||||
f"installed. Original error:\n{otel_import_error_traceback}")
|
||||
|
||||
def _parse_collect_detailed_traces(self):
|
||||
assert isinstance(self.collect_detailed_traces, list)
|
||||
self.collect_detailed_traces = cast(
|
||||
list[DetailedTraceModules],
|
||||
self.collect_detailed_traces[0].split(","))
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class VllmConfig:
|
||||
@ -1009,37 +860,6 @@ def get_layers_from_vllm_config(
|
||||
}
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class SpeechToTextConfig:
|
||||
"""Configuration for speech-to-text models."""
|
||||
|
||||
sample_rate: float = 16_000
|
||||
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
||||
16kHz audio input. The input audio will be automatically resampled to this
|
||||
rate before processing."""
|
||||
|
||||
max_audio_clip_s: int = 30
|
||||
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||
Audio longer than this will be split into smaller chunks if
|
||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
||||
|
||||
overlap_chunk_second: int = 1
|
||||
"""Overlap duration in seconds between consecutive audio chunks when
|
||||
splitting long audio. This helps maintain context across chunk boundaries
|
||||
and improves transcription quality at split points."""
|
||||
|
||||
min_energy_split_window_size: Optional[int] = 1600
|
||||
"""Window size in samples for finding low-energy (quiet) regions to split
|
||||
audio chunks. The algorithm looks for the quietest moment within this
|
||||
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
||||
at 16kHz. If None, no chunking will be done."""
|
||||
|
||||
@property
|
||||
def allow_audio_chunking(self) -> bool:
|
||||
return self.min_energy_split_window_size is not None
|
||||
|
||||
|
||||
def update_config(config: DataclassInstanceT,
|
||||
overrides: dict[str, Any]) -> DataclassInstanceT:
|
||||
processed_overrides = {}
|
||||
|
||||
74
vllm/config/device.py
Normal file
74
vllm/config/device.py
Normal file
@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, SkipValidation
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
Device = Literal["auto", "cuda", "cpu", "tpu", "xpu"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class DeviceConfig:
|
||||
"""Configuration for the device to use for vLLM execution."""
|
||||
|
||||
device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto"
|
||||
"""Device type for vLLM execution.
|
||||
This parameter is deprecated and will be
|
||||
removed in a future release.
|
||||
It will now be set automatically based
|
||||
on the current platform."""
|
||||
device_type: str = field(init=False)
|
||||
"""Device type from the current platform. This is set in
|
||||
`__post_init__`."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# the device/platform information will be summarized
|
||||
# by torch/vllm automatically.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.device == "auto":
|
||||
# Automated device type detection
|
||||
from vllm.platforms import current_platform
|
||||
self.device_type = current_platform.device_type
|
||||
if not self.device_type:
|
||||
raise RuntimeError(
|
||||
"Failed to infer device type, please set "
|
||||
"the environment variable `VLLM_LOGGING_LEVEL=DEBUG` "
|
||||
"to turn on verbose logging to help debug the issue.")
|
||||
else:
|
||||
# Device type is assigned explicitly
|
||||
if isinstance(self.device, str):
|
||||
self.device_type = self.device
|
||||
elif isinstance(self.device, torch.device):
|
||||
self.device_type = self.device.type
|
||||
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["tpu"]:
|
||||
self.device = None
|
||||
else:
|
||||
# Set device with device type
|
||||
self.device = torch.device(self.device_type)
|
||||
99
vllm/config/observability.py
Normal file
99
vllm/config/observability.py
Normal file
@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from functools import cached_property
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm import version
|
||||
from vllm.config.utils import config
|
||||
|
||||
DetailedTraceModules = Literal["model", "worker", "all"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ObservabilityConfig:
|
||||
"""Configuration for observability - metrics and tracing."""
|
||||
|
||||
show_hidden_metrics_for_version: Optional[str] = None
|
||||
"""Enable deprecated Prometheus metrics that have been hidden since the
|
||||
specified version. For example, if a previously deprecated metric has been
|
||||
hidden since the v0.7.0 release, you use
|
||||
`--show-hidden-metrics-for-version=0.7` as a temporary escape hatch while
|
||||
you migrate to new metrics. The metric is likely to be removed completely
|
||||
in an upcoming release."""
|
||||
|
||||
@cached_property
|
||||
def show_hidden_metrics(self) -> bool:
|
||||
"""Check if the hidden metrics should be shown."""
|
||||
if self.show_hidden_metrics_for_version is None:
|
||||
return False
|
||||
return version._prev_minor_version_was(
|
||||
self.show_hidden_metrics_for_version)
|
||||
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
"""Target URL to which OpenTelemetry traces will be sent."""
|
||||
|
||||
collect_detailed_traces: Optional[list[DetailedTraceModules]] = None
|
||||
"""It makes sense to set this only if `--otlp-traces-endpoint` is set. If
|
||||
set, it will collect detailed traces for the specified modules. This
|
||||
involves use of possibly costly and or blocking operations and hence might
|
||||
have a performance impact.
|
||||
|
||||
Note that collecting detailed timing information for each request can be
|
||||
expensive."""
|
||||
|
||||
@cached_property
|
||||
def collect_model_forward_time(self) -> bool:
|
||||
"""Whether to collect model forward time for the request."""
|
||||
return (self.collect_detailed_traces is not None
|
||||
and ("model" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces))
|
||||
|
||||
@cached_property
|
||||
def collect_model_execute_time(self) -> bool:
|
||||
"""Whether to collect model execute time for the request."""
|
||||
return (self.collect_detailed_traces is not None
|
||||
and ("worker" in self.collect_detailed_traces
|
||||
or "all" in self.collect_detailed_traces))
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if (self.collect_detailed_traces is not None
|
||||
and len(self.collect_detailed_traces) == 1
|
||||
and "," in self.collect_detailed_traces[0]):
|
||||
self._parse_collect_detailed_traces()
|
||||
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
if not is_otel_available() and self.otlp_traces_endpoint is not None:
|
||||
raise ValueError(
|
||||
"OpenTelemetry is not available. Unable to configure "
|
||||
"'otlp_traces_endpoint'. Ensure OpenTelemetry packages are "
|
||||
f"installed. Original error:\n{otel_import_error_traceback}")
|
||||
|
||||
def _parse_collect_detailed_traces(self):
|
||||
assert isinstance(self.collect_detailed_traces, list)
|
||||
self.collect_detailed_traces = cast(
|
||||
list[DetailedTraceModules],
|
||||
self.collect_detailed_traces[0].split(","))
|
||||
39
vllm/config/speech_to_text.py
Normal file
39
vllm/config/speech_to_text.py
Normal file
@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class SpeechToTextConfig:
|
||||
"""Configuration for speech-to-text models."""
|
||||
|
||||
sample_rate: float = 16_000
|
||||
"""Sample rate (Hz) to resample input audio to. Most speech models expect
|
||||
16kHz audio input. The input audio will be automatically resampled to this
|
||||
rate before processing."""
|
||||
|
||||
max_audio_clip_s: int = 30
|
||||
"""Maximum duration in seconds for a single audio clip without chunking.
|
||||
Audio longer than this will be split into smaller chunks if
|
||||
`allow_audio_chunking` evaluates to True, otherwise it will be rejected."""
|
||||
|
||||
overlap_chunk_second: int = 1
|
||||
"""Overlap duration in seconds between consecutive audio chunks when
|
||||
splitting long audio. This helps maintain context across chunk boundaries
|
||||
and improves transcription quality at split points."""
|
||||
|
||||
min_energy_split_window_size: Optional[int] = 1600
|
||||
"""Window size in samples for finding low-energy (quiet) regions to split
|
||||
audio chunks. The algorithm looks for the quietest moment within this
|
||||
window to minimize cutting through speech. Default 1600 samples ≈ 100ms
|
||||
at 16kHz. If None, no chunking will be done."""
|
||||
|
||||
@property
|
||||
def allow_audio_chunking(self) -> bool:
|
||||
return self.min_energy_split_window_size is not None
|
||||
Loading…
x
Reference in New Issue
Block a user