diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d5d02fdeb7f4..05f3c3b314d1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -50,9 +50,9 @@ steps: - tests/multimodal - tests/test_utils - tests/worker - - tests/standalone_tests/lazy_torch_compile.py + - tests/standalone_tests/lazy_imports.py commands: - - python3 standalone_tests/lazy_torch_compile.py + - python3 standalone_tests/lazy_imports.py - pytest -v -s mq_llm_engine # MQLLMEngine - pytest -v -s async_engine # AsyncLLMEngine - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py diff --git a/tests/standalone_tests/lazy_torch_compile.py b/tests/standalone_tests/lazy_imports.py similarity index 56% rename from tests/standalone_tests/lazy_torch_compile.py rename to tests/standalone_tests/lazy_imports.py index b3b5809525c9..61e3b387973b 100644 --- a/tests/standalone_tests/lazy_torch_compile.py +++ b/tests/standalone_tests/lazy_imports.py @@ -8,7 +8,17 @@ from contextlib import nullcontext from vllm_test_utils import BlameResult, blame -module_name = "torch._inductor.async_compile" +# List of modules that should not be imported too early. +# Lazy import `torch._inductor.async_compile` to avoid creating +# too many processes before we set the number of compiler threads. +# Lazy import `cv2` to avoid bothering users who only use text models. +# `cv2` can easily mess up the environment. +module_names = ["torch._inductor.async_compile", "cv2"] + + +def any_module_imported(): + return any(module_name in sys.modules for module_name in module_names) + # In CI, we only check finally if the module is imported. # If it is indeed imported, we can rerun the test with `use_blame=True`, @@ -16,8 +26,7 @@ module_name = "torch._inductor.async_compile" # and help find the root cause. # We don't run it in CI by default because it is slow. use_blame = False -context = blame( - lambda: module_name in sys.modules) if use_blame else nullcontext() +context = blame(any_module_imported) if use_blame else nullcontext() with context as result: import vllm # noqa @@ -25,6 +34,6 @@ if use_blame: assert isinstance(result, BlameResult) print(f"the first import location is:\n{result.trace_stack}") -assert module_name not in sys.modules, ( - f"Module {module_name} is imported. To see the first" +assert not any_module_imported(), ( + f"Some the modules in {module_names} are imported. To see the first" f" import location, run the test with `use_blame=True`.") diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 88f184399722..78a2918e3ed3 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -6,7 +6,6 @@ from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Optional -import cv2 import numpy as np import numpy.typing as npt from PIL import Image @@ -95,6 +94,8 @@ def resize_video(frames: npt.NDArray, size: tuple[int, int]) -> npt.NDArray: new_height, new_width = size resized_frames = np.empty((num_frames, new_height, new_width, channels), dtype=frames.dtype) + # lazy import cv2 to avoid bothering users who only use text models + import cv2 for i, frame in enumerate(frames): resized_frame = cv2.resize(frame, (new_width, new_height)) resized_frames[i] = resized_frame diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index cecafcc78fa1..1550f978ed20 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -8,21 +8,18 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast import huggingface_hub from huggingface_hub import HfApi, hf_hub_download -from mistral_common.protocol.instruct.request import ChatCompletionRequest -from mistral_common.tokens.tokenizers.base import SpecialTokens -# yapf: disable -from mistral_common.tokens.tokenizers.mistral import ( - MistralTokenizer as PublicMistralTokenizer) -# yapf: enable -from mistral_common.tokens.tokenizers.sentencepiece import ( - SentencePieceTokenizer) -from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy, - Tekkenizer) from vllm.logger import init_logger from vllm.utils import is_list_of if TYPE_CHECKING: + # make sure `mistral_common` is lazy imported, + # so that users who only use non-mistral models + # will not be bothered by the dependency. + from mistral_common.protocol.instruct.request import ChatCompletionRequest + from mistral_common.tokens.tokenizers.mistral import ( + MistralTokenizer as PublicMistralTokenizer) + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam logger = init_logger(__name__) @@ -33,7 +30,7 @@ class Encoding: input_ids: Union[List[int], List[List[int]]] -def maybe_serialize_tool_calls(request: ChatCompletionRequest): +def maybe_serialize_tool_calls(request: "ChatCompletionRequest"): # SEE: https://github.com/vllm-project/vllm/pull/9951 # Credits go to: @gcalmettes # NOTE: There is currently a bug in pydantic where attributes @@ -108,12 +105,16 @@ def find_tokenizer_file(files: List[str]): class MistralTokenizer: - def __init__(self, tokenizer: PublicMistralTokenizer) -> None: + def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: self.mistral = tokenizer self.instruct = tokenizer.instruct_tokenizer tokenizer_ = tokenizer.instruct_tokenizer.tokenizer + from mistral_common.tokens.tokenizers.tekken import ( + SpecialTokenPolicy, Tekkenizer) self.is_tekken = isinstance(tokenizer_, Tekkenizer) + from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer) self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer) if self.is_tekken: # Make sure special tokens will not raise @@ -153,6 +154,8 @@ class MistralTokenizer: assert Path( path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}" + from mistral_common.tokens.tokenizers.mistral import ( + MistralTokenizer as PublicMistralTokenizer) mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file) return cls(mistral_tokenizer) @@ -181,6 +184,8 @@ class MistralTokenizer: # by the guided structured output backends. @property def all_special_tokens_extended(self) -> List[str]: + from mistral_common.tokens.tokenizers.base import SpecialTokens + # tekken defines its own extended special tokens list if hasattr(self.tokenizer, "SPECIAL_TOKENS"): special_tokens = self.tokenizer.SPECIAL_TOKENS @@ -284,6 +289,8 @@ class MistralTokenizer: if last_message["role"] == "assistant": last_message["prefix"] = True + from mistral_common.protocol.instruct.request import ( + ChatCompletionRequest) request = ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var] encoded = self.mistral.encode_chat_completion(request) @@ -292,6 +299,7 @@ class MistralTokenizer: return encoded.tokens def convert_tokens_to_string(self, tokens: List[str]) -> str: + from mistral_common.tokens.tokenizers.base import SpecialTokens if self.is_tekken: tokens = [ t for t in tokens @@ -363,6 +371,8 @@ class MistralTokenizer: ids: List[int], skip_special_tokens: bool = True, ) -> List[str]: + from mistral_common.tokens.tokenizers.base import SpecialTokens + # TODO(Patrick) - potentially allow special tokens to not be skipped assert ( skip_special_tokens