diff --git a/tests/conftest.py b/tests/conftest.py index 30e25294925ca..e5b2b478aba67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,8 +58,6 @@ from vllm.distributed import ( initialize_model_parallel, ) from vllm.logger import init_logger -from vllm.logprobs import Logprob -from vllm.multimodal.base import MediaWithBytes from vllm.multimodal.utils import fetch_image from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams @@ -71,6 +69,7 @@ from torch._inductor.utils import fresh_cache if TYPE_CHECKING: + from vllm.logprobs import LogprobsOnePosition from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.generation.utils import GenerateOutput @@ -93,8 +92,7 @@ PromptVideoInput = _PromptMultiModalInput[np.ndarray] def _read_prompts(filename: str) -> list[str]: with open(filename) as f: - prompts = f.readlines() - return prompts + return f.readlines() class ImageAssetPrompts(TypedDict): @@ -1021,9 +1019,10 @@ class VllmRunner: perplexities = [] for output in outputs: output = cast(TokensTextLogprobsPromptLogprobs, output) - token_datas = cast(list[dict[int, Logprob] | None], output[3]) + token_datas = cast(list[LogprobsOnePosition | None], output[3]) assert token_datas[0] is None token_log_probs = [] + for token_data in token_datas[1:]: assert token_data is not None assert len(token_data) == 1 @@ -1052,12 +1051,11 @@ class VllmRunner: BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens), concurrency_limit=concurrency_limit, ) - returned_outputs = [] - for output in outputs: - token_ids = [x.tokens for x in output.sequences] - texts = [x.text for x in output.sequences] - returned_outputs.append((token_ids, texts)) - return returned_outputs + + return [ + ([x.tokens for x in output.sequences], [x.text for x in output.sequences]) + for output in outputs + ] def classify(self, prompts: list[str]) -> list[list[float]]: req_outputs = self.llm.classify(prompts) @@ -1432,11 +1430,7 @@ class LocalAssetServer: return f"{self.base_url}/{name}" def get_image_asset(self, name: str) -> Image.Image: - image = fetch_image(self.url_for(name)) - # Unwrap MediaWithBytes if present - if isinstance(image, MediaWithBytes): - image = image.media - return image + return fetch_image(self.url_for(name)) @pytest.fixture(scope="session")