[ci] Fix sampler tests (#11922)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2025-01-10 20:45:33 +08:00 committed by GitHub
parent d85c47d6ad
commit 241ad7b301
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 2 deletions

View File

@ -214,6 +214,7 @@ steps:
- vllm/model_executor/layers
- vllm/sampling_metadata.py
- tests/samplers
- tests/conftest.py
commands:
- pytest -v -s samplers
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers

View File

@ -28,12 +28,13 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
TokensPrompt, to_enc_dec_tuple_list,
zip_enc_dec_prompts)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity)
identity, is_list_of)
logger = init_logger(__name__)
@ -886,6 +887,12 @@ class VllmRunner:
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
if is_list_of(prompts, str, check="all"):
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
else:
prompts = [
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
]
outputs = self.model.beam_search(
prompts,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))