[mypy] Pass type checking for vllm/utils and vllm/v1/pool (#29666)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-11-28 20:40:47 +08:00 committed by GitHub
parent 33b06a6f24
commit 953d9c820b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 37 additions and 43 deletions

View File

@ -36,8 +36,10 @@ FILES = [
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
"vllm/utils",
"vllm/v1/core",
"vllm/v1/engine",
"vllm/v1/pool",
"vllm/v1/worker",
]
@ -59,7 +61,6 @@ SEPARATE_GROUPS = [
"vllm/v1/executor",
"vllm/v1/kv_offload",
"vllm/v1/metrics",
"vllm/v1/pool",
"vllm/v1/sample",
"vllm/v1/spec_decode",
"vllm/v1/structured_output",

View File

@ -12,7 +12,7 @@ from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections.abc import AsyncGenerator, Awaitable, Callable
from concurrent.futures import Executor, ThreadPoolExecutor
from functools import partial
from typing import TypeVar
from typing import TYPE_CHECKING, TypeVar
from transformers.tokenization_utils_base import BatchEncoding
from typing_extensions import ParamSpec
@ -257,6 +257,13 @@ def in_loop(event_loop: AbstractEventLoop) -> bool:
return False
# A hack to pass mypy
if TYPE_CHECKING:
def anext(it: AsyncGenerator[T, None]):
return it.__anext__()
async def merge_async_iterators(
*iterators: AsyncGenerator[T, None],
) -> AsyncGenerator[tuple[int, T], None]:

View File

@ -4,7 +4,7 @@
from collections.abc import Callable, Iterable
from functools import reduce
from typing import TYPE_CHECKING, TypeAlias, TypeVar, cast, overload
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, overload
if TYPE_CHECKING:
import torch
@ -82,16 +82,13 @@ def json_map_leaves(
def json_map_leaves(
func: Callable[[_T], _U],
value: "BatchedTensorInputs" | _JSONTree[_T],
value: Any,
) -> "BatchedTensorInputs" | _JSONTree[_U]:
"""Apply a function to each leaf in a nested JSON structure."""
if isinstance(value, dict):
return {
k: json_map_leaves(func, v) # type: ignore[arg-type]
for k, v in value.items()
}
return {k: json_map_leaves(func, v) for k, v in value.items()} # type: ignore
elif isinstance(value, list):
return [json_map_leaves(func, v) for v in value]
return [json_map_leaves(func, v) for v in value] # type: ignore
elif isinstance(value, tuple):
return tuple(json_map_leaves(func, v) for v in value)
else:
@ -140,9 +137,9 @@ def json_reduce_leaves(
def json_reduce_leaves(
func: Callable[..., _T | _U],
func: Callable[[_T, _T], _T] | Callable[[_U, _T], _U],
value: _JSONTree[_T],
initial: _U = cast(_U, ...), # noqa: B008
initial: _U = ..., # type: ignore[assignment]
/,
) -> _T | _U:
"""
@ -151,13 +148,9 @@ def json_reduce_leaves(
sequence to a single value.
"""
if initial is ...:
return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type]
return reduce(func, json_iter_leaves(value)) # type: ignore
return reduce(
func, # type: ignore[arg-type]
json_iter_leaves(value),
initial,
)
return reduce(func, json_iter_leaves(value), initial) # type: ignore
def json_count_leaves(value: JSONTree[_T]) -> int:

View File

@ -68,11 +68,11 @@ class MemorySnapshot:
timestamp: float = 0.0
auto_measure: bool = True
def __post_init__(self):
def __post_init__(self) -> None:
if self.auto_measure:
self.measure()
def measure(self):
def measure(self) -> None:
from vllm.platforms import current_platform
# we measure the torch peak memory usage via allocated_bytes,

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import importlib
import importlib.util
import os
import torch
@ -47,8 +47,8 @@ def find_nccl_include_paths() -> list[str] | None:
try:
spec = importlib.util.find_spec("nvidia.nccl")
if spec and getattr(spec, "submodule_search_locations", None):
for loc in spec.submodule_search_locations:
if spec and (locs := getattr(spec, "submodule_search_locations", None)):
for loc in locs:
inc_dir = os.path.join(loc, "include")
if os.path.exists(os.path.join(inc_dir, "nccl.h")):
paths.append(inc_dir)

View File

@ -72,7 +72,7 @@ def get_ip() -> str:
return "0.0.0.0"
def test_loopback_bind(address, family):
def test_loopback_bind(address: str, family: int) -> bool:
try:
s = socket.socket(family, socket.SOCK_DGRAM)
s.bind((address, 0)) # Port 0 = auto assign

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from typing import Any, TypeVar
_T = TypeVar("_T", bound=type)
class ExtensionManager:
@ -34,7 +36,7 @@ class ExtensionManager:
Decorator to register a class with the given name.
"""
def wrap(cls_to_register):
def wrap(cls_to_register: _T) -> _T:
self.name2class[name] = cls_to_register
return cls_to_register

View File

@ -13,7 +13,7 @@ import numpy.typing as npt
import torch
from packaging import version
from packaging.version import Version
from torch.library import Library
from torch.library import Library, infer_schema
import vllm.envs as envs
@ -78,7 +78,6 @@ def guard_cuda_initialization():
yield
return
had_key = "CUDA_VISIBLE_DEVICES" in os.environ
old_value = os.environ.get("CUDA_VISIBLE_DEVICES")
os.environ["CUDA_VISIBLE_DEVICES"] = ""
try:
@ -90,10 +89,10 @@ def guard_cuda_initialization():
err_msg = str(e)
raise RuntimeError(err_msg) from e
finally:
if had_key:
os.environ["CUDA_VISIBLE_DEVICES"] = old_value
if old_value is None:
del os.environ["CUDA_VISIBLE_DEVICES"]
else:
os.environ.pop("CUDA_VISIBLE_DEVICES")
os.environ["CUDA_VISIBLE_DEVICES"] = old_value
def get_dtype_size(dtype: torch.dtype) -> int:
@ -525,8 +524,7 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
# Helper function used in testing.
def _is_torch_equal_or_newer(torch_version: str, target: str) -> bool:
torch_version = version.parse(torch_version)
return torch_version >= version.parse(target)
return version.parse(torch_version) >= version.parse(target)
def is_torch_equal_or_newer(target: str) -> bool:
@ -640,15 +638,8 @@ def direct_register_custom_op(
dispatch_key = current_platform.dispatch_key
import torch.library
schema_str = infer_schema(op_func, mutates_args=mutates_args)
if hasattr(torch.library, "infer_schema"):
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
else:
# for pytorch 2.4
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)

View File

@ -67,16 +67,16 @@ def build_pooling_cursor(
n_seq = len(num_scheduled_tokens)
index = list(range(n_seq))
num_scheduled_tokens = torch.tensor(num_scheduled_tokens, device="cpu")
num_scheduled_tokens_cpu = torch.tensor(num_scheduled_tokens, device="cpu")
cumsum = torch.zeros(
n_seq + 1, dtype=torch.int64, pin_memory=pin_memory, device="cpu"
)
torch.cumsum(num_scheduled_tokens, dim=0, out=cumsum[1:])
torch.cumsum(num_scheduled_tokens_cpu, dim=0, out=cumsum[1:])
cumsum = cumsum.to(device, non_blocking=True)
return PoolingCursor(
index=index,
first_token_indices_gpu=cumsum[:n_seq],
last_token_indices_gpu=cumsum[1:] - 1,
prompt_lens_cpu=prompt_lens,
num_scheduled_tokens_cpu=num_scheduled_tokens,
num_scheduled_tokens_cpu=num_scheduled_tokens_cpu,
)