From 8279201ce6ab178131fedff211a5539dc3ef2710 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Mon, 24 Mar 2025 19:37:54 -0400 Subject: [PATCH] [Build] Cython compilation support fix (#14296) Signed-off-by: Gregory Shtrasberg --- Dockerfile.rocm | 2 +- pyproject.toml | 1 + tests/build_cython.py | 38 +++++++++++++++++++++++++++ vllm/engine/llm_engine.py | 2 +- vllm/model_executor/layers/sampler.py | 3 ++- vllm/utils.py | 6 ++--- 6 files changed, 46 insertions(+), 6 deletions(-) create mode 100644 tests/build_cython.py diff --git a/Dockerfile.rocm b/Dockerfile.rocm index f852f3d69759..841e7978a424 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -40,7 +40,7 @@ ARG USE_CYTHON RUN cd vllm \ && python3 -m pip install -r requirements/rocm.txt \ && python3 setup.py clean --all \ - && if [ ${USE_CYTHON} -eq "1" ]; then python3 setup_cython.py build_ext --inplace; fi \ + && if [ ${USE_CYTHON} -eq "1" ]; then python3 tests/build_cython.py build_ext --inplace; fi \ && python3 setup.py bdist_wheel --dist-dir=dist FROM scratch AS export_vllm ARG COMMON_WORKDIR diff --git a/pyproject.toml b/pyproject.toml index ee4e2ed0b7ce..07616c858f1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ exclude = [ "vllm/triton_utils/**/*.py" = ["UP006", "UP035"] "vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"] +"vllm/utils.py" = ["UP006", "UP035"] [tool.ruff.lint] select = [ diff --git a/tests/build_cython.py b/tests/build_cython.py new file mode 100644 index 000000000000..9dea6bcd62f3 --- /dev/null +++ b/tests/build_cython.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +import Cython.Compiler.Options +from Cython.Build import cythonize +from setuptools import setup + +Cython.Compiler.Options.annotate = True + +infiles = [] + +infiles += [ + "vllm/engine/llm_engine.py", + "vllm/transformers_utils/detokenizer.py", + "vllm/engine/output_processor/single_step.py", + "vllm/outputs.py", + "vllm/engine/output_processor/stop_checker.py", +] + +infiles += [ + "vllm/core/scheduler.py", + "vllm/sequence.py", + "vllm/core/block_manager.py", +] + +infiles += [ + "vllm/model_executor/layers/sampler.py", + "vllm/sampling_params.py", + "vllm/utils.py", +] + +setup(ext_modules=cythonize(infiles, + annotate=False, + force=True, + compiler_directives={ + 'language_level': "3", + 'infer_types': True + })) + +# example usage: python3 build_cython.py build_ext --inplace diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b9a8b6a53065..3d019ea58c5e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1249,7 +1249,7 @@ class LLMEngine: return None def _advance_to_next_step( - self, output: List[SamplerOutput], + self, output: SamplerOutput, seq_group_metadata_list: List[SequenceGroupMetadata], scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: """Given model output from a single run, append the tokens to the diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 07ee75593f7b..1ee1332ac45e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1187,7 +1187,8 @@ def _build_sampler_output( deferred_sample_results_args=deferred_sample_results_args) -def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: +def _get_next_prompt_tokens( + seq_group: SequenceGroupToSample) -> tuple[int, ...]: """Get a list of next prompt tokens to compute logprob from a given sequence group. diff --git a/vllm/utils.py b/vllm/utils.py index d87ec44c75fd..9e14a628993f 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -37,7 +37,7 @@ from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, TypeVar, Union) + Optional, Type, TypeVar, Union) from uuid import uuid4 import cloudpickle @@ -1544,9 +1544,9 @@ class LazyDict(Mapping[str, T], Generic[T]): return len(self._factory) -class ClassRegistry(UserDict[type[T], _V]): +class ClassRegistry(UserDict[Type[T], _V]): - def __getitem__(self, key: type[T]) -> _V: + def __getitem__(self, key: Type[T]) -> _V: for cls in key.mro(): if cls in self.data: return self.data[cls]