diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index a3aa546347255..8d04848f8f780 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -36,12 +36,15 @@ FILES = [ "vllm/transformers_utils", "vllm/triton_utils", "vllm/usage", + "vllm/v1/core", + "vllm/v1/engine", ] # After fixing errors resulting from changing follow_imports # from "skip" to "silent", move the following directories to FILES SEPARATE_GROUPS = [ "tests", + # v0 related "vllm/attention", "vllm/compilation", "vllm/engine", @@ -50,7 +53,16 @@ SEPARATE_GROUPS = [ "vllm/model_executor", "vllm/plugins", "vllm/worker", - "vllm/v1", + # v1 related + "vllm/v1/attention", + "vllm/v1/executor", + "vllm/v1/kv_offload", + "vllm/v1/metrics", + "vllm/v1/pool", + "vllm/v1/sample", + "vllm/v1/spec_decode", + "vllm/v1/structured_output", + "vllm/v1/worker", ] # TODO(woosuk): Include the code from Megatron and HuggingFace. diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 20b8eb57f7438..6919bd2f2ff25 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -72,6 +72,7 @@ class EngineClient(ABC): lora_request: LoRARequest | None = None, trace_headers: Mapping[str, str] | None = None, priority: int = 0, + truncate_prompt_tokens: int | None = None, tokenization_kwargs: dict[str, Any] | None = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from a pooling model.""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 08368b7d99efe..06c0c8c942e70 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -165,7 +165,7 @@ class Scheduler(SchedulerInterface): self.kv_cache_manager = KVCacheManager( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, - enable_caching=self.cache_config.enable_prefix_caching, + enable_caching=bool(self.cache_config.enable_prefix_caching), use_eagle=self.use_eagle, log_stats=self.log_stats, enable_kv_cache_events=self.enable_kv_cache_events, @@ -392,7 +392,7 @@ class Scheduler(SchedulerInterface): skipped_waiting_requests.prepend_request(request) continue - num_external_computed_tokens = 0 + num_external_computed_tokens: int | None = 0 load_kv_async = False # Get already-cached tokens. @@ -419,8 +419,8 @@ class Scheduler(SchedulerInterface): continue # Total computed tokens (local + external). - num_computed_tokens = ( - num_new_local_computed_tokens + num_external_computed_tokens + num_computed_tokens = num_new_local_computed_tokens + ( + num_external_computed_tokens or 0 ) # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. @@ -434,6 +434,7 @@ class Scheduler(SchedulerInterface): # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: + assert isinstance(num_external_computed_tokens, int) assert num_external_computed_tokens > 0 num_new_tokens = 0 # Number of tokens to be scheduled. @@ -503,7 +504,7 @@ class Scheduler(SchedulerInterface): new_blocks = self.kv_cache_manager.allocate_slots( request, - num_new_tokens + num_external_computed_tokens, + num_new_tokens + (num_external_computed_tokens or 0), num_new_local_computed_tokens, new_computed_blocks, num_lookahead_tokens=effective_lookahead_tokens, @@ -523,7 +524,7 @@ class Scheduler(SchedulerInterface): self.connector.update_state_after_alloc( request, new_computed_blocks + new_blocks, - num_external_computed_tokens, + num_external_computed_tokens or 0, ) # Request was already popped from self.waiting @@ -916,13 +917,13 @@ class Scheduler(SchedulerInterface): outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: SpecDecodingStats | None = None - kv_connector_stats = ( + kv_connector_stats: KVConnectorStats | None = ( kv_connector_output.kv_connector_stats if kv_connector_output else None ) if kv_connector_stats and self.connector: - stats = self.connector.get_kv_connector_stats() - if stats: - kv_connector_stats = kv_connector_stats.aggregate(stats) + kv_stats = self.connector.get_kv_connector_stats() + if kv_stats: + kv_connector_stats = kv_connector_stats.aggregate(kv_stats) failed_kv_load_req_ids = None if kv_connector_output and kv_connector_output.invalid_block_ids: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 584956c1f0eb3..4e5c4cd6a4b67 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -6,7 +6,7 @@ import socket import time from collections.abc import AsyncGenerator, Iterable, Mapping from copy import copy -from typing import Any +from typing import Any, cast import numpy as np import torch @@ -122,10 +122,9 @@ class AsyncLLM(EngineClient): self.output_processor = OutputProcessor( self.tokenizer, log_stats=self.log_stats ) - if self.observability_config.otlp_traces_endpoint is not None: - tracer = init_tracer( - "vllm.llm_engine", self.observability_config.otlp_traces_endpoint - ) + endpoint = getattr(self.observability_config, "otlp_traces_endpoint", None) + if endpoint is not None: + tracer = init_tracer("vllm.llm_engine", endpoint) self.output_processor.tracer = tracer # EngineCore (starts the engine in background process). @@ -257,7 +256,9 @@ class AsyncLLM(EngineClient): if engine_core := getattr(self, "engine_core", None): engine_core.shutdown() - cancel_task_threadsafe(getattr(self, "output_handler", None)) + handler = getattr(self, "output_handler", None) + if handler is not None: + cancel_task_threadsafe(handler) async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: return await self.engine_core.get_supported_tasks_async() @@ -305,7 +306,10 @@ class AsyncLLM(EngineClient): priority, data_parallel_rank, ) - prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") + if isinstance(prompt, str): + prompt_text = prompt + elif isinstance(prompt, Mapping): + prompt_text = cast(str | None, prompt.get("prompt")) if is_pooling or params.n == 1: await self._add_request(request, prompt_text, None, 0, queue) @@ -427,6 +431,7 @@ class AsyncLLM(EngineClient): # Note: both OutputProcessor and EngineCore handle their # own request cleanup based on finished. finished = out.finished + assert isinstance(out, RequestOutput) yield out # If the request is disconnected by the client, generate() diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 2773dc61ff3d7..719896a397ccf 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -467,9 +467,10 @@ class EngineCore: self, tensorizer_config, ) -> None: - self.model_executor.save_tensorized_model( - tensorizer_config=tensorizer_config, - ) + if hasattr(self.model_executor, "save_tensorized_model"): + self.model_executor.save_tensorized_model( # type: ignore[attr-defined] + tensorizer_config=tensorizer_config, + ) def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, int]: """Preprocess the request. @@ -1089,6 +1090,7 @@ class DPEngineCoreProc(EngineCoreProc): local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local assert dp_size > 1 + assert local_dp_rank is not None assert 0 <= local_dp_rank <= dp_rank < dp_size if vllm_config.kv_transfer_config is not None: @@ -1235,7 +1237,8 @@ class DPEngineCoreProc(EngineCoreProc): parallel_config.data_parallel_master_port ) - self.model_executor.reinitialize_distributed(reconfig_request) + if hasattr(self.model_executor, "reinitialize_distributed"): + self.model_executor.reinitialize_distributed(reconfig_request) # type: ignore[attr-defined] if reconfig_request.new_data_parallel_size > old_dp_size: assert self.available_gpu_memory_for_kv_cache > 0 # pass available_gpu_memory_for_kv_cache from existing diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index f2e316a909706..57ff8d0b877cb 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -385,7 +385,7 @@ class BackgroundResources: with contextlib.suppress(Exception): task.cancel() - if in_loop(loop): + if loop is not None and in_loop(loop): close_sockets_and_tasks() elif loop and not loop.is_closed(): loop.call_soon_threadsafe(close_sockets_and_tasks) @@ -1044,6 +1044,7 @@ class DPAsyncMPClient(AsyncMPClient): return assert self.stats_update_address is not None + stats_addr: str = self.stats_update_address assert len(self.engine_ranks_managed) > 0 # NOTE: running and waiting counts are all global from # the Coordinator include all global EngineCores. This @@ -1054,9 +1055,7 @@ class DPAsyncMPClient(AsyncMPClient): async def run_engine_stats_update_task(): with ( - make_zmq_socket( - self.ctx, self.stats_update_address, zmq.XSUB, linger=0 - ) as socket, + make_zmq_socket(self.ctx, stats_addr, zmq.XSUB, linger=0) as socket, make_zmq_socket( self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0 ) as first_req_rcv_socket, diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 5f66e36893bf3..b7a24096bf15f 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -69,14 +69,21 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): # Stop strings params = request.sampling_params assert params is not None - self.stop = stop = params.stop + stop_list: list[str] + if params.stop is None: + stop_list = [] + elif isinstance(params.stop, str): + stop_list = [params.stop] + else: + stop_list = params.stop + self.stop = stop_list self.min_tokens = params.min_tokens self.include_stop_str_in_output = params.include_stop_str_in_output # Number of chars to hold back when stop strings are to be excluded # from streamed output. - if stop and not self.include_stop_str_in_output: - self.stop_buffer_length = max(len(s) for s in stop) - 1 + if self.stop and not self.include_stop_str_in_output: + self.stop_buffer_length = max(len(s) for s in self.stop) - 1 else: self.stop_buffer_length = 0 self._last_output_text_offset: int = 0 diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 538fb6a04bd7b..3646b986da71c 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -4,7 +4,7 @@ import time from collections.abc import Callable, Mapping from copy import copy -from typing import Any +from typing import Any, cast import torch.nn as nn from typing_extensions import TypeVar @@ -112,10 +112,9 @@ class LLMEngine: self.output_processor = OutputProcessor( self.tokenizer, log_stats=self.log_stats ) - if self.observability_config.otlp_traces_endpoint is not None: - tracer = init_tracer( - "vllm.llm_engine", self.observability_config.otlp_traces_endpoint - ) + endpoint = getattr(self.observability_config, "otlp_traces_endpoint", None) + if endpoint is not None: + tracer = init_tracer("vllm.llm_engine", endpoint) self.output_processor.tracer = tracer # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) @@ -259,7 +258,10 @@ class LLMEngine: trace_headers, priority, ) - prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") + if isinstance(prompt, str): + prompt_text = prompt + elif isinstance(prompt, Mapping): + prompt_text = cast(str | None, prompt.get("prompt")) n = params.n if isinstance(params, SamplingParams) else 1 @@ -316,7 +318,14 @@ class LLMEngine: ) self.do_log_stats_with_interval() - return processed_outputs.request_outputs + ro = processed_outputs.request_outputs + if not ro: + return [] + first = ro[0] + if isinstance(first, RequestOutput): + return [x for x in ro if isinstance(x, RequestOutput)] + else: + return [x for x in ro if isinstance(x, PoolingRequestOutput)] def start_profile(self): self.engine_core.profile(True) diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bc1542187c9b..741e99786dccb 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -47,7 +47,14 @@ class RequestOutputCollector: elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)): # This ensures that request outputs with different request indexes # (if n > 1) do not override each other. - self.output.add(output, aggregate=self.aggregate) + if isinstance(self.output, RequestOutput) and isinstance( + output, RequestOutput + ): + self.output.add(output, aggregate=self.aggregate) + elif isinstance(self.output, PoolingRequestOutput) and isinstance( + output, PoolingRequestOutput + ): + self.output = output async def get(self) -> RequestOutput | PoolingRequestOutput: """Get operation blocks on put event.""" @@ -407,7 +414,7 @@ class OutputProcessor: within the loop below. """ - request_outputs: list[RequestOutput] | list[PoolingRequestOutput] = [] + request_outputs: list[RequestOutput | PoolingRequestOutput] = [] reqs_to_abort: list[str] = [] for engine_core_output in engine_core_outputs: req_id = engine_core_output.request_id diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 2a47befec25f1..26ee10d2b9bbf 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import copy -from typing import Optional +from typing import Optional, cast from vllm.outputs import CompletionOutput from vllm.sampling_params import RequestOutputKind, SamplingParams @@ -37,7 +37,7 @@ class ParentRequest: self.child_requests = set() self.output_aggregator = ( - [None] * sampling_params.n + [cast(CompletionOutput, None)] * sampling_params.n if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY) else [] ) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index de15677aeea91..c49fd1bde8b98 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping -from typing import Any, Literal +from typing import Any, Literal, cast from vllm.config import VllmConfig from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs @@ -208,9 +208,9 @@ class Processor: enc = prompt.get("encoder_prompt") dec = prompt.get("decoder_prompt") if enc is not None: - _validate_single_prompt(enc) + _validate_single_prompt(cast(dict | str, enc)) if dec is not None: - _validate_single_prompt(dec) + _validate_single_prompt(cast(dict | str, dec)) else: _validate_single_prompt(prompt) # type: ignore[arg-type] @@ -332,7 +332,7 @@ class Processor: if not mm_data: return None - mm_uuids: MultiModalUUIDDict = {} + mm_uuids: dict[str, list[str | None] | str] = {} for modality, data in mm_data.items(): n = len(data) if isinstance(data, list) else 1 mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] @@ -384,7 +384,9 @@ class Processor: # if provided. self._validate_multi_modal_uuids(prompt) if isinstance(prompt, dict): - mm_uuids = prompt.get("multi_modal_uuids") + mm_uuids = cast( + MultiModalUUIDDict | None, prompt.get("multi_modal_uuids") + ) else: mm_uuids = None @@ -410,20 +412,13 @@ class Processor: encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) self._validate_model_inputs(encoder_inputs, decoder_inputs) - # Mypy does not always properly infer the types of some elements of - # discriminated unions of TypedDicts, because of how it handles - # inheritance of TypedDict. If we explicitly extract the items we want - # we can avoid type errors from using `dict.get` later in the method. - prompt_token_ids = ( - decoder_inputs["prompt_token_ids"] - if decoder_inputs["type"] != "embeds" - else None - ) - prompt_embeds = ( - decoder_inputs["prompt_embeds"] - if decoder_inputs["type"] == "embeds" - else None - ) + # Mypy can be conservative for TypedDict unions; normalize access. + if decoder_inputs["type"] == "embeds": + prompt_token_ids = None + prompt_embeds = decoder_inputs["prompt_embeds"] + else: + prompt_token_ids = decoder_inputs["prompt_token_ids"] + prompt_embeds = None sampling_params = None pooling_params = None