diff --git a/vllm/beam_search.py b/vllm/beam_search.py index fcd2d1f0e01ab..e7006a8b5c38b 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional @@ -79,7 +80,9 @@ def get_beam_search_score( return cumulative_logprob / (seq_len**length_penalty) -def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): +def create_sort_beams_key_function( + eos_token_id: int, length_penalty: float +) -> Callable[[BeamSearchSequence], float]: def sort_beams_key(x: BeamSearchSequence) -> float: return get_beam_search_score( x.tokens, x.cum_logprob, eos_token_id, length_penalty diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 7a569ec32eac9..73950534b8559 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -3,6 +3,7 @@ import time from collections import defaultdict +from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass from typing import Any, NamedTuple @@ -235,7 +236,7 @@ def create_forward_context( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, -): +) -> ForwardContext: return ForwardContext( no_compile_layers=vllm_config.compilation_config.static_forward_context, virtual_engine=virtual_engine, @@ -248,7 +249,9 @@ def create_forward_context( @contextmanager -def override_forward_context(forward_context: ForwardContext | None): +def override_forward_context( + forward_context: ForwardContext | None, +) -> Generator[None, None, None]: """A context manager that overrides the current forward context. This is used to override the forward context for a specific forward pass. @@ -272,7 +275,7 @@ def set_forward_context( cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: BatchDescriptor | None = None, ubatch_slices: UBatchSlices | None = None, -): +) -> Generator[None, None, None]: """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 4c59d5364a763..99e4e9955c2dc 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -65,7 +65,7 @@ def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]: return plugins -def load_general_plugins(): +def load_general_plugins() -> None: """WARNING: plugins can be loaded for multiple times in different processes. They should be designed in a way that they can be loaded multiple times without causing issues.