fix mypy for core and engine

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
yewentao256 2025-10-17 08:06:28 -07:00
parent 2ba60ec7fe
commit 8244ff7fee
11 changed files with 98 additions and 59 deletions

View File

@ -36,12 +36,15 @@ FILES = [
"vllm/transformers_utils", "vllm/transformers_utils",
"vllm/triton_utils", "vllm/triton_utils",
"vllm/usage", "vllm/usage",
"vllm/v1/core",
"vllm/v1/engine",
] ]
# After fixing errors resulting from changing follow_imports # After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES # from "skip" to "silent", move the following directories to FILES
SEPARATE_GROUPS = [ SEPARATE_GROUPS = [
"tests", "tests",
# v0 related
"vllm/attention", "vllm/attention",
"vllm/compilation", "vllm/compilation",
"vllm/engine", "vllm/engine",
@ -50,7 +53,16 @@ SEPARATE_GROUPS = [
"vllm/model_executor", "vllm/model_executor",
"vllm/plugins", "vllm/plugins",
"vllm/worker", "vllm/worker",
"vllm/v1", # v1 related
"vllm/v1/attention",
"vllm/v1/executor",
"vllm/v1/kv_offload",
"vllm/v1/metrics",
"vllm/v1/pool",
"vllm/v1/sample",
"vllm/v1/spec_decode",
"vllm/v1/structured_output",
"vllm/v1/worker",
] ]
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.

View File

@ -72,6 +72,7 @@ class EngineClient(ABC):
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.""" """Generate outputs for a request from a pooling model."""

View File

@ -165,7 +165,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching, enable_caching=bool(self.cache_config.enable_prefix_caching),
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
log_stats=self.log_stats, log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,
@ -392,7 +392,7 @@ class Scheduler(SchedulerInterface):
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
continue continue
num_external_computed_tokens = 0 num_external_computed_tokens: int | None = 0
load_kv_async = False load_kv_async = False
# Get already-cached tokens. # Get already-cached tokens.
@ -419,8 +419,8 @@ class Scheduler(SchedulerInterface):
continue continue
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = ( num_computed_tokens = num_new_local_computed_tokens + (
num_new_local_computed_tokens + num_external_computed_tokens num_external_computed_tokens or 0
) )
# KVTransfer: WAITING reqs have num_computed_tokens > 0 # KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed. # after async KV recvs are completed.
@ -434,6 +434,7 @@ class Scheduler(SchedulerInterface):
# KVTransfer: loading remote KV, do not allocate for new work. # KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async: if load_kv_async:
assert isinstance(num_external_computed_tokens, int)
assert num_external_computed_tokens > 0 assert num_external_computed_tokens > 0
num_new_tokens = 0 num_new_tokens = 0
# Number of tokens to be scheduled. # Number of tokens to be scheduled.
@ -503,7 +504,7 @@ class Scheduler(SchedulerInterface):
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, request,
num_new_tokens + num_external_computed_tokens, num_new_tokens + (num_external_computed_tokens or 0),
num_new_local_computed_tokens, num_new_local_computed_tokens,
new_computed_blocks, new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens, num_lookahead_tokens=effective_lookahead_tokens,
@ -523,7 +524,7 @@ class Scheduler(SchedulerInterface):
self.connector.update_state_after_alloc( self.connector.update_state_after_alloc(
request, request,
new_computed_blocks + new_blocks, new_computed_blocks + new_blocks,
num_external_computed_tokens, num_external_computed_tokens or 0,
) )
# Request was already popped from self.waiting # Request was already popped from self.waiting
@ -916,13 +917,13 @@ class Scheduler(SchedulerInterface):
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats = ( kv_connector_stats: KVConnectorStats | None = (
kv_connector_output.kv_connector_stats if kv_connector_output else None kv_connector_output.kv_connector_stats if kv_connector_output else None
) )
if kv_connector_stats and self.connector: if kv_connector_stats and self.connector:
stats = self.connector.get_kv_connector_stats() kv_stats = self.connector.get_kv_connector_stats()
if stats: if kv_stats:
kv_connector_stats = kv_connector_stats.aggregate(stats) kv_connector_stats = kv_connector_stats.aggregate(kv_stats)
failed_kv_load_req_ids = None failed_kv_load_req_ids = None
if kv_connector_output and kv_connector_output.invalid_block_ids: if kv_connector_output and kv_connector_output.invalid_block_ids:

View File

@ -6,7 +6,7 @@ import socket
import time import time
from collections.abc import AsyncGenerator, Iterable, Mapping from collections.abc import AsyncGenerator, Iterable, Mapping
from copy import copy from copy import copy
from typing import Any from typing import Any, cast
import numpy as np import numpy as np
import torch import torch
@ -122,10 +122,9 @@ class AsyncLLM(EngineClient):
self.output_processor = OutputProcessor( self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats self.tokenizer, log_stats=self.log_stats
) )
if self.observability_config.otlp_traces_endpoint is not None: endpoint = getattr(self.observability_config, "otlp_traces_endpoint", None)
tracer = init_tracer( if endpoint is not None:
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint tracer = init_tracer("vllm.llm_engine", endpoint)
)
self.output_processor.tracer = tracer self.output_processor.tracer = tracer
# EngineCore (starts the engine in background process). # EngineCore (starts the engine in background process).
@ -257,7 +256,9 @@ class AsyncLLM(EngineClient):
if engine_core := getattr(self, "engine_core", None): if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown() engine_core.shutdown()
cancel_task_threadsafe(getattr(self, "output_handler", None)) handler = getattr(self, "output_handler", None)
if handler is not None:
cancel_task_threadsafe(handler)
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async() return await self.engine_core.get_supported_tasks_async()
@ -305,7 +306,10 @@ class AsyncLLM(EngineClient):
priority, priority,
data_parallel_rank, data_parallel_rank,
) )
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
if is_pooling or params.n == 1: if is_pooling or params.n == 1:
await self._add_request(request, prompt_text, None, 0, queue) await self._add_request(request, prompt_text, None, 0, queue)
@ -427,6 +431,7 @@ class AsyncLLM(EngineClient):
# Note: both OutputProcessor and EngineCore handle their # Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished. # own request cleanup based on finished.
finished = out.finished finished = out.finished
assert isinstance(out, RequestOutput)
yield out yield out
# If the request is disconnected by the client, generate() # If the request is disconnected by the client, generate()

View File

@ -467,9 +467,10 @@ class EngineCore:
self, self,
tensorizer_config, tensorizer_config,
) -> None: ) -> None:
self.model_executor.save_tensorized_model( if hasattr(self.model_executor, "save_tensorized_model"):
tensorizer_config=tensorizer_config, self.model_executor.save_tensorized_model( # type: ignore[attr-defined]
) tensorizer_config=tensorizer_config,
)
def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]: def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]:
"""Preprocess the request. """Preprocess the request.
@ -1089,6 +1090,7 @@ class DPEngineCoreProc(EngineCoreProc):
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1 assert dp_size > 1
assert local_dp_rank is not None
assert 0 <= local_dp_rank <= dp_rank < dp_size assert 0 <= local_dp_rank <= dp_rank < dp_size
if vllm_config.kv_transfer_config is not None: if vllm_config.kv_transfer_config is not None:
@ -1235,7 +1237,8 @@ class DPEngineCoreProc(EngineCoreProc):
parallel_config.data_parallel_master_port parallel_config.data_parallel_master_port
) )
self.model_executor.reinitialize_distributed(reconfig_request) if hasattr(self.model_executor, "reinitialize_distributed"):
self.model_executor.reinitialize_distributed(reconfig_request) # type: ignore[attr-defined]
if reconfig_request.new_data_parallel_size > old_dp_size: if reconfig_request.new_data_parallel_size > old_dp_size:
assert self.available_gpu_memory_for_kv_cache > 0 assert self.available_gpu_memory_for_kv_cache > 0
# pass available_gpu_memory_for_kv_cache from existing # pass available_gpu_memory_for_kv_cache from existing

View File

@ -385,7 +385,7 @@ class BackgroundResources:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
task.cancel() task.cancel()
if in_loop(loop): if loop is not None and in_loop(loop):
close_sockets_and_tasks() close_sockets_and_tasks()
elif loop and not loop.is_closed(): elif loop and not loop.is_closed():
loop.call_soon_threadsafe(close_sockets_and_tasks) loop.call_soon_threadsafe(close_sockets_and_tasks)
@ -1044,6 +1044,7 @@ class DPAsyncMPClient(AsyncMPClient):
return return
assert self.stats_update_address is not None assert self.stats_update_address is not None
stats_addr: str = self.stats_update_address
assert len(self.engine_ranks_managed) > 0 assert len(self.engine_ranks_managed) > 0
# NOTE: running and waiting counts are all global from # NOTE: running and waiting counts are all global from
# the Coordinator include all global EngineCores. This # the Coordinator include all global EngineCores. This
@ -1054,9 +1055,7 @@ class DPAsyncMPClient(AsyncMPClient):
async def run_engine_stats_update_task(): async def run_engine_stats_update_task():
with ( with (
make_zmq_socket( make_zmq_socket(self.ctx, stats_addr, zmq.XSUB, linger=0) as socket,
self.ctx, self.stats_update_address, zmq.XSUB, linger=0
) as socket,
make_zmq_socket( make_zmq_socket(
self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0 self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0
) as first_req_rcv_socket, ) as first_req_rcv_socket,

View File

@ -69,14 +69,21 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
# Stop strings # Stop strings
params = request.sampling_params params = request.sampling_params
assert params is not None assert params is not None
self.stop = stop = params.stop stop_list: list[str]
if params.stop is None:
stop_list = []
elif isinstance(params.stop, str):
stop_list = [params.stop]
else:
stop_list = params.stop
self.stop = stop_list
self.min_tokens = params.min_tokens self.min_tokens = params.min_tokens
self.include_stop_str_in_output = params.include_stop_str_in_output self.include_stop_str_in_output = params.include_stop_str_in_output
# Number of chars to hold back when stop strings are to be excluded # Number of chars to hold back when stop strings are to be excluded
# from streamed output. # from streamed output.
if stop and not self.include_stop_str_in_output: if self.stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in stop) - 1 self.stop_buffer_length = max(len(s) for s in self.stop) - 1
else: else:
self.stop_buffer_length = 0 self.stop_buffer_length = 0
self._last_output_text_offset: int = 0 self._last_output_text_offset: int = 0

View File

@ -4,7 +4,7 @@
import time import time
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from copy import copy from copy import copy
from typing import Any from typing import Any, cast
import torch.nn as nn import torch.nn as nn
from typing_extensions import TypeVar from typing_extensions import TypeVar
@ -112,10 +112,9 @@ class LLMEngine:
self.output_processor = OutputProcessor( self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats self.tokenizer, log_stats=self.log_stats
) )
if self.observability_config.otlp_traces_endpoint is not None: endpoint = getattr(self.observability_config, "otlp_traces_endpoint", None)
tracer = init_tracer( if endpoint is not None:
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint tracer = init_tracer("vllm.llm_engine", endpoint)
)
self.output_processor.tracer = tracer self.output_processor.tracer = tracer
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
@ -259,7 +258,10 @@ class LLMEngine:
trace_headers, trace_headers,
priority, priority,
) )
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
n = params.n if isinstance(params, SamplingParams) else 1 n = params.n if isinstance(params, SamplingParams) else 1
@ -316,7 +318,14 @@ class LLMEngine:
) )
self.do_log_stats_with_interval() self.do_log_stats_with_interval()
return processed_outputs.request_outputs ro = processed_outputs.request_outputs
if not ro:
return []
first = ro[0]
if isinstance(first, RequestOutput):
return [x for x in ro if isinstance(x, RequestOutput)]
else:
return [x for x in ro if isinstance(x, PoolingRequestOutput)]
def start_profile(self): def start_profile(self):
self.engine_core.profile(True) self.engine_core.profile(True)

View File

@ -47,7 +47,14 @@ class RequestOutputCollector:
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)): elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)):
# This ensures that request outputs with different request indexes # This ensures that request outputs with different request indexes
# (if n > 1) do not override each other. # (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate) if isinstance(self.output, RequestOutput) and isinstance(
output, RequestOutput
):
self.output.add(output, aggregate=self.aggregate)
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
output, PoolingRequestOutput
):
self.output = output
async def get(self) -> RequestOutput | PoolingRequestOutput: async def get(self) -> RequestOutput | PoolingRequestOutput:
"""Get operation blocks on put event.""" """Get operation blocks on put event."""
@ -407,7 +414,7 @@ class OutputProcessor:
within the loop below. within the loop below.
""" """
request_outputs: list[RequestOutput] | list[PoolingRequestOutput] = [] request_outputs: list[RequestOutput | PoolingRequestOutput] = []
reqs_to_abort: list[str] = [] reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs: for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id req_id = engine_core_output.request_id

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import copy from copy import copy
from typing import Optional from typing import Optional, cast
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
@ -37,7 +37,7 @@ class ParentRequest:
self.child_requests = set() self.child_requests = set()
self.output_aggregator = ( self.output_aggregator = (
[None] * sampling_params.n [cast(CompletionOutput, None)] * sampling_params.n
if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY) if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY)
else [] else []
) )

View File

@ -3,7 +3,7 @@
import time import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Literal from typing import Any, Literal, cast
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
@ -208,9 +208,9 @@ class Processor:
enc = prompt.get("encoder_prompt") enc = prompt.get("encoder_prompt")
dec = prompt.get("decoder_prompt") dec = prompt.get("decoder_prompt")
if enc is not None: if enc is not None:
_validate_single_prompt(enc) _validate_single_prompt(cast(dict | str, enc))
if dec is not None: if dec is not None:
_validate_single_prompt(dec) _validate_single_prompt(cast(dict | str, dec))
else: else:
_validate_single_prompt(prompt) # type: ignore[arg-type] _validate_single_prompt(prompt) # type: ignore[arg-type]
@ -332,7 +332,7 @@ class Processor:
if not mm_data: if not mm_data:
return None return None
mm_uuids: MultiModalUUIDDict = {} mm_uuids: dict[str, list[str | None] | str] = {}
for modality, data in mm_data.items(): for modality, data in mm_data.items():
n = len(data) if isinstance(data, list) else 1 n = len(data) if isinstance(data, list) else 1
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
@ -384,7 +384,9 @@ class Processor:
# if provided. # if provided.
self._validate_multi_modal_uuids(prompt) self._validate_multi_modal_uuids(prompt)
if isinstance(prompt, dict): if isinstance(prompt, dict):
mm_uuids = prompt.get("multi_modal_uuids") mm_uuids = cast(
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
)
else: else:
mm_uuids = None mm_uuids = None
@ -410,20 +412,13 @@ class Processor:
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs) self._validate_model_inputs(encoder_inputs, decoder_inputs)
# Mypy does not always properly infer the types of some elements of # Mypy can be conservative for TypedDict unions; normalize access.
# discriminated unions of TypedDicts, because of how it handles if decoder_inputs["type"] == "embeds":
# inheritance of TypedDict. If we explicitly extract the items we want prompt_token_ids = None
# we can avoid type errors from using `dict.get` later in the method. prompt_embeds = decoder_inputs["prompt_embeds"]
prompt_token_ids = ( else:
decoder_inputs["prompt_token_ids"] prompt_token_ids = decoder_inputs["prompt_token_ids"]
if decoder_inputs["type"] != "embeds" prompt_embeds = None
else None
)
prompt_embeds = (
decoder_inputs["prompt_embeds"]
if decoder_inputs["type"] == "embeds"
else None
)
sampling_params = None sampling_params = None
pooling_params = None pooling_params = None