mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 02:47:03 +08:00
[Code Quality] Add missing return type annotations to core modules
Add return type annotations to functions in: - vllm/plugins/__init__.py: load_general_plugins() -> None - vllm/beam_search.py: create_sort_beams_key_function() -> Callable - vllm/forward_context.py: create_forward_context() -> ForwardContext, override_forward_context() and set_forward_context() -> Generator Signed-off-by: yurekami <yurekami@users.noreply.github.com>
This commit is contained in:
parent
09dc7c690c
commit
3c1b94e5b9
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
@ -79,7 +80,9 @@ def get_beam_search_score(
|
||||
return cumulative_logprob / (seq_len**length_penalty)
|
||||
|
||||
|
||||
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
|
||||
def create_sort_beams_key_function(
|
||||
eos_token_id: int, length_penalty: float
|
||||
) -> Callable[[BeamSearchSequence], float]:
|
||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||
return get_beam_search_score(
|
||||
x.tokens, x.cum_logprob, eos_token_id, length_penalty
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NamedTuple
|
||||
@ -235,7 +236,7 @@ def create_forward_context(
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
):
|
||||
) -> ForwardContext:
|
||||
return ForwardContext(
|
||||
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
||||
virtual_engine=virtual_engine,
|
||||
@ -248,7 +249,9 @@ def create_forward_context(
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_forward_context(forward_context: ForwardContext | None):
|
||||
def override_forward_context(
|
||||
forward_context: ForwardContext | None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""A context manager that overrides the current forward context.
|
||||
This is used to override the forward context for a specific
|
||||
forward pass.
|
||||
@ -272,7 +275,7 @@ def set_forward_context(
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: BatchDescriptor | None = None,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
):
|
||||
) -> Generator[None, None, None]:
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
Here we can inject common logic for every model forward pass.
|
||||
|
||||
@ -65,7 +65,7 @@ def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]:
|
||||
return plugins
|
||||
|
||||
|
||||
def load_general_plugins():
|
||||
def load_general_plugins() -> None:
|
||||
"""WARNING: plugins can be loaded for multiple times in different
|
||||
processes. They should be designed in a way that they can be loaded
|
||||
multiple times without causing issues.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user