From 44be2b7349c12723b0695353ebb2bec3de4b5ae1 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 22 Sep 2025 13:23:45 +0100 Subject: [PATCH] Make `mypy` behave like a proper pre-commit hook (#25313) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: yewentao256 --- .github/CODEOWNERS | 1 + .pre-commit-config.yaml | 34 ++++----- pyproject.toml | 21 ------ tools/mypy.sh | 35 --------- tools/pre_commit/mypy.py | 140 +++++++++++++++++++++++++++++++++++ vllm/entrypoints/llm.py | 4 +- vllm/entrypoints/renderer.py | 2 +- vllm/utils/__init__.py | 9 ++- vllm/utils/tensor_schema.py | 7 +- 9 files changed, 166 insertions(+), 87 deletions(-) delete mode 100755 tools/mypy.sh create mode 100755 tools/pre_commit/mypy.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 37bd0ace98a97..9d749fe8d3238 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -72,6 +72,7 @@ mkdocs.yaml @hmellor # Linting .markdownlint.yaml @hmellor .pre-commit-config.yaml @hmellor +/tools/pre_commit @hmellor # CPU /vllm/v1/worker/cpu* @bigPYJ1151 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bf36db7d15ed9..8ca414ee4269b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -60,38 +60,32 @@ repos: files: ^requirements/test\.(in|txt)$ - id: mypy-local name: Run mypy for local Python installation - entry: tools/mypy.sh 0 "local" - language: python - types: [python] - additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic] + entry: python tools/pre_commit/mypy.py 0 "local" stages: [pre-commit] # Don't run in CI + <<: &mypy_common + language: python + types_or: [python, pyi] + require_serial: true + additional_dependencies: [mypy==1.11.1, regex, types-cachetools, types-setuptools, types-PyYAML, types-requests, types-torch, pydantic] - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.9 - entry: tools/mypy.sh 1 "3.9" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.9" + <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.10 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.10 - entry: tools/mypy.sh 1 "3.10" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.10" + <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.11 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.11 - entry: tools/mypy.sh 1 "3.11" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.11" + <<: *mypy_common stages: [manual] # Only run in CI - id: mypy-3.12 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.12 - entry: tools/mypy.sh 1 "3.12" - language: python - types: [python] - additional_dependencies: *mypy_deps + entry: python tools/pre_commit/mypy.py 1 "3.12" + <<: *mypy_common stages: [manual] # Only run in CI - id: shellcheck name: Lint shell scripts diff --git a/pyproject.toml b/pyproject.toml index f43ae69e00bdd..88c5c4067f5ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,27 +110,6 @@ ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" -# After fixing type errors resulting from follow_imports: "skip" -> "silent", -# move the directory here and remove it from tools/mypy.sh -files = [ - "vllm/*.py", - "vllm/assets", - "vllm/entrypoints", - "vllm/inputs", - "vllm/logging_utils", - "vllm/multimodal", - "vllm/platforms", - "vllm/transformers_utils", - "vllm/triton_utils", - "vllm/usage", -] -# TODO(woosuk): Include the code from Megatron and HuggingFace. -exclude = [ - "vllm/model_executor/parallel_utils/|vllm/model_executor/models/", - # Ignore triton kernels in ops. - 'vllm/attention/ops/.*\.py$' -] - [tool.isort] skip_glob = [ ".buildkite/*", diff --git a/tools/mypy.sh b/tools/mypy.sh deleted file mode 100755 index 63e3b9a916634..0000000000000 --- a/tools/mypy.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash - -CI=${1:-0} -PYTHON_VERSION=${2:-local} - -if [ "$CI" -eq 1 ]; then - set -e -fi - -if [ $PYTHON_VERSION == "local" ]; then - PYTHON_VERSION=$(python -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') -fi - -run_mypy() { - echo "Running mypy on $1" - if [ "$CI" -eq 1 ] && [ -z "$1" ]; then - mypy --python-version "${PYTHON_VERSION}" "$@" - return - fi - mypy --follow-imports skip --python-version "${PYTHON_VERSION}" "$@" -} - -run_mypy # Note that this is less strict than CI -run_mypy tests -run_mypy vllm/attention -run_mypy vllm/compilation -run_mypy vllm/distributed -run_mypy vllm/engine -run_mypy vllm/executor -run_mypy vllm/inputs -run_mypy vllm/lora -run_mypy --exclude 'vllm/model_executor/layers/fla/ops' vllm/model_executor -run_mypy vllm/plugins -run_mypy vllm/worker -run_mypy vllm/v1 diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py new file mode 100755 index 0000000000000..039cf6075f631 --- /dev/null +++ b/tools/pre_commit/mypy.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Run mypy on changed files. + +This script is designed to be used as a pre-commit hook. It runs mypy +on files that have been changed. It groups files into different mypy calls +based on their directory to avoid import following issues. + +Usage: + python tools/pre_commit/mypy.py + +Args: + ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to + "silent" for the main group of files. + python_version: Python version to use (e.g., "3.10") or "local" to use + the local Python version. + changed_files: List of changed files to check. +""" + +import subprocess +import sys +from typing import Optional + +import regex as re + +FILES = [ + "vllm/*.py", + "vllm/assets", + "vllm/entrypoints", + "vllm/inputs", + "vllm/logging_utils", + "vllm/multimodal", + "vllm/platforms", + "vllm/transformers_utils", + "vllm/triton_utils", + "vllm/usage", +] + +# After fixing errors resulting from changing follow_imports +# from "skip" to "silent", move the following directories to FILES +SEPARATE_GROUPS = [ + "tests", + "vllm/attention", + "vllm/compilation", + "vllm/distributed", + "vllm/engine", + "vllm/executor", + "vllm/inputs", + "vllm/lora", + "vllm/model_executor", + "vllm/plugins", + "vllm/worker", + "vllm/v1", +] + +# TODO(woosuk): Include the code from Megatron and HuggingFace. +EXCLUDE = [ + "vllm/model_executor/parallel_utils", + "vllm/model_executor/models", + "vllm/model_executor/layers/fla/ops", + # Ignore triton kernels in ops. + "vllm/attention/ops", +] + + +def group_files(changed_files: list[str]) -> dict[str, list[str]]: + """ + Group changed files into different mypy calls. + + Args: + changed_files: List of changed files. + + Returns: + A dictionary mapping file group names to lists of changed files. + """ + exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*") + files_pattern = re.compile(f"^({'|'.join(FILES)}).*") + file_groups = {"": []} + file_groups.update({k: [] for k in SEPARATE_GROUPS}) + for changed_file in changed_files: + # Skip files which should be ignored completely + if exclude_pattern.match(changed_file): + continue + # Group files by mypy call + if files_pattern.match(changed_file): + file_groups[""].append(changed_file) + continue + else: + for directory in SEPARATE_GROUPS: + if re.match(f"^{directory}.*", changed_file): + file_groups[directory].append(changed_file) + break + return file_groups + + +def mypy(targets: list[str], python_version: Optional[str], + follow_imports: Optional[str], file_group: str) -> int: + """ + Run mypy on the given targets. + + Args: + targets: List of files or directories to check. + python_version: Python version to use (e.g., "3.10") or None to use + the default mypy version. + follow_imports: Value for the --follow-imports option or None to use + the default mypy behavior. + file_group: The file group name for logging purposes. + + Returns: + The return code from mypy. + """ + args = ["mypy"] + if python_version is not None: + args += ["--python-version", python_version] + if follow_imports is not None: + args += ["--follow-imports", follow_imports] + print(f"$ {' '.join(args)} {file_group}") + return subprocess.run(args + targets, check=False).returncode + + +def main(): + ci = sys.argv[1] == "1" + python_version = sys.argv[2] + file_groups = group_files(sys.argv[3:]) + + if python_version == "local": + python_version = f"{sys.version_info.major}.{sys.version_info.minor}" + + returncode = 0 + for file_group, changed_files in file_groups.items(): + follow_imports = None if ci and file_group == "" else "skip" + if changed_files: + returncode |= mypy(changed_files, python_version, follow_imports, + file_group) + return returncode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 092d3f276d1c5..c41f44aa47187 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1468,7 +1468,7 @@ class LLM: def _validate_and_add_requests( self, - prompts: Union[PromptType, Sequence[PromptType]], + prompts: Union[PromptType, Sequence[PromptType], DataPrompt], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], *, @@ -1478,7 +1478,7 @@ class LLM: ) -> None: if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. - prompts = [prompts] + prompts = [prompts] # type: ignore[list-item] num_requests = len(prompts) if isinstance(params, Sequence) and len(params) != num_requests: diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index fb859d57be9fe..d7ce57c728ba6 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -280,7 +280,7 @@ class CompletionRenderer(BaseRenderer): if truncate_prompt_tokens < 0: truncate_prompt_tokens = self.model_config.max_model_len - if max_length is not None and truncate_prompt_tokens > max_length: + if max_length is not None and truncate_prompt_tokens > max_length: # type: ignore[operator] raise ValueError( f"truncate_prompt_tokens ({truncate_prompt_tokens}) " f"cannot be greater than max_length ({max_length}). " diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index b74b746a35830..022e35a399c53 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -551,9 +551,10 @@ class AsyncMicrobatchTokenizer: # If every request uses identical kwargs we can run a single # batched tokenizer call for a big speed-up. if can_batch and len(prompts) > 1: - encode_fn = partial(self.tokenizer, prompts, **kwargs) + batch_encode_fn = partial(self.tokenizer, prompts, + **kwargs) results = await self._loop.run_in_executor( - self._executor, encode_fn) + self._executor, batch_encode_fn) for i, fut in enumerate(result_futures): if not fut.done(): @@ -889,7 +890,7 @@ def get_open_port() -> int: def get_open_ports_list(count: int = 5) -> list[int]: """Get a list of open ports.""" - ports = set() + ports = set[int]() while len(ports) < count: ports.add(get_open_port()) return list(ports) @@ -1279,7 +1280,7 @@ def as_list(maybe_list: Iterable[T]) -> list[T]: def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]: if isinstance(obj, str) or not isinstance(obj, Iterable): - obj = [obj] + return [obj] # type: ignore[list-item] return obj diff --git a/vllm/utils/tensor_schema.py b/vllm/utils/tensor_schema.py index 21d3249fe1547..d75dbcd5401b2 100644 --- a/vllm/utils/tensor_schema.py +++ b/vllm/utils/tensor_schema.py @@ -22,9 +22,8 @@ class TensorShape: self.dims = dims self.dynamic_dims = dynamic_dims if dynamic_dims else set() - def resolve(self, **bindings: dict[str, - int]) -> tuple[Union[int, str], ...]: - resolved = [] + def resolve(self, **bindings: int) -> tuple[Union[int, str], ...]: + resolved = list[Union[int, str]]() for dim in self.dims: if isinstance(dim, str) and dim in bindings: resolved.append(bindings[dim]) @@ -159,7 +158,7 @@ class TensorSchema: def validate(self) -> None: type_hints = get_type_hints(self.__class__, include_extras=True) - shape_env = {} + shape_env = dict[str, int]() for field_name, field_type in type_hints.items(): # Check if field is missing