diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 20a15bbc31e38..65462c316c036 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -75,7 +75,7 @@ class SampleRequest: Represents a single inference request for benchmarking. """ - prompt: str | list[str] + prompt: str | list[str] | list[int] prompt_len: int expected_output_len: int multi_modal_data: MultiModalDataDict | dict | list[dict] | None = None @@ -402,8 +402,9 @@ def gen_prompt_decode_to_target_len( remain_num_try = max_retry token_mismatch = 0 while True: - prompt = tokenizer.decode(token_sequence) - token_sequence = tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + # prompt = tokenizer.decode(token_sequence) + # token_sequence = tokenizer.encode(prompt, add_special_tokens=add_special_tokens) + prompt = token_sequence if remain_num_try <= 0: if len(token_sequence) != target_token_len: token_mismatch = len(token_sequence) - target_token_len