From 062c89e7c9c6fa9fd7fb2d28fd50321c6f78f389 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 30 Sep 2024 19:34:25 -0600 Subject: [PATCH] [Frontend][Core] Move guided decoding params into sampling params (#8252) Signed-off-by: Joe Runde Co-authored-by: Nick Hill --- tests/entrypoints/llm/test_guided_generate.py | 66 +++++++++----- tests/model_executor/conftest.py | 49 ++++++++++ .../test_guided_processors.py | 35 +++++--- vllm/engine/async_llm_engine.py | 44 +++++++++ vllm/engine/llm_engine.py | 54 +++++++++++ vllm/engine/multiprocessing/client.py | 14 +++ vllm/entrypoints/llm.py | 48 ++++++---- vllm/entrypoints/openai/protocol.py | 82 ++++++++++------- vllm/entrypoints/openai/serving_chat.py | 5 -- vllm/entrypoints/openai/serving_completion.py | 4 - vllm/entrypoints/openai/serving_engine.py | 13 +-- .../guided_decoding/__init__.py | 68 ++++---------- .../guided_decoding/guided_fields.py | 1 + .../lm_format_enforcer_decoding.py | 90 ++++--------------- .../guided_decoding/outlines_decoding.py | 72 ++++----------- vllm/sampling_params.py | 77 +++++++++++++++- 16 files changed, 441 insertions(+), 281 deletions(-) create mode 100644 tests/model_executor/conftest.py rename tests/{entrypoints/openai => model_executor}/test_guided_processors.py (69%) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 873e115421257..2841dfc6bd9c2 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -7,7 +7,7 @@ import pytest from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...conftest import cleanup @@ -31,14 +31,12 @@ def test_guided_regex(sample_regex, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - ) - outputs = llm.generate( - prompts=[ - f"Give an example IPv4 address with this regex: {sample_regex}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_regex=sample_regex)) + guided_decoding=GuidedDecodingParams(regex=sample_regex)) + outputs = llm.generate(prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None for output in outputs: @@ -57,15 +55,13 @@ def test_guided_json_completion(sample_json_schema, llm): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, - ) - outputs = llm.generate( - prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_json=sample_json_schema)) + guided_decoding=GuidedDecodingParams(json=sample_json_schema)) + outputs = llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None @@ -86,12 +82,11 @@ def test_guided_choice_completion(sample_guided_choice, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - ) + guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) outputs = llm.generate( prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_choice=sample_guided_choice)) + use_tqdm=True) assert outputs is not None for output in outputs: @@ -112,13 +107,13 @@ def test_guided_grammar(sample_sql_statements, llm): temperature=0.8, top_p=0.95, max_tokens=1000, - ) + guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements)) outputs = llm.generate( prompts=("Generate a sql state that select col_1 from " "table_1 where it is equals to 1"), sampling_params=sampling_params, use_tqdm=True, - guided_options_request=dict(guided_grammar=sample_sql_statements)) + ) assert outputs is not None for output in outputs: @@ -140,3 +135,28 @@ def test_guided_grammar(sample_sql_statements, llm): assert generated_text.strip() == ground_truth print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +def test_guided_options_request_deprecation_warning(sample_regex, llm): + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + with pytest.warns(DeprecationWarning, match="guided_options_request"): + llm.generate(prompts="This should fail", + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex)) + + +@pytest.mark.skip_global_cleanup +def test_validation_against_both_guided_decoding_options(sample_regex, llm): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams(regex=sample_regex)) + + with pytest.raises(ValueError, match="Cannot set both"): + llm.generate(prompts="This should fail", + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex)) diff --git a/tests/model_executor/conftest.py b/tests/model_executor/conftest.py new file mode 100644 index 0000000000000..10792b0a04999 --- /dev/null +++ b/tests/model_executor/conftest.py @@ -0,0 +1,49 @@ +import pytest + + +@pytest.fixture +def sample_regex(): + return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + + +@pytest.fixture +def sample_json_schema(): + return { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work_history"] + } diff --git a/tests/entrypoints/openai/test_guided_processors.py b/tests/model_executor/test_guided_processors.py similarity index 69% rename from tests/entrypoints/openai/test_guided_processors.py rename to tests/model_executor/test_guided_processors.py index 85cb4d52200c3..45fab8e96b968 100644 --- a/tests/entrypoints/openai/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -1,14 +1,12 @@ -# This unit test should be moved to a new -# tests/test_guided_decoding directory. import pytest import torch from transformers import AutoTokenizer -from vllm.entrypoints.openai.protocol import CompletionRequest from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.sampling_params import GuidedDecodingParams def test_guided_logits_processors(sample_regex, sample_json_schema): @@ -44,11 +42,9 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") - regex_request = CompletionRequest(model='test', - prompt=token_ids, - guided_regex=sample_regex) + regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_lp = await get_guided_decoding_logits_processor( - backend, regex_request, tokenizer) + regex_request, tokenizer) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -59,14 +55,31 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, token_ids = tokenizer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}" ) - json_request = CompletionRequest(model='test', - prompt=token_ids, - guided_json=sample_json_schema) + json_request = GuidedDecodingParams(json=sample_json_schema, + backend=backend) json_lp = await get_guided_decoding_logits_processor( - backend, json_request, tokenizer) + json_request, tokenizer) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) tensor = json_lp(token_ids, tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) + + +def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): + with pytest.raises(ValueError, + match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, regex=sample_regex) + + with pytest.raises(ValueError, + match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, json_object=True) + + with pytest.raises(ValueError, + match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"]) + + with pytest.raises(ValueError, + match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7778732dd8be0..9664bb29a3667 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -20,6 +20,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams @@ -477,6 +479,18 @@ class _AsyncLLMEngine(LLMEngine): ) processed_inputs = self.input_processor(preprocessed_inputs) + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: + # Guided decoding has an async implementation for building logits + # processors in a separate threadpool. + # We want to invoke that here instead of using the blocking + # implementation in the LLMEngine + params = await build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=self.get_tokenizer(lora_request), + default_guided_backend=self.decoding_config. + guided_decoding_backend) + self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, @@ -494,6 +508,36 @@ class _AsyncLLMEngine(LLMEngine): self.model_executor.check_health() +async def build_guided_decoding_logits_processor_async( + sampling_params: SamplingParams, tokenizer: AnyTokenizer, + default_guided_backend: str) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Modifies sampling params in-place and returns + the modified sampling params.""" + if (guided_decoding := sampling_params.guided_decoding) is None: + return sampling_params + + logger.debug("Building guided decoding logits processor. " + "Params: %s", guided_decoding) + + guided_decoding.backend = guided_decoding.backend or default_guided_backend + + processor = await get_guided_decoding_logits_processor( + guided_params=guided_decoding, tokenizer=tokenizer) + + if processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(processor) + + # Unset guided decoding params after constructing the lp from them + sampling_params.guided_decoding = None + + return sampling_params + + class AsyncLLMEngine: """An asynchronous wrapper for :class:`LLMEngine`. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e3cd822f648fe..3550759f85dde 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -25,6 +25,7 @@ from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.util import create_output_by_sequence_group +from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster @@ -33,6 +34,8 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_local_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) @@ -843,6 +846,9 @@ class LLMEngine: raise ValueError(f"Cannot request more than " f"{max_logprobs} logprobs.") + sampling_params = self._build_logits_processors( + sampling_params, lora_request) + # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() @@ -1895,3 +1901,51 @@ class LLMEngine: # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def _build_logits_processors( + self, sampling_params: SamplingParams, + lora_request: Optional[LoRARequest]) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Returns the modified sampling params.""" + + logits_processors = [] + if (guided_decoding := sampling_params.guided_decoding) is not None: + + logger.debug( + "Building guided decoding logits processor in " + "LLMEngine. Params: %s", guided_decoding) + + tokenizer = self.get_tokenizer(lora_request=lora_request) + guided_decoding.backend = guided_decoding.backend or \ + self.decoding_config.guided_decoding_backend + + processor = get_local_guided_decoding_logits_processor( + guided_params=guided_decoding, tokenizer=tokenizer) + if processor: + logits_processors.append(processor) + + # Unset so this doesn't get passed down to the model + sampling_params.guided_decoding = None + + if (sampling_params.logit_bias or sampling_params.allowed_token_ids): + tokenizer = self.get_tokenizer(lora_request=lora_request) + + processors = get_logits_processors( + logit_bias=sampling_params.logit_bias, + allowed_token_ids=sampling_params.allowed_token_ids, + tokenizer=tokenizer) + logits_processors.extend(processors) + + # Unset so these don't get passed down to the model + sampling_params.logit_bias = None + sampling_params.allowed_token_ids = None + + if logits_processors: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = logits_processors + else: + sampling_params.logits_processors.extend(logits_processors) + + return sampling_params diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 700e65000e052..79da0be97fdbf 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -16,6 +16,8 @@ from vllm.config import DecodingConfig, EngineConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block # yapf: disable +from vllm.engine.async_llm_engine import ( + build_guided_decoding_logits_processor_async) from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, @@ -512,6 +514,18 @@ class MQLLMEngineClient: if self._errored_with is not None: raise ENGINE_DEAD_ERROR(self._errored_with) + # Constructing guided decoding logits processors is expensive, so we do + # it here to avoid contending with cpu resources and the GIL on the + # backend process. + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: + params = await \ + build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=await self.get_tokenizer(lora_request), + default_guided_backend=self.decoding_config.guided_decoding_backend + ) + # 1) Create output queue for this requests. queue: asyncio.Queue[Union[RequestOutput, BaseException]] = asyncio.Queue() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index bd009ae915c93..98d6df944da67 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,4 +1,5 @@ import itertools +import warnings from contextlib import contextmanager from dataclasses import dataclass from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, @@ -16,13 +17,13 @@ from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.guided_decoding import ( - GuidedDecodingRequest, get_local_guided_decoding_logits_processor) -from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest, LLMGuidedOptions) from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind, + SamplingParams) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -798,6 +799,14 @@ class LLM: guided_options: Optional[GuidedDecodingRequest] = None, priority: Optional[List[int]] = None, ) -> None: + if guided_options is not None: + warnings.warn( + "guided_options_request is deprecated, use " + "SamplingParams.guided_decoding instead", + DeprecationWarning, + stacklevel=2, + ) + if isinstance(prompts, (str, dict)): # Convert a single prompt to a list. prompts = [prompts] @@ -813,7 +822,7 @@ class LLM: for sp in params if isinstance(params, list) else (params, ): if isinstance(sp, SamplingParams): - self._add_guided_processor(sp, guided_options) + self._add_guided_params(sp, guided_options) # We only care about the final output sp.output_kind = RequestOutputKind.FINAL_ONLY @@ -847,22 +856,25 @@ class LLM: priority=priority, ) - def _add_guided_processor( + def _add_guided_params( self, params: SamplingParams, guided_options: Optional[GuidedDecodingRequest] = None): - if guided_options: - if guided_options.guided_decoding_backend is None: - decoding_config = self.llm_engine.get_decoding_config() - guided_options.guided_decoding_backend = ( - decoding_config.guided_decoding_backend) - guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa - guided_options.guided_decoding_backend, guided_options, - self.get_tokenizer()) - if guided_logits_processor: - if params.logits_processors is None: - params.logits_processors = [] - params.logits_processors.append(guided_logits_processor) + if guided_options is None: + return params + + if params.guided_decoding is not None: + raise ValueError("Cannot set both guided_options_request and" + "params.guided_decoding.") + + params.guided_decoding = GuidedDecodingParams( + json=guided_options.guided_json, + regex=guided_options.guided_regex, + choice=guided_options.guided_choice, + grammar=guided_options.guided_grammar, + json_object=guided_options.guided_json_object, + backend=guided_options.guided_decoding_backend, + whitespace_pattern=guided_options.guided_whitespace_pattern) return params def _run_engine( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f716e4a0458bf..c3101ca2b6900 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -10,12 +10,10 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated, Required, TypedDict from vllm.entrypoints.chat_utils import ChatCompletionMessageParam -from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams -from vllm.sampling_params import (LogitsProcessor, RequestOutputKind, +from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind, SamplingParams) from vllm.sequence import Logprob -from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid # torch is mocked during docs generation, @@ -284,10 +282,7 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params - def to_sampling_params( - self, tokenizer: AnyTokenizer, - guided_decode_logits_processor: Optional[LogitsProcessor], - default_max_tokens: int) -> SamplingParams: + def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens @@ -296,14 +291,19 @@ class ChatCompletionRequest(OpenAIBaseModel): if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs - # We now allow logprobs being true without top_logrobs. - logits_processors = get_logits_processors( - logit_bias=self.logit_bias, - allowed_token_ids=None, - tokenizer=tokenizer, - ) - if guided_decode_logits_processor: - logits_processors.append(guided_decode_logits_processor) + guided_json_object = None + if (self.response_format is not None + and self.response_format.type == "json_object"): + guided_json_object = True + + guided_decoding = GuidedDecodingParams.from_optional( + json=self._get_guided_json_from_tool() or self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern) return SamplingParams.from_optional( n=self.n, @@ -329,11 +329,29 @@ class ChatCompletionRequest(OpenAIBaseModel): spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, - logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY, - ) + guided_decoding=guided_decoding, + logit_bias=self.logit_bias) + + def _get_guided_json_from_tool( + self) -> Optional[Union[str, dict, BaseModel]]: + # user has chosen to not use any tool + if self.tool_choice == "none" or self.tools is None: + return None + + # user has chosen to use a named tool + if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_name = self.tool_choice.function.name + tools = {tool.function.name: tool.function for tool in self.tools} + if tool_name not in tools: + raise ValueError( + f"Tool '{tool_name}' has not been passed in `tools`.") + tool = tools[tool_name] + return tool.parameters + + return None @model_validator(mode="before") @classmethod @@ -537,10 +555,7 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params - def to_sampling_params( - self, tokenizer: AnyTokenizer, - guided_decode_logits_processor: Optional[LogitsProcessor], - default_max_tokens: int) -> SamplingParams: + def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens @@ -551,13 +566,19 @@ class CompletionRequest(OpenAIBaseModel): echo_without_generation = self.echo and self.max_tokens == 0 - logits_processors = get_logits_processors( - logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids, - tokenizer=tokenizer, - ) - if guided_decode_logits_processor: - logits_processors.append(guided_decode_logits_processor) + guided_json_object = None + if (self.response_format is not None + and self.response_format.type == "json_object"): + guided_json_object = True + + guided_decoding = GuidedDecodingParams.from_optional( + json=self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern) return SamplingParams.from_optional( n=self.n, @@ -583,11 +604,12 @@ class CompletionRequest(OpenAIBaseModel): spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, - logits_processors=logits_processors, truncate_prompt_tokens=self.truncate_prompt_tokens, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY, - ) + guided_decoding=guided_decoding, + logit_bias=self.logit_bias, + allowed_token_ids=self.allowed_token_ids) @model_validator(mode="before") @classmethod diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5625e34cca003..29a5b11b595c7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -187,9 +187,6 @@ class OpenAIServingChat(OpenAIServing): raw_request.state.request_metadata = request_metadata try: - guided_decode_logits_processor = ( - await self._guided_decode_logits_processor(request, tokenizer)) - if isinstance(prompt, str): prompt_inputs = self._tokenize_prompt_input( request, @@ -208,8 +205,6 @@ class OpenAIServingChat(OpenAIServing): assert prompt_inputs is not None sampling_params = request.to_sampling_params( - tokenizer, - guided_decode_logits_processor, default_max_tokens=self.max_model_len - len(prompt_inputs["prompt_token_ids"])) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 0e8609002e39e..a0161611288de 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -110,8 +110,6 @@ class OpenAIServingCompletion(OpenAIServing): tokenizer = await self.engine_client.get_tokenizer(lora_request) - guided_decode_logits_processor = ( - await self._guided_decode_logits_processor(request, tokenizer)) prompts = list( self._tokenize_prompt_input_or_inputs( request, @@ -123,8 +121,6 @@ class OpenAIServingCompletion(OpenAIServing): for i, prompt_inputs in enumerate(prompts): sampling_params = request.to_sampling_params( - tokenizer, - guided_decode_logits_processor, default_max_tokens=self.max_model_len - len(prompt_inputs["prompt_token_ids"])) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 9c4e8d8bb671a..1a0669d8d12c5 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -27,11 +27,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import LogitsProcessor, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import AtomicCounter @@ -168,15 +166,6 @@ class OpenAIServing: }) return json_str - async def _guided_decode_logits_processor( - self, request: Union[ChatCompletionRequest, CompletionRequest], - tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: - decoding_config = await self.engine_client.get_decoding_config() - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend - return await get_guided_decoding_logits_processor( - guided_decoding_backend, request, tokenizer) - async def _check_model( self, request: AnyRequest, diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 7161e83952a3d..368436aa14613 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -1,77 +1,45 @@ -from typing import Optional, Union +from typing import Optional -from vllm.entrypoints.openai.protocol import ( - ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, - CompletionRequest) -from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingRequest) -from vllm.sampling_params import LogitsProcessor +from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor async def get_guided_decoding_logits_processor( - guided_decoding_backend: str, request: Union[CompletionRequest, - ChatCompletionRequest], + guided_params: GuidedDecodingParams, tokenizer) -> Optional[LogitsProcessor]: - request = _adapt_request_for_tool_use(request) - - if guided_decoding_backend == 'outlines': + # CFG grammar not supported by LMFE, so we use outlines instead + if guided_params.backend == 'outlines' or guided_params.grammar: # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( - request, tokenizer) - if guided_decoding_backend == 'lm-format-enforcer': + guided_params, tokenizer) + if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa - get_lm_format_enforcer_guided_decoding_logits_processor) - return await get_lm_format_enforcer_guided_decoding_logits_processor( - request, tokenizer) + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params, tokenizer) raise ValueError( - f"Unknown guided decoding backend '{guided_decoding_backend}'. " + f"Unknown guided decoding backend '{guided_params.backend}'. " "Must be one of 'outlines, 'lm-format-enforcer'") def get_local_guided_decoding_logits_processor( - guided_decoding_backend: str, guided_options: GuidedDecodingRequest, + guided_params: GuidedDecodingParams, tokenizer) -> Optional[LogitsProcessor]: - # request = _adapt_request_for_tool_use(request) - - if guided_decoding_backend == 'outlines': + # CFG grammar not supported by LMFE, so we use outlines instead + if guided_params.backend == 'outlines' or guided_params.grammar: # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( - guided_options, tokenizer) - if guided_decoding_backend == 'lm-format-enforcer': + guided_params, tokenizer) + if guided_params.backend == 'lm-format-enforcer': from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( - guided_options, tokenizer) + guided_params, tokenizer) raise ValueError( - f"Unknown guided decoding backend '{guided_decoding_backend}'. " + f"Unknown guided decoding backend '{guided_params.backend}'. " "Must be one of 'outlines, 'lm-format-enforcer'") - - -def _adapt_request_for_tool_use(request: Union[CompletionRequest, - ChatCompletionRequest]): - # the legacy completion API does not support tool use - if type(request) is CompletionRequest: - return request - - # user has chosen to not use any tool, - # OR is allowing the model to choose a tool. - if request.tool_choice == "none" or request.tool_choice == "auto": - return request - - # user has chosen to use a named tool - if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam: - tool_name = request.tool_choice.function.name - tools = {tool.function.name: tool.function for tool in request.tools} - if tool_name not in tools: - raise ValueError( - f"Tool '{tool_name}' has not been passed in `tools`.") - tool = tools[tool_name] - request.guided_json = tool.parameters - - return request diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index 3082ac1510ccc..8deb4c949824a 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, TypedDict, Union from pydantic import BaseModel +# These classes are deprecated, see SamplingParams class LLMGuidedOptions(TypedDict, total=False): guided_json: Union[Dict, BaseModel, str] guided_regex: str diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index 51f947981cac8..cf2162ed7720d 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -7,66 +7,13 @@ from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser, TokenEnforcerTokenizerData, UnionParser) from lmformatenforcer.integrations.vllm import ( build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) -from pydantic import BaseModel from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest) -from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingRequest) -from vllm.sampling_params import LogitsProcessor - - -async def get_lm_format_enforcer_guided_decoding_logits_processor( - request: Union[CompletionRequest, ChatCompletionRequest], - tokenizer) -> Optional[LogitsProcessor]: - """ - Given an OpenAI-compatible request, check for guided decoding parameters - and get the necessary logits processor for the given guide. - We cache logit processors by (guide, tokenizer), and on cache hit - we make a shallow copy to reuse the same underlying FSM. - """ - - tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( - tokenizer) - character_level_parser: CharacterLevelParser - if request.guided_json: - schema = _normalize_json_schema_object(request.guided_json) - character_level_parser = JsonSchemaParser(schema) - elif request.guided_choice: - character_level_parser = UnionParser( - [StringParser(choice) for choice in request.guided_choice]) - elif request.guided_regex: - character_level_parser = RegexParser(request.guided_regex) - elif request.guided_grammar: - # CFG grammar not supported by LMFE, revert to outlines - - # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 - from vllm.model_executor.guided_decoding.outlines_decoding import ( - get_outlines_guided_decoding_logits_processor) - return await get_outlines_guided_decoding_logits_processor( - request, tokenizer) - elif (request.response_format is not None - and request.response_format.type == "json_object"): - character_level_parser = JsonSchemaParser( - None) # None means any json object - elif (request.response_format is not None - and request.response_format.type == "json_schema" - and request.response_format.json_schema is not None - and request.response_format.json_schema.json_schema is not None): - schema = _normalize_json_schema_object( - request.response_format.json_schema.json_schema) - character_level_parser = JsonSchemaParser(schema) - else: - return None - - logits_processor = build_vllm_logits_processor(tokenizer_data, - character_level_parser) - return logits_processor +from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor def get_local_lm_format_enforcer_guided_decoding_logits_processor( - guided_options: GuidedDecodingRequest, + guided_params: GuidedDecodingParams, tokenizer) -> Optional[LogitsProcessor]: """ Given an OpenAI-compatible request, check for guided decoding parameters @@ -78,23 +25,20 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor( tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( tokenizer) character_level_parser: CharacterLevelParser - if guided_options.guided_json: - schema = _normalize_json_schema_object(guided_options.guided_json) - character_level_parser = JsonSchemaParser(schema) - elif guided_options.guided_choice: + if guided_params.json: + schema_dict = _normalize_json_schema_object(guided_params.json) + character_level_parser = JsonSchemaParser(schema_dict) + elif guided_params.choice: character_level_parser = UnionParser( - [StringParser(choice) for choice in guided_options.guided_choice]) - elif guided_options.guided_regex: - character_level_parser = RegexParser(guided_options.guided_regex) - elif guided_options.guided_grammar: - # CFG grammar not supported by LMFE, revert to outlines - - # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 - from vllm.model_executor.guided_decoding.outlines_decoding import ( - get_local_outlines_guided_decoding_logits_processor) - return get_local_outlines_guided_decoding_logits_processor( - guided_options, tokenizer) - elif guided_options.guided_json_object: + [StringParser(choice) for choice in guided_params.choice]) + elif guided_params.regex: + character_level_parser = RegexParser(guided_params.regex) + elif guided_params.grammar: + # CFG grammar not supported by LMFE + raise ValueError("Cannot construct a guided decoding logits processor" + " using the grammar option with the" + " lm_format_enforcer backend.") + elif guided_params.json_object: # None means any json object character_level_parser = JsonSchemaParser(None) else: @@ -105,13 +49,11 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor( return logits_processor -def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: +def _normalize_json_schema_object(schema: Union[str, dict]) -> dict: if isinstance(schema, str): return json_loads(schema) if isinstance(schema, dict): return schema - if isinstance(schema, BaseModel): - return schema.model_json_schema() raise AssertionError(f"Unsupported schema type {schema}") diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index e1f5b380120c5..8a7ff38bfeb1a 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -5,16 +5,11 @@ from json import dumps as json_dumps from re import escape as regex_escape from typing import Tuple, Union -from pydantic import BaseModel from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import ( - ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, - CompletionRequest) -from vllm.model_executor.guided_decoding.guided_fields import ( - GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.sampling_params import GuidedDecodingParams class GuidedDecodingMode(Enum): @@ -55,8 +50,7 @@ global_thread_pool = None # used for generating logits processor fsm async def get_outlines_guided_decoding_logits_processor( - request: Union[CompletionRequest, - ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -66,7 +60,7 @@ async def get_outlines_guided_decoding_logits_processor( we make a shallow copy to reuse the same underlying FSM. """ global global_thread_pool - guide, mode = _get_guide_and_mode(request) + guide, mode = _get_guide_and_mode(guided_params) if not guide or not mode: return None @@ -77,11 +71,11 @@ async def get_outlines_guided_decoding_logits_processor( return await loop.run_in_executor(global_thread_pool, _get_logits_processor, guide, tokenizer, - mode, request.guided_whitespace_pattern) + mode, guided_params.whitespace_pattern) def get_local_outlines_guided_decoding_logits_processor( - guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, None]: """ @@ -90,65 +84,37 @@ def get_local_outlines_guided_decoding_logits_processor( We cache logit processors by (guide, tokenizer), and on cache hit we make a shallow copy to reuse the same underlying FSM. """ - guide, mode = _get_guide_and_mode(guided_options) + guide, mode = _get_guide_and_mode(guided_params) if not guide or not mode: return None return _get_logits_processor(guide, tokenizer, mode, - guided_options.guided_whitespace_pattern) + guided_params.whitespace_pattern) def _get_guide_and_mode( - request: Union[CompletionRequest, ChatCompletionRequest, - GuidedDecodingRequest] + guided_params: GuidedDecodingParams ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: - # if the request is a chat completion request, AND the tool choice is a - # named tool choice, do guided decoding - # using that tool as the JSON schema - if isinstance(request, ChatCompletionRequest) and isinstance( - request.tool_choice, ChatCompletionNamedToolChoiceParam): - # Guided generation for tools/functions parameters - if request.tool_choice.type == "function": - for tool in request.tools: - if (tool.type == "function" and tool.function.name - == request.tool_choice.function.name): - json = json_dumps(tool.function.parameters, sort_keys=True) - return json, GuidedDecodingMode.JSON - return None, None - - elif request.guided_json: - if isinstance(request.guided_json, dict): + if guided_params.json: + if isinstance(guided_params.json, dict): # turn dict into hashable string - json = json_dumps(request.guided_json) - elif isinstance(request.guided_json, BaseModel): - # use pydantic signature so that different model classes - # with the same fields will get hashed the same - json = str(request.guided_json.__signature__) + json = json_dumps(guided_params.json) else: - json = request.guided_json + json = guided_params.json return json, GuidedDecodingMode.JSON - elif request.guided_regex: - return request.guided_regex, GuidedDecodingMode.REGEX - elif request.guided_choice: + elif guided_params.regex: + return guided_params.regex, GuidedDecodingMode.REGEX + elif guided_params.choice: # choice just uses regex choices = [ - regex_escape(str(choice)) for choice in request.guided_choice + regex_escape(str(choice)) for choice in guided_params.choice ] choices_regex = "(" + "|".join(choices) + ")" return choices_regex, GuidedDecodingMode.CHOICE - elif request.guided_grammar: - return request.guided_grammar, GuidedDecodingMode.GRAMMAR - elif (not isinstance(request, GuidedDecodingRequest) - and request.response_format is not None - and request.response_format.type == "json_object"): + elif guided_params.grammar: + return guided_params.grammar, GuidedDecodingMode.GRAMMAR + elif guided_params.json_object: return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR - elif (not isinstance(request, GuidedDecodingRequest) - and request.response_format is not None - and request.response_format.type == "json_schema" - and request.response_format.json_schema is not None - and request.response_format.json_schema.json_schema is not None): - json = json_dumps(request.response_format.json_schema.json_schema) - return json, GuidedDecodingMode.JSON else: return None, None diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index f9ba4b4777e4d..83f76410882de 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,11 +1,13 @@ """Sampling parameters for text generation.""" import copy +from dataclasses import dataclass from enum import Enum, IntEnum from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Set, Union import msgspec import torch +from pydantic import BaseModel from typing_extensions import Annotated import vllm.envs as envs @@ -34,6 +36,54 @@ first argument, and returns a modified tensor of logits to sample from.""" +# maybe make msgspec? +@dataclass +class GuidedDecodingParams: + """One of these fields will be used to build a logit processor.""" + json: Optional[Union[str, Dict]] = None + regex: Optional[str] = None + choice: Optional[List[str]] = None + grammar: Optional[str] = None + json_object: Optional[bool] = None + """These are other options that can be set""" + backend: Optional[str] = None + whitespace_pattern: Optional[str] = None + + @staticmethod + def from_optional( + json: Optional[Union[Dict, BaseModel, str]], + regex: Optional[str] = None, + choice: Optional[List[str]] = None, + grammar: Optional[str] = None, + json_object: Optional[bool] = None, + backend: Optional[str] = None, + whitespace_pattern: Optional[str] = None, + ) -> "GuidedDecodingParams": + # Extract json schemas from pydantic models + if isinstance(json, (BaseModel, type(BaseModel))): + json = json.model_json_schema() + return GuidedDecodingParams( + json=json, + regex=regex, + choice=choice, + grammar=grammar, + json_object=json_object, + backend=backend, + whitespace_pattern=whitespace_pattern, + ) + + def __post_init__(self): + """Validate that some fields are mutually exclusive.""" + guide_count = sum([ + self.json is not None, self.regex is not None, self.choice + is not None, self.grammar is not None, self.json_object is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding but multiple are " + f"specified: {self.__dict__}") + + class RequestOutputKind(Enum): # Return entire output so far in every RequestOutput CUMULATIVE = 0 @@ -124,6 +174,13 @@ class SamplingParams( truncate_prompt_tokens: If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). Defaults to None (i.e., no truncation). + guided_decoding: If provided, the engine will construct a guided + decoding logits processor from these parameters. Defaults to None. + logit_bias: If provided, the engine will construct a logits processor + that applies these logit biases. Defaults to None. + allowed_token_ids: If provided, the engine will construct a logits + processor which only retains scores for the given token ids. + Defaults to None. """ n: int = 1 @@ -164,6 +221,11 @@ class SamplingParams( output_text_buffer_length: int = 0 _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set) + # Fields used to construct logits processors + guided_decoding: Optional[GuidedDecodingParams] = None + logit_bias: Optional[Dict[int, float]] = None + allowed_token_ids: Optional[List[int]] = None + @staticmethod def from_optional( n: Optional[int] = 1, @@ -194,7 +256,16 @@ class SamplingParams( truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, + guided_decoding: Optional[GuidedDecodingParams] = None, + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None, + allowed_token_ids: Optional[List[int]] = None, ) -> "SamplingParams": + if logit_bias is not None: + logit_bias = { + int(token): bias + for token, bias in logit_bias.items() + } + return SamplingParams( n=1 if n is None else n, best_of=best_of, @@ -226,6 +297,9 @@ class SamplingParams( logits_processors=logits_processors, truncate_prompt_tokens=truncate_prompt_tokens, output_kind=output_kind, + guided_decoding=guided_decoding, + logit_bias=logit_bias, + allowed_token_ids=allowed_token_ids, ) def __post_init__(self) -> None: @@ -454,4 +528,5 @@ class SamplingParams( f"skip_special_tokens={self.skip_special_tokens}, " "spaces_between_special_tokens=" f"{self.spaces_between_special_tokens}, " - f"truncate_prompt_tokens={self.truncate_prompt_tokens})") + f"truncate_prompt_tokens={self.truncate_prompt_tokens}), " + f"guided_decoding={self.guided_decoding}")