mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:44:56 +08:00
[CI] Speed up V1 structured output tests (#15718)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
1286211f57
commit
7a7992085b
@ -23,20 +23,46 @@ MODELS_TO_TEST = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CarType(str, Enum):
|
||||||
|
sedan = "sedan"
|
||||||
|
suv = "SUV"
|
||||||
|
truck = "Truck"
|
||||||
|
coupe = "Coupe"
|
||||||
|
|
||||||
|
|
||||||
|
class CarDescription(BaseModel):
|
||||||
|
brand: str
|
||||||
|
model: str
|
||||||
|
car_type: CarType
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
GUIDED_DECODING_BACKENDS_V1)
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
||||||
def test_guided_json_completion(
|
def test_structured_output(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
sample_json_schema: dict[str, Any],
|
sample_json_schema: dict[str, Any],
|
||||||
|
unsupported_json_schema: dict[str, Any],
|
||||||
|
sample_sql_ebnf: str,
|
||||||
|
sample_sql_lark: str,
|
||||||
|
sample_regex: str,
|
||||||
|
sample_guided_choice: str,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
# Use a single LLM instance for several scenarios to
|
||||||
|
# speed up the test suite.
|
||||||
llm = LLM(model=model_name,
|
llm = LLM(model=model_name,
|
||||||
|
enforce_eager=True,
|
||||||
max_model_len=1024,
|
max_model_len=1024,
|
||||||
guided_decoding_backend=guided_decoding_backend)
|
guided_decoding_backend=guided_decoding_backend)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Test 1: Generate JSON output based on a provided schema
|
||||||
|
#
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
@ -63,20 +89,9 @@ def test_guided_json_completion(
|
|||||||
output_json = json.loads(generated_text)
|
output_json = json.loads(generated_text)
|
||||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||||
|
|
||||||
|
#
|
||||||
@pytest.mark.skip_global_cleanup
|
# Test 2: Generate JSON object without a schema
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
#
|
||||||
GUIDED_DECODING_BACKENDS_V1)
|
|
||||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
||||||
def test_guided_json_object(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
guided_decoding_backend: str,
|
|
||||||
model_name: str,
|
|
||||||
):
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
||||||
llm = LLM(model=model_name,
|
|
||||||
max_model_len=1024,
|
|
||||||
guided_decoding_backend=guided_decoding_backend)
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
@ -111,21 +126,9 @@ def test_guided_json_object(
|
|||||||
allowed_types = (dict, list)
|
allowed_types = (dict, list)
|
||||||
assert isinstance(parsed_json, allowed_types)
|
assert isinstance(parsed_json, allowed_types)
|
||||||
|
|
||||||
|
#
|
||||||
@pytest.mark.skip_global_cleanup
|
# Test 3: test a jsonschema incompatible with xgrammar
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
#
|
||||||
GUIDED_DECODING_BACKENDS_V1 + ["auto"])
|
|
||||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
||||||
def test_guided_json_unsupported_schema(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
unsupported_json_schema: dict[str, Any],
|
|
||||||
guided_decoding_backend: str,
|
|
||||||
model_name: str,
|
|
||||||
):
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
||||||
llm = LLM(model=model_name,
|
|
||||||
max_model_len=1024,
|
|
||||||
guided_decoding_backend=guided_decoding_backend)
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
@ -141,8 +144,6 @@ def test_guided_json_unsupported_schema(
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=True)
|
use_tqdm=True)
|
||||||
else:
|
else:
|
||||||
# This should work for both "guidance" and "auto".
|
|
||||||
|
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
prompts=("Give an example JSON object for a grade "
|
prompts=("Give an example JSON object for a grade "
|
||||||
"that fits this schema: "
|
"that fits this schema: "
|
||||||
@ -161,21 +162,9 @@ def test_guided_json_unsupported_schema(
|
|||||||
parsed_json = json.loads(generated_text)
|
parsed_json = json.loads(generated_text)
|
||||||
assert isinstance(parsed_json, dict)
|
assert isinstance(parsed_json, dict)
|
||||||
|
|
||||||
|
#
|
||||||
@pytest.mark.skip_global_cleanup
|
# Test 4: Generate SQL statement using EBNF grammar
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
#
|
||||||
GUIDED_DECODING_BACKENDS_V1)
|
|
||||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
||||||
def test_guided_grammar_ebnf(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
sample_sql_ebnf: str,
|
|
||||||
guided_decoding_backend: str,
|
|
||||||
model_name: str,
|
|
||||||
):
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
||||||
llm = LLM(model=model_name,
|
|
||||||
max_model_len=1024,
|
|
||||||
guided_decoding_backend=guided_decoding_backend)
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
@ -205,21 +194,9 @@ def test_guided_grammar_ebnf(
|
|||||||
|
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
#
|
||||||
@pytest.mark.skip_global_cleanup
|
# Test 5: Generate SQL statement using Lark grammar
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
#
|
||||||
GUIDED_DECODING_BACKENDS_V1)
|
|
||||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
||||||
def test_guided_grammar_lark(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
sample_sql_lark: str,
|
|
||||||
guided_decoding_backend: str,
|
|
||||||
model_name: str,
|
|
||||||
):
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
||||||
llm = LLM(model=model_name,
|
|
||||||
max_model_len=1024,
|
|
||||||
guided_decoding_backend=guided_decoding_backend)
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
@ -254,20 +231,9 @@ def test_guided_grammar_lark(
|
|||||||
|
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
#
|
||||||
@pytest.mark.skip_global_cleanup
|
# Test 6: Test invalid grammar input
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
#
|
||||||
GUIDED_DECODING_BACKENDS_V1)
|
|
||||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
||||||
def test_guided_grammar_ebnf_invalid(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
guided_decoding_backend: str,
|
|
||||||
model_name: str,
|
|
||||||
):
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
||||||
llm = LLM(model=model_name,
|
|
||||||
max_model_len=1024,
|
|
||||||
guided_decoding_backend=guided_decoding_backend)
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
@ -281,21 +247,9 @@ def test_guided_grammar_ebnf_invalid(
|
|||||||
use_tqdm=True,
|
use_tqdm=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
#
|
||||||
@pytest.mark.skip_global_cleanup
|
# Test 7: Generate text based on a regex pattern
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
#
|
||||||
GUIDED_DECODING_BACKENDS_V1)
|
|
||||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
||||||
def test_guided_regex(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
sample_regex: str,
|
|
||||||
guided_decoding_backend: str,
|
|
||||||
model_name: str,
|
|
||||||
):
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
||||||
llm = LLM(model=model_name,
|
|
||||||
max_model_len=1024,
|
|
||||||
guided_decoding_backend=guided_decoding_backend)
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
@ -319,21 +273,9 @@ def test_guided_regex(
|
|||||||
assert re.fullmatch(sample_regex, generated_text) is not None
|
assert re.fullmatch(sample_regex, generated_text) is not None
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
#
|
||||||
@pytest.mark.skip_global_cleanup
|
# Test 8: Generate text based on a choices
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
#
|
||||||
GUIDED_DECODING_BACKENDS_V1)
|
|
||||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
||||||
def test_guided_choice_completion(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
sample_guided_choice: str,
|
|
||||||
guided_decoding_backend: str,
|
|
||||||
model_name: str,
|
|
||||||
):
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
||||||
llm = LLM(model=model_name,
|
|
||||||
max_model_len=1024,
|
|
||||||
guided_decoding_backend=guided_decoding_backend)
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
@ -353,33 +295,9 @@ def test_guided_choice_completion(
|
|||||||
assert generated_text in sample_guided_choice
|
assert generated_text in sample_guided_choice
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
#
|
||||||
class CarType(str, Enum):
|
# Test 9: Generate structured output using a Pydantic model with an enum
|
||||||
sedan = "sedan"
|
#
|
||||||
suv = "SUV"
|
|
||||||
truck = "Truck"
|
|
||||||
coupe = "Coupe"
|
|
||||||
|
|
||||||
|
|
||||||
class CarDescription(BaseModel):
|
|
||||||
brand: str
|
|
||||||
model: str
|
|
||||||
car_type: CarType
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
|
||||||
GUIDED_DECODING_BACKENDS_V1)
|
|
||||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
||||||
def test_guided_json_completion_with_enum(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
guided_decoding_backend: str,
|
|
||||||
model_name: str,
|
|
||||||
):
|
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
||||||
llm = LLM(model=model_name,
|
|
||||||
max_model_len=1024,
|
|
||||||
guided_decoding_backend=guided_decoding_backend)
|
|
||||||
json_schema = CarDescription.model_json_schema()
|
json_schema = CarDescription.model_json_schema()
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
@ -403,3 +321,41 @@ def test_guided_json_completion_with_enum(
|
|||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
output_json = json.loads(generated_text)
|
output_json = json.loads(generated_text)
|
||||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
||||||
|
def test_structured_output_auto_mode(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
unsupported_json_schema: dict[str, Any],
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
llm = LLM(model=model_name,
|
||||||
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend="auto")
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
|
||||||
|
|
||||||
|
# This would fail with the default of "xgrammar", but in "auto"
|
||||||
|
# we will handle fallback automatically.
|
||||||
|
outputs = llm.generate(prompts=("Give an example JSON object for a grade "
|
||||||
|
"that fits this schema: "
|
||||||
|
f"{unsupported_json_schema}"),
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
assert outputs is not None
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
# Parse to verify it is valid JSON
|
||||||
|
parsed_json = json.loads(generated_text)
|
||||||
|
assert isinstance(parsed_json, dict)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user