From 6e865b6a83565c5d661091ec0886403edb171794 Mon Sep 17 00:00:00 2001 From: Chukwuma Nwaugha <20521315+nwaughachukwuma@users.noreply.github.com> Date: Fri, 5 Dec 2025 06:44:32 +0000 Subject: [PATCH] Refactor example prompts fixture (#29854) Signed-off-by: nwaughac@gmail.com --- tests/conftest.py | 47 +++++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b20c9efef542a..204452b5835ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,7 +27,7 @@ import threading from collections.abc import Generator from contextlib import nullcontext from enum import Enum -from typing import Any, Callable, TypedDict, TypeVar, cast +from typing import Any, Callable, TypedDict, TypeVar, cast, TYPE_CHECKING import numpy as np import pytest @@ -67,6 +67,11 @@ from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils.collection_utils import is_list_of from vllm.utils.torch_utils import set_default_torch_num_threads +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + from transformers.generation.utils import GenerateOutput + + logger = init_logger(__name__) _TEST_DIR = os.path.dirname(__file__) @@ -202,10 +207,7 @@ def dynamo_reset(): @pytest.fixture def example_prompts() -> list[str]: - prompts = [] - for filename in _TEST_PROMPTS: - prompts += _read_prompts(filename) - return prompts + return [prompt for filename in _TEST_PROMPTS for prompt in _read_prompts(filename)] @pytest.fixture @@ -224,10 +226,7 @@ class DecoderPromptType(Enum): @pytest.fixture def example_long_prompts() -> list[str]: - prompts = [] - for filename in _LONG_PROMPTS: - prompts += _read_prompts(filename) - return prompts + return [prompt for filename in _LONG_PROMPTS for prompt in _read_prompts(filename)] @pytest.fixture(scope="session") @@ -353,10 +352,13 @@ class HfRunner: trust_remote_code=trust_remote_code, ) else: - model = auto_cls.from_pretrained( - model_name, - trust_remote_code=trust_remote_code, - **model_kwargs, + model = cast( + nn.Module, + auto_cls.from_pretrained( + model_name, + trust_remote_code=trust_remote_code, + **model_kwargs, + ), ) # in case some unquantized custom models are not in same dtype @@ -374,10 +376,12 @@ class HfRunner: self.model = model if not skip_tokenizer_init: - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, - dtype=dtype, - trust_remote_code=trust_remote_code, + self.tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast" = ( + AutoTokenizer.from_pretrained( + model_name, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) ) # don't put this import at the top level @@ -495,7 +499,7 @@ class HfRunner: outputs: list[tuple[list[list[int]], list[str]]] = [] for inputs in all_inputs: - output_ids = self.model.generate( + output_ids: torch.Tensor = self.model.generate( **self.wrap_device(inputs), use_cache=True, **kwargs, @@ -505,8 +509,7 @@ class HfRunner: skip_special_tokens=True, clean_up_tokenization_spaces=False, ) - output_ids = output_ids.cpu().tolist() - outputs.append((output_ids, output_str)) + outputs.append((output_ids.cpu().tolist(), output_str)) return outputs def generate_greedy( @@ -574,7 +577,7 @@ class HfRunner: all_logprobs: list[list[torch.Tensor]] = [] for inputs in all_inputs: - output = self.model.generate( + output: "GenerateOutput" = self.model.generate( **self.wrap_device(inputs), use_cache=True, do_sample=False, @@ -656,7 +659,7 @@ class HfRunner: all_output_strs: list[str] = [] for inputs in all_inputs: - output = self.model.generate( + output: "GenerateOutput" = self.model.generate( **self.wrap_device(inputs), use_cache=True, do_sample=False,