diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py index 8ef121ebe848e..f46064931dbac 100644 --- a/examples/offline_inference/structured_outputs.py +++ b/examples/offline_inference/structured_outputs.py @@ -15,6 +15,8 @@ from pydantic import BaseModel from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams +MAX_TOKENS = 50 + # Guided decoding by Choice (list of possible options) guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"]) sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice) @@ -23,7 +25,9 @@ prompt_choice = "Classify this sentiment: vLLM is wonderful!" # Guided decoding by Regex guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n") sampling_params_regex = SamplingParams( - guided_decoding=guided_decoding_params_regex, stop=["\n"] + guided_decoding=guided_decoding_params_regex, + stop=["\n"], + max_tokens=MAX_TOKENS, ) prompt_regex = ( "Generate an email address for Alan Turing, who works in Enigma." @@ -48,7 +52,10 @@ class CarDescription(BaseModel): json_schema = CarDescription.model_json_schema() guided_decoding_params_json = GuidedDecodingParams(json=json_schema) -sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json) +sampling_params_json = SamplingParams( + guided_decoding=guided_decoding_params_json, + max_tokens=MAX_TOKENS, +) prompt_json = ( "Generate a JSON with the brand, model and car_type of" "the most iconic car from the 90's" @@ -64,7 +71,10 @@ condition ::= column "= " number number ::= "1 " | "2 " """ guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar) -sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar) +sampling_params_grammar = SamplingParams( + guided_decoding=guided_decoding_params_grammar, + max_tokens=MAX_TOKENS, +) prompt_grammar = ( "Generate an SQL query to show the 'username' and 'email'from the 'users' table." )