Add unit tests for batched guided and non-guided requests (#23389)

Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
Yong Hoon Shin 2025-08-22 10:31:24 -07:00 committed by GitHub
parent 341923b982
commit b6d7d34fc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,9 +11,11 @@ from typing import TYPE_CHECKING, Any
import jsonschema
import pytest
import regex as re
import torch
from pydantic import BaseModel
from tests.reasoning.utils import run_reasoning_extraction
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
@ -727,3 +729,83 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
assert "a4" not in generated
assert "a5" not in generated
assert "a6" not in generated
@pytest.mark.parametrize("guided_decoding_backend",
["guidance", "xgrammar", "outlines"])
def test_structured_output_batched_with_non_guided_requests(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
guided_decoding_backend: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
# Don't use eager execution on TPUs because we want to test for no
# recompilation at runtime
enforce_eager = bool(not current_platform.is_tpu())
llm = LLM(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}),
)
guided_prompt = (
"Give an example JSON for an employee profile that fits this "
"schema. Make the response as short as possible. Schema: "
f"{sample_json_schema}")
non_guided_prompt = "The diameter of the Earth in kilometers is "
prompts = [guided_prompt, non_guided_prompt]
sampling_params = [
SamplingParams(
temperature=1.0,
max_tokens=400,
guided_decoding=GuidedDecodingParams(json=sample_json_schema)),
# No max tokens, temp=0 to assert on contents
SamplingParams(
seed=42,
temperature=0,
top_p=1.0,
),
]
outputs = llm.generate(prompts=prompts,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
# Free memory as soon as possible as failed assertions
# will short circuit and not free up memory
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
for index, output in enumerate(outputs):
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")
if index == 0:
# First prompt is guided, expect valid JSON
assert "\n" not in generated_text
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_json_schema)
else:
# Second prompt is not guided, expect valid output
# Cannot assert on exact output, but we can expect it to be factual
assert "12,742" in generated_text
# non-guided requests should not return a valid JSON here
with pytest.raises(ValueError):
output_json = json.loads(generated_text)