[CI] Fix mypy for vllm/v1/core and vllm/v1/engine (#27108)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-10-30 07:32:17 -04:00 committed by GitHub
parent c7d2a554ba
commit c01f6e525f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 91 additions and 61 deletions

View File

@ -36,12 +36,15 @@ FILES = [
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
"vllm/v1/core",
"vllm/v1/engine",
]
# After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES
SEPARATE_GROUPS = [
"tests",
# v0 related
"vllm/attention",
"vllm/compilation",
"vllm/engine",
@ -50,7 +53,16 @@ SEPARATE_GROUPS = [
"vllm/model_executor",
"vllm/plugins",
"vllm/worker",
"vllm/v1",
# v1 related
"vllm/v1/attention",
"vllm/v1/executor",
"vllm/v1/kv_offload",
"vllm/v1/metrics",
"vllm/v1/pool",
"vllm/v1/sample",
"vllm/v1/spec_decode",
"vllm/v1/structured_output",
"vllm/v1/worker",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.

View File

@ -84,7 +84,9 @@ class VllmConfig:
default_factory=StructuredOutputsConfig
)
"""Structured outputs configuration."""
observability_config: ObservabilityConfig | None = None
observability_config: ObservabilityConfig = Field(
default_factory=ObservabilityConfig
)
"""Observability configuration."""
quant_config: QuantizationConfig | None = None
"""Quantization configuration."""
@ -170,10 +172,7 @@ class VllmConfig:
vllm_factors.append(self.structured_outputs_config.compute_hash())
else:
vllm_factors.append("None")
if self.observability_config:
vllm_factors.append(self.observability_config.compute_hash())
else:
vllm_factors.append("None")
vllm_factors.append(self.observability_config.compute_hash())
if self.quant_config:
pass # should be captured by model_config.quantization
if self.compilation_config:

View File

@ -77,6 +77,7 @@ class EngineClient(ABC):
lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None,
priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model."""

View File

@ -167,7 +167,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching,
enable_caching=bool(self.cache_config.enable_prefix_caching),
use_eagle=self.use_eagle,
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
@ -407,13 +407,13 @@ class Scheduler(SchedulerInterface):
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
ext_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens
)
)
if num_external_computed_tokens is None:
if ext_tokens is None:
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
@ -421,6 +421,8 @@ class Scheduler(SchedulerInterface):
skipped_waiting_requests.prepend_request(request)
continue
num_external_computed_tokens = ext_tokens
# Total computed tokens (local + external).
num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens
@ -905,13 +907,13 @@ class Scheduler(SchedulerInterface):
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats = (
kv_connector_stats: KVConnectorStats | None = (
kv_connector_output.kv_connector_stats if kv_connector_output else None
)
if kv_connector_stats and self.connector:
stats = self.connector.get_kv_connector_stats()
if stats:
kv_connector_stats = kv_connector_stats.aggregate(stats)
kv_stats = self.connector.get_kv_connector_stats()
if kv_stats:
kv_connector_stats = kv_connector_stats.aggregate(kv_stats)
failed_kv_load_req_ids = None
if kv_connector_output and kv_connector_output.invalid_block_ids:

View File

@ -6,7 +6,7 @@ import socket
import time
from collections.abc import AsyncGenerator, Iterable, Mapping
from copy import copy
from typing import Any
from typing import Any, cast
import numpy as np
import torch
@ -131,10 +131,9 @@ class AsyncLLM(EngineClient):
self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats
)
if self.observability_config.otlp_traces_endpoint is not None:
tracer = init_tracer(
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
tracer = init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracer = tracer
# EngineCore (starts the engine in background process).
@ -266,7 +265,9 @@ class AsyncLLM(EngineClient):
if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown()
cancel_task_threadsafe(getattr(self, "output_handler", None))
handler = getattr(self, "output_handler", None)
if handler is not None:
cancel_task_threadsafe(handler)
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async()
@ -314,7 +315,10 @@ class AsyncLLM(EngineClient):
priority,
data_parallel_rank,
)
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
if is_pooling or params.n == 1:
await self._add_request(request, prompt_text, None, 0, queue)
@ -436,6 +440,7 @@ class AsyncLLM(EngineClient):
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
finished = out.finished
assert isinstance(out, RequestOutput)
yield out
# If the request is disconnected by the client, generate()
@ -653,7 +658,7 @@ class AsyncLLM(EngineClient):
return self.tokenizer
async def is_tracing_enabled(self) -> bool:
return self.observability_config.otlp_traces_endpoint is not None
return self.observability_config.otlp_traces_endpoint is not None # type: ignore
async def do_log_stats(self) -> None:
if self.logger_manager:

View File

@ -1075,6 +1075,7 @@ class DPEngineCoreProc(EngineCoreProc):
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
assert local_dp_rank is not None
assert 0 <= local_dp_rank <= dp_rank < dp_size
if vllm_config.kv_transfer_config is not None:

View File

@ -385,10 +385,11 @@ class BackgroundResources:
with contextlib.suppress(Exception):
task.cancel()
if in_loop(loop):
close_sockets_and_tasks()
elif loop and not loop.is_closed():
loop.call_soon_threadsafe(close_sockets_and_tasks)
if loop is not None:
if in_loop(loop):
close_sockets_and_tasks()
elif not loop.is_closed():
loop.call_soon_threadsafe(close_sockets_and_tasks)
else:
# Loop has been closed, try to clean up directly.
del tasks
@ -1044,6 +1045,7 @@ class DPAsyncMPClient(AsyncMPClient):
return
assert self.stats_update_address is not None
stats_addr: str = self.stats_update_address
assert len(self.engine_ranks_managed) > 0
# NOTE: running and waiting counts are all global from
# the Coordinator include all global EngineCores. This
@ -1054,9 +1056,7 @@ class DPAsyncMPClient(AsyncMPClient):
async def run_engine_stats_update_task():
with (
make_zmq_socket(
self.ctx, self.stats_update_address, zmq.XSUB, linger=0
) as socket,
make_zmq_socket(self.ctx, stats_addr, zmq.XSUB, linger=0) as socket,
make_zmq_socket(
self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0
) as first_req_rcv_socket,

View File

@ -69,14 +69,21 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
# Stop strings
params = request.sampling_params
assert params is not None
self.stop = stop = params.stop
stop_list: list[str]
if params.stop is None:
stop_list = []
elif isinstance(params.stop, str):
stop_list = [params.stop]
else:
stop_list = params.stop
self.stop = stop_list
self.min_tokens = params.min_tokens
self.include_stop_str_in_output = params.include_stop_str_in_output
# Number of chars to hold back when stop strings are to be excluded
# from streamed output.
if stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in stop) - 1
if self.stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in self.stop) - 1
else:
self.stop_buffer_length = 0
self._last_output_text_offset: int = 0

View File

@ -4,7 +4,7 @@
import time
from collections.abc import Callable, Mapping
from copy import copy
from typing import Any
from typing import Any, cast
import torch.nn as nn
from typing_extensions import TypeVar
@ -112,10 +112,9 @@ class LLMEngine:
self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats
)
if self.observability_config.otlp_traces_endpoint is not None:
tracer = init_tracer(
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint
)
endpoint = self.observability_config.otlp_traces_endpoint
if endpoint is not None:
tracer = init_tracer("vllm.llm_engine", endpoint)
self.output_processor.tracer = tracer
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
@ -259,7 +258,10 @@ class LLMEngine:
trace_headers,
priority,
)
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt")
if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
n = params.n if isinstance(params, SamplingParams) else 1
@ -285,7 +287,7 @@ class LLMEngine:
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput] | list[PoolingRequestOutput]:
def step(self) -> list[RequestOutput | PoolingRequestOutput]:
if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False
self.engine_core.execute_dummy_batch()

View File

@ -44,10 +44,16 @@ class RequestOutputCollector:
if self.output is None or isinstance(output, Exception):
self.output = output
self.ready.set()
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
elif isinstance(self.output, RequestOutput) and isinstance(
output, RequestOutput
):
# This ensures that request outputs with different request indexes
# (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate)
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
output, PoolingRequestOutput
):
self.output = output
async def get(self) -> RequestOutput | PoolingRequestOutput:
"""Get operation blocks on put event."""
@ -408,7 +414,7 @@ class OutputProcessor:
within the loop below.
"""
request_outputs: list[RequestOutput] | list[PoolingRequestOutput] = []
request_outputs: list[RequestOutput | PoolingRequestOutput] = []
reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import copy
from typing import Optional
from typing import Optional, cast
from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
@ -37,7 +37,7 @@ class ParentRequest:
self.child_requests = set()
self.output_aggregator = (
[None] * sampling_params.n
[cast(CompletionOutput, None)] * sampling_params.n
if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY)
else []
)

View File

@ -3,7 +3,7 @@
import time
from collections.abc import Mapping
from typing import Any, Literal
from typing import Any, Literal, cast
from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
@ -208,9 +208,9 @@ class Processor:
enc = prompt.get("encoder_prompt")
dec = prompt.get("decoder_prompt")
if enc is not None:
_validate_single_prompt(enc)
_validate_single_prompt(cast(dict | str, enc))
if dec is not None:
_validate_single_prompt(dec)
_validate_single_prompt(cast(dict | str, dec))
else:
_validate_single_prompt(prompt) # type: ignore[arg-type]
@ -332,7 +332,7 @@ class Processor:
if not mm_data:
return None
mm_uuids: MultiModalUUIDDict = {}
mm_uuids: dict[str, list[str | None] | str] = {}
for modality, data in mm_data.items():
n = len(data) if isinstance(data, list) else 1
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
@ -384,7 +384,9 @@ class Processor:
# if provided.
self._validate_multi_modal_uuids(prompt)
if isinstance(prompt, dict):
mm_uuids = prompt.get("multi_modal_uuids")
mm_uuids = cast(
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
)
else:
mm_uuids = None
@ -410,20 +412,13 @@ class Processor:
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs)
# Mypy does not always properly infer the types of some elements of
# discriminated unions of TypedDicts, because of how it handles
# inheritance of TypedDict. If we explicitly extract the items we want
# we can avoid type errors from using `dict.get` later in the method.
prompt_token_ids = (
decoder_inputs["prompt_token_ids"]
if decoder_inputs["type"] != "embeds"
else None
)
prompt_embeds = (
decoder_inputs["prompt_embeds"]
if decoder_inputs["type"] == "embeds"
else None
)
# Mypy can be conservative for TypedDict unions; normalize access.
if decoder_inputs["type"] == "embeds":
prompt_token_ids = None
prompt_embeds = decoder_inputs["prompt_embeds"]
else:
prompt_token_ids = decoder_inputs["prompt_token_ids"]
prompt_embeds = None
sampling_params = None
pooling_params = None