mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
[ci] Fix sampler tests (#11922)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
d85c47d6ad
commit
241ad7b301
@ -214,6 +214,7 @@ steps:
|
|||||||
- vllm/model_executor/layers
|
- vllm/model_executor/layers
|
||||||
- vllm/sampling_metadata.py
|
- vllm/sampling_metadata.py
|
||||||
- tests/samplers
|
- tests/samplers
|
||||||
|
- tests/conftest.py
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s samplers
|
- pytest -v -s samplers
|
||||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||||
|
|||||||
@ -28,12 +28,13 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
|
|||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
initialize_model_parallel)
|
initialize_model_parallel)
|
||||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
TokensPrompt, to_enc_dec_tuple_list,
|
||||||
|
zip_enc_dec_prompts)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import BeamSearchParams
|
from vllm.sampling_params import BeamSearchParams
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||||
identity)
|
identity, is_list_of)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -886,6 +887,12 @@ class VllmRunner:
|
|||||||
beam_width: int,
|
beam_width: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||||
|
if is_list_of(prompts, str, check="all"):
|
||||||
|
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||||
|
else:
|
||||||
|
prompts = [
|
||||||
|
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
|
||||||
|
]
|
||||||
outputs = self.model.beam_search(
|
outputs = self.model.beam_search(
|
||||||
prompts,
|
prompts,
|
||||||
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
|
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user