mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 11:11:19 +08:00
Add unit tests for batched guided and non-guided requests (#23389)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
parent
341923b982
commit
b6d7d34fc6
@ -11,9 +11,11 @@ from typing import TYPE_CHECKING, Any
|
|||||||
import jsonschema
|
import jsonschema
|
||||||
import pytest
|
import pytest
|
||||||
import regex as re
|
import regex as re
|
||||||
|
import torch
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from tests.reasoning.utils import run_reasoning_extraction
|
from tests.reasoning.utils import run_reasoning_extraction
|
||||||
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
from vllm.entrypoints.llm import LLM
|
from vllm.entrypoints.llm import LLM
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -727,3 +729,83 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
|
|||||||
assert "a4" not in generated
|
assert "a4" not in generated
|
||||||
assert "a5" not in generated
|
assert "a5" not in generated
|
||||||
assert "a6" not in generated
|
assert "a6" not in generated
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
["guidance", "xgrammar", "outlines"])
|
||||||
|
def test_structured_output_batched_with_non_guided_requests(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
sample_json_schema: dict[str, Any],
|
||||||
|
guided_decoding_backend: str,
|
||||||
|
):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
# Don't use eager execution on TPUs because we want to test for no
|
||||||
|
# recompilation at runtime
|
||||||
|
enforce_eager = bool(not current_platform.is_tpu())
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend=guided_decoding_backend,
|
||||||
|
guided_decoding_disable_any_whitespace=(guided_decoding_backend
|
||||||
|
in {"xgrammar", "guidance"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
guided_prompt = (
|
||||||
|
"Give an example JSON for an employee profile that fits this "
|
||||||
|
"schema. Make the response as short as possible. Schema: "
|
||||||
|
f"{sample_json_schema}")
|
||||||
|
|
||||||
|
non_guided_prompt = "The diameter of the Earth in kilometers is "
|
||||||
|
|
||||||
|
prompts = [guided_prompt, non_guided_prompt]
|
||||||
|
sampling_params = [
|
||||||
|
SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
max_tokens=400,
|
||||||
|
guided_decoding=GuidedDecodingParams(json=sample_json_schema)),
|
||||||
|
# No max tokens, temp=0 to assert on contents
|
||||||
|
SamplingParams(
|
||||||
|
seed=42,
|
||||||
|
temperature=0,
|
||||||
|
top_p=1.0,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = llm.generate(prompts=prompts,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
|
# Free memory as soon as possible as failed assertions
|
||||||
|
# will short circuit and not free up memory
|
||||||
|
del llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
for index, output in enumerate(outputs):
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")
|
||||||
|
|
||||||
|
if index == 0:
|
||||||
|
# First prompt is guided, expect valid JSON
|
||||||
|
assert "\n" not in generated_text
|
||||||
|
output_json = json.loads(generated_text)
|
||||||
|
jsonschema.validate(instance=output_json,
|
||||||
|
schema=sample_json_schema)
|
||||||
|
else:
|
||||||
|
# Second prompt is not guided, expect valid output
|
||||||
|
# Cannot assert on exact output, but we can expect it to be factual
|
||||||
|
assert "12,742" in generated_text
|
||||||
|
|
||||||
|
# non-guided requests should not return a valid JSON here
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
output_json = json.loads(generated_text)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user