Add unit tests for batched guided and non-guided requests (#23389)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin 2025-08-22 10:31:24 -07:00 committed by GitHub
parent 341923b982
commit b6d7d34fc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,9 +11,11 @@ from typing import TYPE_CHECKING, Any
import jsonschema import jsonschema
import pytest import pytest
import regex as re import regex as re
import torch
from pydantic import BaseModel from pydantic import BaseModel
from tests.reasoning.utils import run_reasoning_extraction from tests.reasoning.utils import run_reasoning_extraction
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -727,3 +729,83 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
assert "a4" not in generated assert "a4" not in generated
assert "a5" not in generated assert "a5" not in generated
assert "a6" not in generated assert "a6" not in generated
@pytest.mark.parametrize("guided_decoding_backend",
["guidance", "xgrammar", "outlines"])
def test_structured_output_batched_with_non_guided_requests(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
guided_decoding_backend: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
# Don't use eager execution on TPUs because we want to test for no
# recompilation at runtime
enforce_eager = bool(not current_platform.is_tpu())
llm = LLM(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}),
)
guided_prompt = (
"Give an example JSON for an employee profile that fits this "
"schema. Make the response as short as possible. Schema: "
f"{sample_json_schema}")
non_guided_prompt = "The diameter of the Earth in kilometers is "
prompts = [guided_prompt, non_guided_prompt]
sampling_params = [
SamplingParams(
temperature=1.0,
max_tokens=400,
guided_decoding=GuidedDecodingParams(json=sample_json_schema)),
# No max tokens, temp=0 to assert on contents
SamplingParams(
seed=42,
temperature=0,
top_p=1.0,
),
]
outputs = llm.generate(prompts=prompts,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
# Free memory as soon as possible as failed assertions
# will short circuit and not free up memory
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
for index, output in enumerate(outputs):
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")
if index == 0:
# First prompt is guided, expect valid JSON
assert "\n" not in generated_text
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_json_schema)
else:
# Second prompt is not guided, expect valid output
# Cannot assert on exact output, but we can expect it to be factual
assert "12,742" in generated_text
# non-guided requests should not return a valid JSON here
with pytest.raises(ValueError):
output_json = json.loads(generated_text)