[CI] Speed up V1 structured output tests (#15718)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-03-29 00:10:45 -04:00 committed by GitHub
parent 1286211f57
commit 7a7992085b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -23,20 +23,46 @@ MODELS_TO_TEST = [
] ]
class CarType(str, Enum):
sedan = "sedan"
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1) GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_completion( def test_structured_output(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any], sample_json_schema: dict[str, Any],
unsupported_json_schema: dict[str, Any],
sample_sql_ebnf: str,
sample_sql_lark: str,
sample_regex: str,
sample_guided_choice: str,
guided_decoding_backend: str, guided_decoding_backend: str,
model_name: str, model_name: str,
): ):
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
# Use a single LLM instance for several scenarios to
# speed up the test suite.
llm = LLM(model=model_name, llm = LLM(model=model_name,
enforce_eager=True,
max_model_len=1024, max_model_len=1024,
guided_decoding_backend=guided_decoding_backend) guided_decoding_backend=guided_decoding_backend)
#
# Test 1: Generate JSON output based on a provided schema
#
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
max_tokens=1000, max_tokens=1000,
@ -63,20 +89,9 @@ def test_guided_json_completion(
output_json = json.loads(generated_text) output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema) jsonschema.validate(instance=output_json, schema=sample_json_schema)
#
@pytest.mark.skip_global_cleanup # Test 2: Generate JSON object without a schema
@pytest.mark.parametrize("guided_decoding_backend", #
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_object(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
max_tokens=100, max_tokens=100,
@ -111,21 +126,9 @@ def test_guided_json_object(
allowed_types = (dict, list) allowed_types = (dict, list)
assert isinstance(parsed_json, allowed_types) assert isinstance(parsed_json, allowed_types)
#
@pytest.mark.skip_global_cleanup # Test 3: test a jsonschema incompatible with xgrammar
@pytest.mark.parametrize("guided_decoding_backend", #
GUIDED_DECODING_BACKENDS_V1 + ["auto"])
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_unsupported_schema(
monkeypatch: pytest.MonkeyPatch,
unsupported_json_schema: dict[str, Any],
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
max_tokens=1000, max_tokens=1000,
@ -141,8 +144,6 @@ def test_guided_json_unsupported_schema(
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True)
else: else:
# This should work for both "guidance" and "auto".
outputs = llm.generate( outputs = llm.generate(
prompts=("Give an example JSON object for a grade " prompts=("Give an example JSON object for a grade "
"that fits this schema: " "that fits this schema: "
@ -161,21 +162,9 @@ def test_guided_json_unsupported_schema(
parsed_json = json.loads(generated_text) parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict) assert isinstance(parsed_json, dict)
#
@pytest.mark.skip_global_cleanup # Test 4: Generate SQL statement using EBNF grammar
@pytest.mark.parametrize("guided_decoding_backend", #
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_grammar_ebnf(
monkeypatch: pytest.MonkeyPatch,
sample_sql_ebnf: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
@ -205,21 +194,9 @@ def test_guided_grammar_ebnf(
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
#
@pytest.mark.skip_global_cleanup # Test 5: Generate SQL statement using Lark grammar
@pytest.mark.parametrize("guided_decoding_backend", #
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_grammar_lark(
monkeypatch: pytest.MonkeyPatch,
sample_sql_lark: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
@ -254,20 +231,9 @@ def test_guided_grammar_lark(
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
#
@pytest.mark.skip_global_cleanup # Test 6: Test invalid grammar input
@pytest.mark.parametrize("guided_decoding_backend", #
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_grammar_ebnf_invalid(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
@ -281,21 +247,9 @@ def test_guided_grammar_ebnf_invalid(
use_tqdm=True, use_tqdm=True,
) )
#
@pytest.mark.skip_global_cleanup # Test 7: Generate text based on a regex pattern
@pytest.mark.parametrize("guided_decoding_backend", #
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_regex(
monkeypatch: pytest.MonkeyPatch,
sample_regex: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
@ -319,21 +273,9 @@ def test_guided_regex(
assert re.fullmatch(sample_regex, generated_text) is not None assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
#
@pytest.mark.skip_global_cleanup # Test 8: Generate text based on a choices
@pytest.mark.parametrize("guided_decoding_backend", #
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_choice_completion(
monkeypatch: pytest.MonkeyPatch,
sample_guided_choice: str,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
@ -353,33 +295,9 @@ def test_guided_choice_completion(
assert generated_text in sample_guided_choice assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
#
class CarType(str, Enum): # Test 9: Generate structured output using a Pydantic model with an enum
sedan = "sedan" #
suv = "SUV"
truck = "Truck"
coupe = "Coupe"
class CarDescription(BaseModel):
brand: str
model: str
car_type: CarType
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_guided_json_completion_with_enum(
monkeypatch: pytest.MonkeyPatch,
guided_decoding_backend: str,
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend)
json_schema = CarDescription.model_json_schema() json_schema = CarDescription.model_json_schema()
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
@ -403,3 +321,41 @@ def test_guided_json_completion_with_enum(
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text) output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=json_schema) jsonschema.validate(instance=output_json, schema=json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
def test_structured_output_auto_mode(
monkeypatch: pytest.MonkeyPatch,
unsupported_json_schema: dict[str, Any],
model_name: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=model_name,
max_model_len=1024,
guided_decoding_backend="auto")
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
# This would fail with the default of "xgrammar", but in "auto"
# we will handle fallback automatically.
outputs = llm.generate(prompts=("Give an example JSON object for a grade "
"that fits this schema: "
f"{unsupported_json_schema}"),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
generated_text = output.outputs[0].text
assert generated_text is not None
print(generated_text)
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)