Merge 101c957bff1c1174b63f063b6383505f8862826b into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
Chukwuma Nwaugha 2025-12-25 08:07:49 +08:00 committed by GitHub
commit 4144c41d11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -58,8 +58,6 @@ from vllm.distributed import (
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.logprobs import Logprob
from vllm.multimodal.base import MediaWithBytes
from vllm.multimodal.utils import fetch_image from vllm.multimodal.utils import fetch_image
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams from vllm.sampling_params import BeamSearchParams
@ -71,6 +69,7 @@ from torch._inductor.utils import fresh_cache
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.logprobs import LogprobsOnePosition
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.generation.utils import GenerateOutput from transformers.generation.utils import GenerateOutput
@ -93,8 +92,7 @@ PromptVideoInput = _PromptMultiModalInput[np.ndarray]
def _read_prompts(filename: str) -> list[str]: def _read_prompts(filename: str) -> list[str]:
with open(filename) as f: with open(filename) as f:
prompts = f.readlines() return f.readlines()
return prompts
class ImageAssetPrompts(TypedDict): class ImageAssetPrompts(TypedDict):
@ -1021,9 +1019,10 @@ class VllmRunner:
perplexities = [] perplexities = []
for output in outputs: for output in outputs:
output = cast(TokensTextLogprobsPromptLogprobs, output) output = cast(TokensTextLogprobsPromptLogprobs, output)
token_datas = cast(list[dict[int, Logprob] | None], output[3]) token_datas = cast(list[LogprobsOnePosition | None], output[3])
assert token_datas[0] is None assert token_datas[0] is None
token_log_probs = [] token_log_probs = []
for token_data in token_datas[1:]: for token_data in token_datas[1:]:
assert token_data is not None assert token_data is not None
assert len(token_data) == 1 assert len(token_data) == 1
@ -1052,12 +1051,11 @@ class VllmRunner:
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens), BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens),
concurrency_limit=concurrency_limit, concurrency_limit=concurrency_limit,
) )
returned_outputs = []
for output in outputs: return [
token_ids = [x.tokens for x in output.sequences] ([x.tokens for x in output.sequences], [x.text for x in output.sequences])
texts = [x.text for x in output.sequences] for output in outputs
returned_outputs.append((token_ids, texts)) ]
return returned_outputs
def classify(self, prompts: list[str]) -> list[list[float]]: def classify(self, prompts: list[str]) -> list[list[float]]:
req_outputs = self.llm.classify(prompts) req_outputs = self.llm.classify(prompts)
@ -1432,11 +1430,7 @@ class LocalAssetServer:
return f"{self.base_url}/{name}" return f"{self.base_url}/{name}"
def get_image_asset(self, name: str) -> Image.Image: def get_image_asset(self, name: str) -> Image.Image:
image = fetch_image(self.url_for(name)) return fetch_image(self.url_for(name))
# Unwrap MediaWithBytes if present
if isinstance(image, MediaWithBytes):
image = image.media
return image
@pytest.fixture(scope="session") @pytest.fixture(scope="session")