mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 22:44:29 +08:00
Add unit tests for batched guided and non-guided requests (#23389)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
341923b982
commit
b6d7d34fc6
@ -11,9 +11,11 @@ from typing import TYPE_CHECKING, Any
|
||||
import jsonschema
|
||||
import pytest
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
@ -727,3 +729,83 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
||||
assert "a4" not in generated
|
||||
assert "a5" not in generated
|
||||
assert "a6" not in generated
|
||||
|
||||
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["guidance", "xgrammar", "outlines"])
|
||||
def test_structured_output_batched_with_non_guided_requests(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
sample_json_schema: dict[str, Any],
|
||||
guided_decoding_backend: str,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
# Don't use eager execution on TPUs because we want to test for no
|
||||
# recompilation at runtime
|
||||
enforce_eager = bool(not current_platform.is_tpu())
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
enforce_eager=enforce_eager,
|
||||
max_model_len=1024,
|
||||
guided_decoding_backend=guided_decoding_backend,
|
||||
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
||||
in {"xgrammar", "guidance"}),
|
||||
)
|
||||
|
||||
guided_prompt = (
|
||||
"Give an example JSON for an employee profile that fits this "
|
||||
"schema. Make the response as short as possible. Schema: "
|
||||
f"{sample_json_schema}")
|
||||
|
||||
non_guided_prompt = "The diameter of the Earth in kilometers is "
|
||||
|
||||
prompts = [guided_prompt, non_guided_prompt]
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=400,
|
||||
guided_decoding=GuidedDecodingParams(json=sample_json_schema)),
|
||||
# No max tokens, temp=0 to assert on contents
|
||||
SamplingParams(
|
||||
seed=42,
|
||||
temperature=0,
|
||||
top_p=1.0,
|
||||
),
|
||||
]
|
||||
|
||||
outputs = llm.generate(prompts=prompts,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
# Free memory as soon as possible as failed assertions
|
||||
# will short circuit and not free up memory
|
||||
del llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
for index, output in enumerate(outputs):
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")
|
||||
|
||||
if index == 0:
|
||||
# First prompt is guided, expect valid JSON
|
||||
assert "\n" not in generated_text
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json,
|
||||
schema=sample_json_schema)
|
||||
else:
|
||||
# Second prompt is not guided, expect valid output
|
||||
# Cannot assert on exact output, but we can expect it to be factual
|
||||
assert "12,742" in generated_text
|
||||
|
||||
# non-guided requests should not return a valid JSON here
|
||||
with pytest.raises(ValueError):
|
||||
output_json = json.loads(generated_text)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user