mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-09 12:27:12 +08:00
simplify the returned value from generate_beam_search
> includes cleaning up wrap_device, generate_prompt_perplexity and get_image_asset Signed-off-by: Chukwuma Nwaugha <nwaughac@gmail.com>
This commit is contained in:
parent
6af70e11a0
commit
95af481818
@ -58,8 +58,7 @@ from vllm.distributed import (
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import Logprob
|
||||
from vllm.multimodal.base import MediaWithBytes
|
||||
from vllm.logprobs import LogprobsOnePosition
|
||||
from vllm.multimodal.utils import fetch_image
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
@ -93,8 +92,7 @@ PromptVideoInput = _PromptMultiModalInput[np.ndarray]
|
||||
|
||||
def _read_prompts(filename: str) -> list[str]:
|
||||
with open(filename) as f:
|
||||
prompts = f.readlines()
|
||||
return prompts
|
||||
return f.readlines()
|
||||
|
||||
|
||||
class ImageAssetPrompts(TypedDict):
|
||||
@ -267,7 +265,7 @@ class HfRunner:
|
||||
if isinstance(x, dict):
|
||||
return {k: self.wrap_device(v, device) for k, v in x.items()}
|
||||
|
||||
if hasattr(x, "device") and x.device.type == device:
|
||||
if hasattr(x.device, "type") and x.device.type == device:
|
||||
return x
|
||||
|
||||
return x.to(device)
|
||||
@ -993,8 +991,8 @@ class VllmRunner:
|
||||
|
||||
perplexities = []
|
||||
for output in outputs:
|
||||
output = cast(TokensTextLogprobsPromptLogprobs, output)
|
||||
token_datas = cast(list[dict[int, Logprob] | None], output[3])
|
||||
assert isinstance(output, TokensTextLogprobsPromptLogprobs)
|
||||
token_datas = cast(list[LogprobsOnePosition | None], output[3])
|
||||
assert token_datas[0] is None
|
||||
token_log_probs = []
|
||||
for token_data in token_datas[1:]:
|
||||
@ -1025,12 +1023,11 @@ class VllmRunner:
|
||||
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens),
|
||||
concurrency_limit=concurrency_limit,
|
||||
)
|
||||
returned_outputs = []
|
||||
for output in outputs:
|
||||
token_ids = [x.tokens for x in output.sequences]
|
||||
texts = [x.text for x in output.sequences]
|
||||
returned_outputs.append((token_ids, texts))
|
||||
return returned_outputs
|
||||
|
||||
return [
|
||||
([x.tokens for x in output.sequences], [x.text for x in output.sequences])
|
||||
for output in outputs
|
||||
]
|
||||
|
||||
def classify(self, prompts: list[str]) -> list[list[float]]:
|
||||
req_outputs = self.llm.classify(prompts)
|
||||
@ -1405,11 +1402,7 @@ class LocalAssetServer:
|
||||
return f"{self.base_url}/{name}"
|
||||
|
||||
def get_image_asset(self, name: str) -> Image.Image:
|
||||
image = fetch_image(self.url_for(name))
|
||||
# Unwrap MediaWithBytes if present
|
||||
if isinstance(image, MediaWithBytes):
|
||||
image = image.media
|
||||
return image
|
||||
return fetch_image(self.url_for(name))
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user