mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[mypy] Enable type checking for more directories (#29674)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
9eec282cb5
commit
9e6bcda3ac
@ -27,19 +27,24 @@ FILES = [
|
|||||||
"vllm/*.py",
|
"vllm/*.py",
|
||||||
"vllm/assets",
|
"vllm/assets",
|
||||||
"vllm/distributed",
|
"vllm/distributed",
|
||||||
|
"vllm/engine",
|
||||||
"vllm/entrypoints",
|
"vllm/entrypoints",
|
||||||
"vllm/executor",
|
"vllm/executor",
|
||||||
"vllm/inputs",
|
"vllm/inputs",
|
||||||
"vllm/logging_utils",
|
"vllm/logging_utils",
|
||||||
"vllm/multimodal",
|
"vllm/multimodal",
|
||||||
"vllm/platforms",
|
"vllm/platforms",
|
||||||
|
"vllm/plugins",
|
||||||
"vllm/transformers_utils",
|
"vllm/transformers_utils",
|
||||||
"vllm/triton_utils",
|
"vllm/triton_utils",
|
||||||
"vllm/usage",
|
"vllm/usage",
|
||||||
"vllm/utils",
|
"vllm/utils",
|
||||||
|
"vllm/worker",
|
||||||
"vllm/v1/core",
|
"vllm/v1/core",
|
||||||
"vllm/v1/engine",
|
"vllm/v1/engine",
|
||||||
|
"vllm/v1/metrics",
|
||||||
"vllm/v1/pool",
|
"vllm/v1/pool",
|
||||||
|
"vllm/v1/sample",
|
||||||
"vllm/v1/worker",
|
"vllm/v1/worker",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -50,24 +55,19 @@ SEPARATE_GROUPS = [
|
|||||||
# v0 related
|
# v0 related
|
||||||
"vllm/attention",
|
"vllm/attention",
|
||||||
"vllm/compilation",
|
"vllm/compilation",
|
||||||
"vllm/engine",
|
|
||||||
"vllm/inputs",
|
|
||||||
"vllm/lora",
|
"vllm/lora",
|
||||||
"vllm/model_executor",
|
"vllm/model_executor",
|
||||||
"vllm/plugins",
|
|
||||||
"vllm/worker",
|
|
||||||
# v1 related
|
# v1 related
|
||||||
"vllm/v1/attention",
|
"vllm/v1/attention",
|
||||||
"vllm/v1/executor",
|
"vllm/v1/executor",
|
||||||
"vllm/v1/kv_offload",
|
"vllm/v1/kv_offload",
|
||||||
"vllm/v1/metrics",
|
|
||||||
"vllm/v1/sample",
|
|
||||||
"vllm/v1/spec_decode",
|
"vllm/v1/spec_decode",
|
||||||
"vllm/v1/structured_output",
|
"vllm/v1/structured_output",
|
||||||
]
|
]
|
||||||
|
|
||||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||||
EXCLUDE = [
|
EXCLUDE = [
|
||||||
|
"vllm/engine/arg_utils.py",
|
||||||
"vllm/model_executor/parallel_utils",
|
"vllm/model_executor/parallel_utils",
|
||||||
"vllm/model_executor/models",
|
"vllm/model_executor/models",
|
||||||
"vllm/model_executor/layers/fla/ops",
|
"vllm/model_executor/layers/fla/ops",
|
||||||
|
|||||||
@ -565,7 +565,7 @@ class KVConnectorBase_V1(ABC):
|
|||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||||
labelnames: list[str],
|
labelnames: list[str],
|
||||||
per_engine_labelvalues: dict[int, list[str]],
|
per_engine_labelvalues: dict[int, list[object]],
|
||||||
) -> Optional["KVConnectorPromMetrics"]:
|
) -> Optional["KVConnectorPromMetrics"]:
|
||||||
"""
|
"""
|
||||||
Create a KVConnectorPromMetrics subclass which should register
|
Create a KVConnectorPromMetrics subclass which should register
|
||||||
|
|||||||
@ -806,7 +806,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
|||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||||
labelnames: list[str],
|
labelnames: list[str],
|
||||||
per_engine_labelvalues: dict[int, list[str]],
|
per_engine_labelvalues: dict[int, list[object]],
|
||||||
) -> Optional["KVConnectorPromMetrics"]:
|
) -> Optional["KVConnectorPromMetrics"]:
|
||||||
"""
|
"""
|
||||||
Create a KVConnectorPromMetrics subclass which should register
|
Create a KVConnectorPromMetrics subclass which should register
|
||||||
|
|||||||
@ -52,13 +52,13 @@ class KVConnectorStats:
|
|||||||
|
|
||||||
|
|
||||||
class KVConnectorLogging:
|
class KVConnectorLogging:
|
||||||
def __init__(self, kv_tranfer_config: KVTransferConfig):
|
def __init__(self, kv_transfer_config: KVTransferConfig | None):
|
||||||
# This should be called on frontend process.
|
# This should be called on frontend process.
|
||||||
assert not has_kv_transfer_group()
|
assert not has_kv_transfer_group()
|
||||||
# Instantiate the connector's stats class.
|
# Instantiate the connector's stats class.
|
||||||
if kv_tranfer_config and kv_tranfer_config.kv_connector:
|
if kv_transfer_config and kv_transfer_config.kv_connector:
|
||||||
self.connector_cls = KVConnectorFactory.get_connector_class(
|
self.connector_cls = KVConnectorFactory.get_connector_class(
|
||||||
kv_tranfer_config
|
kv_transfer_config
|
||||||
)
|
)
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@ -120,7 +120,7 @@ class KVConnectorPromMetrics:
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||||
labelnames: list[str],
|
labelnames: list[str],
|
||||||
per_engine_labelvalues: dict[int, list[str]],
|
per_engine_labelvalues: dict[int, list[object]],
|
||||||
):
|
):
|
||||||
self._kv_transfer_config = vllm_config.kv_transfer_config
|
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
self._gauge_cls = metric_types[Gauge]
|
self._gauge_cls = metric_types[Gauge]
|
||||||
@ -129,7 +129,7 @@ class KVConnectorPromMetrics:
|
|||||||
self._labelnames = labelnames
|
self._labelnames = labelnames
|
||||||
self._per_engine_labelvalues = per_engine_labelvalues
|
self._per_engine_labelvalues = per_engine_labelvalues
|
||||||
|
|
||||||
def make_per_engine(self, metric: PromMetric) -> PromMetric:
|
def make_per_engine(self, metric: PromMetric) -> dict[int, PromMetric]:
|
||||||
"""
|
"""
|
||||||
Create a per-engine child of a prometheus_client.Metric with
|
Create a per-engine child of a prometheus_client.Metric with
|
||||||
the appropriate labels set. The parent metric must be created
|
the appropriate labels set. The parent metric must be created
|
||||||
@ -165,7 +165,7 @@ class KVConnectorPrometheus:
|
|||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
labelnames: list[str],
|
labelnames: list[str],
|
||||||
per_engine_labelvalues: dict[int, list[str]],
|
per_engine_labelvalues: dict[int, list[object]],
|
||||||
):
|
):
|
||||||
self.prom_metrics: KVConnectorPromMetrics | None = None
|
self.prom_metrics: KVConnectorPromMetrics | None = None
|
||||||
kv_transfer_config = vllm_config.kv_transfer_config
|
kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
|
|||||||
@ -85,7 +85,7 @@ class MultiKVConnectorPromMetrics(KVConnectorPromMetrics):
|
|||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||||
labelnames: list[str],
|
labelnames: list[str],
|
||||||
per_engine_labelvalues: dict[int, list[str]],
|
per_engine_labelvalues: dict[int, list[object]],
|
||||||
prom_metrics: dict[str, KVConnectorPromMetrics],
|
prom_metrics: dict[str, KVConnectorPromMetrics],
|
||||||
):
|
):
|
||||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||||
@ -434,7 +434,7 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
vllm_config: "VllmConfig",
|
vllm_config: "VllmConfig",
|
||||||
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
|
||||||
labelnames: list[str],
|
labelnames: list[str],
|
||||||
per_engine_labelvalues: dict[int, list[str]],
|
per_engine_labelvalues: dict[int, list[object]],
|
||||||
) -> KVConnectorPromMetrics:
|
) -> KVConnectorPromMetrics:
|
||||||
prom_metrics: dict[str, KVConnectorPromMetrics] = {}
|
prom_metrics: dict[str, KVConnectorPromMetrics] = {}
|
||||||
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
|
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
|
||||||
|
|||||||
@ -288,7 +288,7 @@ class NixlConnector(KVConnectorBase_V1):
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||||
labelnames: list[str],
|
labelnames: list[str],
|
||||||
per_engine_labelvalues: dict[int, list[str]],
|
per_engine_labelvalues: dict[int, list[object]],
|
||||||
) -> KVConnectorPromMetrics:
|
) -> KVConnectorPromMetrics:
|
||||||
return NixlPromMetrics(
|
return NixlPromMetrics(
|
||||||
vllm_config, metric_types, labelnames, per_engine_labelvalues
|
vllm_config, metric_types, labelnames, per_engine_labelvalues
|
||||||
@ -2345,9 +2345,9 @@ class NixlKVConnectorStats(KVConnectorStats):
|
|||||||
return {
|
return {
|
||||||
"Num successful transfers": n,
|
"Num successful transfers": n,
|
||||||
"Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3),
|
"Avg xfer time (ms)": round(xfer_time.mean() * 1e3, 3),
|
||||||
"P90 xfer time (ms)": round(np.percentile(xfer_time, 90) * 1e3, 3),
|
"P90 xfer time (ms)": round(np.percentile(xfer_time, 90).item() * 1e3, 3),
|
||||||
"Avg post time (ms)": round(post_time.mean() * 1e3, 3),
|
"Avg post time (ms)": round(post_time.mean() * 1e3, 3),
|
||||||
"P90 post time (ms)": round(np.percentile(post_time, 90) * 1e3, 3),
|
"P90 post time (ms)": round(np.percentile(post_time, 90).item() * 1e3, 3),
|
||||||
"Avg MB per transfer": round(avg_mb, 3),
|
"Avg MB per transfer": round(avg_mb, 3),
|
||||||
"Throughput (MB/s)": round(throughput_mb_s, 3),
|
"Throughput (MB/s)": round(throughput_mb_s, 3),
|
||||||
"Avg number of descriptors": round(descs.mean(), 1),
|
"Avg number of descriptors": round(descs.mean(), 1),
|
||||||
@ -2364,7 +2364,7 @@ class NixlPromMetrics(KVConnectorPromMetrics):
|
|||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
metric_types: dict[type[PromMetric], type[PromMetricT]],
|
||||||
labelnames: list[str],
|
labelnames: list[str],
|
||||||
per_engine_labelvalues: dict[int, list[str]],
|
per_engine_labelvalues: dict[int, list[object]],
|
||||||
):
|
):
|
||||||
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
|
||||||
|
|
||||||
|
|||||||
@ -1954,7 +1954,9 @@ class EngineArgs:
|
|||||||
self.enable_prefix_caching = False
|
self.enable_prefix_caching = False
|
||||||
|
|
||||||
def _set_default_max_num_seqs_and_batched_tokens_args(
|
def _set_default_max_num_seqs_and_batched_tokens_args(
|
||||||
self, usage_context: UsageContext, model_config: ModelConfig
|
self,
|
||||||
|
usage_context: UsageContext | None,
|
||||||
|
model_config: ModelConfig,
|
||||||
):
|
):
|
||||||
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||||
(
|
(
|
||||||
|
|||||||
@ -614,12 +614,12 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
|||||||
|
|
||||||
def maybe_override_with_speculators(
|
def maybe_override_with_speculators(
|
||||||
model: str,
|
model: str,
|
||||||
tokenizer: str,
|
tokenizer: str | None,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
revision: str | None = None,
|
revision: str | None = None,
|
||||||
vllm_speculative_config: dict[str, Any] | None = None,
|
vllm_speculative_config: dict[str, Any] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> tuple[str, str, dict[str, Any] | None]:
|
) -> tuple[str, str | None, dict[str, Any] | None]:
|
||||||
"""
|
"""
|
||||||
Resolve model configuration when speculators are detected.
|
Resolve model configuration when speculators are detected.
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from vllm.triton_utils.importing import (
|
from vllm.triton_utils.importing import (
|
||||||
HAS_TRITON,
|
HAS_TRITON,
|
||||||
@ -7,7 +8,7 @@ from vllm.triton_utils.importing import (
|
|||||||
TritonPlaceholder,
|
TritonPlaceholder,
|
||||||
)
|
)
|
||||||
|
|
||||||
if HAS_TRITON:
|
if TYPE_CHECKING or HAS_TRITON:
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
import triton.language.extra.libdevice as tldevice
|
import triton.language.extra.libdevice as tldevice
|
||||||
|
|||||||
@ -104,8 +104,8 @@ class LoggingStatLogger(StatLoggerBase):
|
|||||||
self.mm_caching_metrics = CachingMetrics()
|
self.mm_caching_metrics = CachingMetrics()
|
||||||
|
|
||||||
self.spec_decoding_logging = SpecDecodingLogging()
|
self.spec_decoding_logging = SpecDecodingLogging()
|
||||||
kv_tranfer_config = self.vllm_config.kv_transfer_config
|
kv_transfer_config = self.vllm_config.kv_transfer_config
|
||||||
self.kv_connector_logging = KVConnectorLogging(kv_tranfer_config)
|
self.kv_connector_logging = KVConnectorLogging(kv_transfer_config)
|
||||||
self.last_prompt_throughput: float = 0.0
|
self.last_prompt_throughput: float = 0.0
|
||||||
self.last_generation_throughput: float = 0.0
|
self.last_generation_throughput: float = 0.0
|
||||||
self.engine_is_idle = False
|
self.engine_is_idle = False
|
||||||
@ -380,7 +380,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase):
|
|||||||
model_name = vllm_config.model_config.served_model_name
|
model_name = vllm_config.model_config.served_model_name
|
||||||
max_model_len = vllm_config.model_config.max_model_len
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
|
|
||||||
per_engine_labelvalues: dict[int, list[str]] = {
|
per_engine_labelvalues: dict[int, list[object]] = {
|
||||||
idx: [model_name, str(idx)] for idx in engine_indexes
|
idx: [model_name, str(idx)] for idx in engine_indexes
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1052,7 +1052,7 @@ PromMetric: TypeAlias = Gauge | Counter | Histogram
|
|||||||
|
|
||||||
|
|
||||||
def make_per_engine(
|
def make_per_engine(
|
||||||
metric: PromMetric, engine_idxs: list[int], model_name: str
|
metric: PromMetric, engine_idxs: list[int], model_name: object
|
||||||
) -> dict[int, PromMetric]:
|
) -> dict[int, PromMetric]:
|
||||||
return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs}
|
return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs}
|
||||||
|
|
||||||
|
|||||||
@ -313,7 +313,7 @@ class AdapterLogitsProcessor(LogitsProcessor):
|
|||||||
if (len(inspect.signature(req_lp).parameters) == 3)
|
if (len(inspect.signature(req_lp).parameters) == 3)
|
||||||
else [output_ids]
|
else [output_ids]
|
||||||
)
|
)
|
||||||
return partial(req_lp, *args)
|
return partial(req_lp, *args) # type: ignore[misc]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def update_state(self, batch_update: BatchUpdate | None):
|
def update_state(self, batch_update: BatchUpdate | None):
|
||||||
|
|||||||
@ -144,7 +144,7 @@ class SpecDecodingProm:
|
|||||||
self,
|
self,
|
||||||
speculative_config: SpeculativeConfig | None,
|
speculative_config: SpeculativeConfig | None,
|
||||||
labelnames: list[str],
|
labelnames: list[str],
|
||||||
per_engine_labelvalues: dict[int, list[str]],
|
per_engine_labelvalues: dict[int, list[object]],
|
||||||
):
|
):
|
||||||
self.spec_decoding_enabled = speculative_config is not None
|
self.spec_decoding_enabled = speculative_config is not None
|
||||||
if not self.spec_decoding_enabled:
|
if not self.spec_decoding_enabled:
|
||||||
@ -215,7 +215,8 @@ class SpecDecodingProm:
|
|||||||
|
|
||||||
|
|
||||||
def make_per_engine(
|
def make_per_engine(
|
||||||
counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[str]]
|
counter: prometheus_client.Counter,
|
||||||
|
per_engine_labelvalues: dict[int, list[object]],
|
||||||
):
|
):
|
||||||
"""Create a counter for each label value."""
|
"""Create a counter for each label value."""
|
||||||
return {
|
return {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user