mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-26 02:44:27 +08:00
[Code Quality] Add missing return type annotations to core modules
Add return type annotations to functions in: - vllm/plugins/__init__.py: load_general_plugins() -> None - vllm/beam_search.py: create_sort_beams_key_function() -> Callable - vllm/forward_context.py: create_forward_context() -> ForwardContext, override_forward_context() and set_forward_context() -> Generator Signed-off-by: yurekami <yurekami@users.noreply.github.com>
This commit is contained in:
parent
09dc7c690c
commit
3c1b94e5b9
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
@ -79,7 +80,9 @@ def get_beam_search_score(
|
|||||||
return cumulative_logprob / (seq_len**length_penalty)
|
return cumulative_logprob / (seq_len**length_penalty)
|
||||||
|
|
||||||
|
|
||||||
def create_sort_beams_key_function(eos_token_id: int, length_penalty: float):
|
def create_sort_beams_key_function(
|
||||||
|
eos_token_id: int, length_penalty: float
|
||||||
|
) -> Callable[[BeamSearchSequence], float]:
|
||||||
def sort_beams_key(x: BeamSearchSequence) -> float:
|
def sort_beams_key(x: BeamSearchSequence) -> float:
|
||||||
return get_beam_search_score(
|
return get_beam_search_score(
|
||||||
x.tokens, x.cum_logprob, eos_token_id, length_penalty
|
x.tokens, x.cum_logprob, eos_token_id, length_penalty
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, NamedTuple
|
from typing import Any, NamedTuple
|
||||||
@ -235,7 +236,7 @@ def create_forward_context(
|
|||||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor: BatchDescriptor | None = None,
|
batch_descriptor: BatchDescriptor | None = None,
|
||||||
ubatch_slices: UBatchSlices | None = None,
|
ubatch_slices: UBatchSlices | None = None,
|
||||||
):
|
) -> ForwardContext:
|
||||||
return ForwardContext(
|
return ForwardContext(
|
||||||
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
no_compile_layers=vllm_config.compilation_config.static_forward_context,
|
||||||
virtual_engine=virtual_engine,
|
virtual_engine=virtual_engine,
|
||||||
@ -248,7 +249,9 @@ def create_forward_context(
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def override_forward_context(forward_context: ForwardContext | None):
|
def override_forward_context(
|
||||||
|
forward_context: ForwardContext | None,
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
"""A context manager that overrides the current forward context.
|
"""A context manager that overrides the current forward context.
|
||||||
This is used to override the forward context for a specific
|
This is used to override the forward context for a specific
|
||||||
forward pass.
|
forward pass.
|
||||||
@ -272,7 +275,7 @@ def set_forward_context(
|
|||||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor: BatchDescriptor | None = None,
|
batch_descriptor: BatchDescriptor | None = None,
|
||||||
ubatch_slices: UBatchSlices | None = None,
|
ubatch_slices: UBatchSlices | None = None,
|
||||||
):
|
) -> Generator[None, None, None]:
|
||||||
"""A context manager that stores the current forward context,
|
"""A context manager that stores the current forward context,
|
||||||
can be attention metadata, etc.
|
can be attention metadata, etc.
|
||||||
Here we can inject common logic for every model forward pass.
|
Here we can inject common logic for every model forward pass.
|
||||||
|
|||||||
@ -65,7 +65,7 @@ def load_plugins_by_group(group: str) -> dict[str, Callable[[], Any]]:
|
|||||||
return plugins
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
def load_general_plugins():
|
def load_general_plugins() -> None:
|
||||||
"""WARNING: plugins can be loaded for multiple times in different
|
"""WARNING: plugins can be loaded for multiple times in different
|
||||||
processes. They should be designed in a way that they can be loaded
|
processes. They should be designed in a way that they can be loaded
|
||||||
multiple times without causing issues.
|
multiple times without causing issues.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user