[Frontend][Core] Move guided decoding params into sampling params (#8252)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Joe Runde 2024-09-30 19:34:25 -06:00 committed by GitHub
parent bce324487a
commit 062c89e7c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 441 additions and 281 deletions

View File

@ -7,7 +7,7 @@ import pytest
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ...conftest import cleanup from ...conftest import cleanup
@ -31,14 +31,12 @@ def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
) guided_decoding=GuidedDecodingParams(regex=sample_regex))
outputs = llm.generate( outputs = llm.generate(prompts=[
prompts=[ f"Give an example IPv4 address with this regex: {sample_regex}"
f"Give an example IPv4 address with this regex: {sample_regex}" ] * 2,
] * 2, sampling_params=sampling_params,
sampling_params=sampling_params, use_tqdm=True)
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
assert outputs is not None assert outputs is not None
for output in outputs: for output in outputs:
@ -57,15 +55,13 @@ def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
max_tokens=1000, max_tokens=1000,
) guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate( outputs = llm.generate(prompts=[
prompts=[ f"Give an example JSON for an employee profile "
f"Give an example JSON for an employee profile " f"that fits this schema: {sample_json_schema}"
f"that fits this schema: {sample_json_schema}" ] * 2,
] * 2, sampling_params=sampling_params,
sampling_params=sampling_params, use_tqdm=True)
use_tqdm=True,
guided_options_request=dict(guided_json=sample_json_schema))
assert outputs is not None assert outputs is not None
@ -86,12 +82,11 @@ def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
) guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
outputs = llm.generate( outputs = llm.generate(
prompts="The best language for type-safe systems programming is ", prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True, use_tqdm=True)
guided_options_request=dict(guided_choice=sample_guided_choice))
assert outputs is not None assert outputs is not None
for output in outputs: for output in outputs:
@ -112,13 +107,13 @@ def test_guided_grammar(sample_sql_statements, llm):
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
max_tokens=1000, max_tokens=1000,
) guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements))
outputs = llm.generate( outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from " prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"), "table_1 where it is equals to 1"),
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True, use_tqdm=True,
guided_options_request=dict(guided_grammar=sample_sql_statements)) )
assert outputs is not None assert outputs is not None
for output in outputs: for output in outputs:
@ -140,3 +135,28 @@ def test_guided_grammar(sample_sql_statements, llm):
assert generated_text.strip() == ground_truth assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
def test_guided_options_request_deprecation_warning(sample_regex, llm):
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
with pytest.warns(DeprecationWarning, match="guided_options_request"):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
@pytest.mark.skip_global_cleanup
def test_validation_against_both_guided_decoding_options(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(regex=sample_regex))
with pytest.raises(ValueError, match="Cannot set both"):
llm.generate(prompts="This should fail",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))

View File

@ -0,0 +1,49 @@
import pytest
@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}

View File

@ -1,14 +1,12 @@
# This unit test should be moved to a new
# tests/test_guided_decoding directory.
import pytest import pytest
import torch import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.entrypoints.openai.protocol import CompletionRequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor) JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
def test_guided_logits_processors(sample_regex, sample_json_schema): def test_guided_logits_processors(sample_regex, sample_json_schema):
@ -44,11 +42,9 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode( token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {sample_regex}") f"Give an example IPv4 address with this regex: {sample_regex}")
regex_request = CompletionRequest(model='test', regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
prompt=token_ids,
guided_regex=sample_regex)
regex_lp = await get_guided_decoding_logits_processor( regex_lp = await get_guided_decoding_logits_processor(
backend, regex_request, tokenizer) regex_request, tokenizer)
assert regex_lp is not None assert regex_lp is not None
tensor = torch.rand(32000) tensor = torch.rand(32000)
original_tensor = torch.clone(tensor) original_tensor = torch.clone(tensor)
@ -59,14 +55,31 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
token_ids = tokenizer.encode( token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {sample_json_schema}" f"Give an employee profile that fits this schema: {sample_json_schema}"
) )
json_request = CompletionRequest(model='test', json_request = GuidedDecodingParams(json=sample_json_schema,
prompt=token_ids, backend=backend)
guided_json=sample_json_schema)
json_lp = await get_guided_decoding_logits_processor( json_lp = await get_guided_decoding_logits_processor(
backend, json_request, tokenizer) json_request, tokenizer)
assert json_lp is not None assert json_lp is not None
tensor = torch.rand(32000) tensor = torch.rand(32000)
original_tensor = torch.clone(tensor) original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor) tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor) assert not torch.allclose(tensor, original_tensor)
def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, regex=sample_regex)
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, json_object=True)
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"])
with pytest.raises(ValueError,
match="You can only use one kind of guided"):
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")

View File

@ -20,6 +20,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
@ -477,6 +479,18 @@ class _AsyncLLMEngine(LLMEngine):
) )
processed_inputs = self.input_processor(preprocessed_inputs) processed_inputs = self.input_processor(preprocessed_inputs)
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
# Guided decoding has an async implementation for building logits
# processors in a separate threadpool.
# We want to invoke that here instead of using the blocking
# implementation in the LLMEngine
params = await build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=self.get_tokenizer(lora_request),
default_guided_backend=self.decoding_config.
guided_decoding_backend)
self._add_processed_request( self._add_processed_request(
request_id=request_id, request_id=request_id,
processed_inputs=processed_inputs, processed_inputs=processed_inputs,
@ -494,6 +508,36 @@ class _AsyncLLMEngine(LLMEngine):
self.model_executor.check_health() self.model_executor.check_health()
async def build_guided_decoding_logits_processor_async(
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
default_guided_backend: str) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if (guided_decoding := sampling_params.guided_decoding) is None:
return sampling_params
logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)
guided_decoding.backend = guided_decoding.backend or default_guided_backend
processor = await get_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
if processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(processor)
# Unset guided decoding params after constructing the lp from them
sampling_params.guided_decoding = None
return sampling_params
class AsyncLLMEngine: class AsyncLLMEngine:
"""An asynchronous wrapper for :class:`LLMEngine`. """An asynchronous wrapper for :class:`LLMEngine`.

View File

@ -25,6 +25,7 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
@ -33,6 +34,8 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_local_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
@ -843,6 +846,9 @@ class LLMEngine:
raise ValueError(f"Cannot request more than " raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.") f"{max_logprobs} logprobs.")
sampling_params = self._build_logits_processors(
sampling_params, lora_request)
# Defensive copy of SamplingParams, which are used by the sampler, # Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects # this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone() sampling_params = sampling_params.clone()
@ -1895,3 +1901,51 @@ class LLMEngine:
# TODO: Find out how many placeholder tokens are there so we can # TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them # check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens # max_batch_len = self.scheduler_config.max_num_batched_tokens
def _build_logits_processors(
self, sampling_params: SamplingParams,
lora_request: Optional[LoRARequest]) -> SamplingParams:
"""Constructs logits processors based on the guided_decoding,
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
those fields and adds the constructed logits processors to the
logits_processors field. Returns the modified sampling params."""
logits_processors = []
if (guided_decoding := sampling_params.guided_decoding) is not None:
logger.debug(
"Building guided decoding logits processor in "
"LLMEngine. Params: %s", guided_decoding)
tokenizer = self.get_tokenizer(lora_request=lora_request)
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding, tokenizer=tokenizer)
if processor:
logits_processors.append(processor)
# Unset so this doesn't get passed down to the model
sampling_params.guided_decoding = None
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
tokenizer = self.get_tokenizer(lora_request=lora_request)
processors = get_logits_processors(
logit_bias=sampling_params.logit_bias,
allowed_token_ids=sampling_params.allowed_token_ids,
tokenizer=tokenizer)
logits_processors.extend(processors)
# Unset so these don't get passed down to the model
sampling_params.logit_bias = None
sampling_params.allowed_token_ids = None
if logits_processors:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = logits_processors
else:
sampling_params.logits_processors.extend(logits_processors)
return sampling_params

View File

@ -16,6 +16,8 @@ from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T, IPC_OUTPUT_EXT, RPC_REQUEST_T,
@ -512,6 +514,18 @@ class MQLLMEngineClient:
if self._errored_with is not None: if self._errored_with is not None:
raise ENGINE_DEAD_ERROR(self._errored_with) raise ENGINE_DEAD_ERROR(self._errored_with)
# Constructing guided decoding logits processors is expensive, so we do
# it here to avoid contending with cpu resources and the GIL on the
# backend process.
if isinstance(params, SamplingParams) and \
params.guided_decoding is not None:
params = await \
build_guided_decoding_logits_processor_async(
sampling_params=params,
tokenizer=await self.get_tokenizer(lora_request),
default_guided_backend=self.decoding_config.guided_decoding_backend
)
# 1) Create output queue for this requests. # 1) Create output queue for this requests.
queue: asyncio.Queue[Union[RequestOutput, queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue() BaseException]] = asyncio.Queue()

View File

@ -1,4 +1,5 @@
import itertools import itertools
import warnings
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
@ -16,13 +17,13 @@ from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, get_local_guided_decoding_logits_processor) GuidedDecodingRequest, LLMGuidedOptions)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -798,6 +799,14 @@ class LLM:
guided_options: Optional[GuidedDecodingRequest] = None, guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None, priority: Optional[List[int]] = None,
) -> None: ) -> None:
if guided_options is not None:
warnings.warn(
"guided_options_request is deprecated, use "
"SamplingParams.guided_decoding instead",
DeprecationWarning,
stacklevel=2,
)
if isinstance(prompts, (str, dict)): if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
prompts = [prompts] prompts = [prompts]
@ -813,7 +822,7 @@ class LLM:
for sp in params if isinstance(params, list) else (params, ): for sp in params if isinstance(params, list) else (params, ):
if isinstance(sp, SamplingParams): if isinstance(sp, SamplingParams):
self._add_guided_processor(sp, guided_options) self._add_guided_params(sp, guided_options)
# We only care about the final output # We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY sp.output_kind = RequestOutputKind.FINAL_ONLY
@ -847,22 +856,25 @@ class LLM:
priority=priority, priority=priority,
) )
def _add_guided_processor( def _add_guided_params(
self, self,
params: SamplingParams, params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None): guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options: if guided_options is None:
if guided_options.guided_decoding_backend is None: return params
decoding_config = self.llm_engine.get_decoding_config()
guided_options.guided_decoding_backend = ( if params.guided_decoding is not None:
decoding_config.guided_decoding_backend) raise ValueError("Cannot set both guided_options_request and"
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa "params.guided_decoding.")
guided_options.guided_decoding_backend, guided_options,
self.get_tokenizer()) params.guided_decoding = GuidedDecodingParams(
if guided_logits_processor: json=guided_options.guided_json,
if params.logits_processors is None: regex=guided_options.guided_regex,
params.logits_processors = [] choice=guided_options.guided_choice,
params.logits_processors.append(guided_logits_processor) grammar=guided_options.guided_grammar,
json_object=guided_options.guided_json_object,
backend=guided_options.guided_decoding_backend,
whitespace_pattern=guided_options.guided_whitespace_pattern)
return params return params
def _run_engine( def _run_engine(

View File

@ -10,12 +10,10 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated, Required, TypedDict from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind, from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
SamplingParams) SamplingParams)
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
# torch is mocked during docs generation, # torch is mocked during docs generation,
@ -284,10 +282,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
def to_sampling_params( def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
@ -296,14 +291,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
if prompt_logprobs is None and self.echo: if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs prompt_logprobs = self.top_logprobs
# We now allow logprobs being true without top_logrobs. guided_json_object = None
logits_processors = get_logits_processors( if (self.response_format is not None
logit_bias=self.logit_bias, and self.response_format.type == "json_object"):
allowed_token_ids=None, guided_json_object = True
tokenizer=tokenizer,
) guided_decoding = GuidedDecodingParams.from_optional(
if guided_decode_logits_processor: json=self._get_guided_json_from_tool() or self.guided_json,
logits_processors.append(guided_decode_logits_processor) regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
return SamplingParams.from_optional( return SamplingParams.from_optional(
n=self.n, n=self.n,
@ -329,11 +329,29 @@ class ChatCompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
) guided_decoding=guided_decoding,
logit_bias=self.logit_bias)
def _get_guided_json_from_tool(
self) -> Optional[Union[str, dict, BaseModel]]:
# user has chosen to not use any tool
if self.tool_choice == "none" or self.tools is None:
return None
# user has chosen to use a named tool
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = self.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in self.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
return tool.parameters
return None
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@ -537,10 +555,7 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params # doc: end-completion-extra-params
def to_sampling_params( def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
self, tokenizer: AnyTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
max_tokens = default_max_tokens max_tokens = default_max_tokens
@ -551,13 +566,19 @@ class CompletionRequest(OpenAIBaseModel):
echo_without_generation = self.echo and self.max_tokens == 0 echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = get_logits_processors( guided_json_object = None
logit_bias=self.logit_bias, if (self.response_format is not None
allowed_token_ids=self.allowed_token_ids, and self.response_format.type == "json_object"):
tokenizer=tokenizer, guided_json_object = True
)
if guided_decode_logits_processor: guided_decoding = GuidedDecodingParams.from_optional(
logits_processors.append(guided_decode_logits_processor) json=self.guided_json,
regex=self.guided_regex,
choice=self.guided_choice,
grammar=self.guided_grammar,
json_object=guided_json_object,
backend=self.guided_decoding_backend,
whitespace_pattern=self.guided_whitespace_pattern)
return SamplingParams.from_optional( return SamplingParams.from_optional(
n=self.n, n=self.n,
@ -583,11 +604,12 @@ class CompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty, length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
) guided_decoding=guided_decoding,
logit_bias=self.logit_bias,
allowed_token_ids=self.allowed_token_ids)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod

View File

@ -187,9 +187,6 @@ class OpenAIServingChat(OpenAIServing):
raw_request.state.request_metadata = request_metadata raw_request.state.request_metadata = request_metadata
try: try:
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
if isinstance(prompt, str): if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input( prompt_inputs = self._tokenize_prompt_input(
request, request,
@ -208,8 +205,6 @@ class OpenAIServingChat(OpenAIServing):
assert prompt_inputs is not None assert prompt_inputs is not None
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len - default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"])) len(prompt_inputs["prompt_token_ids"]))

View File

@ -110,8 +110,6 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request)
guided_decode_logits_processor = (
await self._guided_decode_logits_processor(request, tokenizer))
prompts = list( prompts = list(
self._tokenize_prompt_input_or_inputs( self._tokenize_prompt_input_or_inputs(
request, request,
@ -123,8 +121,6 @@ class OpenAIServingCompletion(OpenAIServing):
for i, prompt_inputs in enumerate(prompts): for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params( sampling_params = request.to_sampling_params(
tokenizer,
guided_decode_logits_processor,
default_max_tokens=self.max_model_len - default_max_tokens=self.max_model_len -
len(prompt_inputs["prompt_token_ids"])) len(prompt_inputs["prompt_token_ids"]))

View File

@ -27,11 +27,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter from vllm.utils import AtomicCounter
@ -168,15 +166,6 @@ class OpenAIServing:
}) })
return json_str return json_str
async def _guided_decode_logits_processor(
self, request: Union[ChatCompletionRequest, CompletionRequest],
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
decoding_config = await self.engine_client.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
return await get_guided_decoding_logits_processor(
guided_decoding_backend, request, tokenizer)
async def _check_model( async def _check_model(
self, self,
request: AnyRequest, request: AnyRequest,

View File

@ -1,77 +1,45 @@
from typing import Optional, Union from typing import Optional
from vllm.entrypoints.openai.protocol import ( from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.sampling_params import LogitsProcessor
async def get_guided_decoding_logits_processor( async def get_guided_decoding_logits_processor(
guided_decoding_backend: str, request: Union[CompletionRequest, guided_params: GuidedDecodingParams,
ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]: tokenizer) -> Optional[LogitsProcessor]:
request = _adapt_request_for_tool_use(request) # CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor) get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor( return await get_outlines_guided_decoding_logits_processor(
request, tokenizer) guided_params, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer': if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor) get_local_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor( return get_local_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer) guided_params, tokenizer)
raise ValueError( raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. " f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'") "Must be one of 'outlines, 'lm-format-enforcer'")
def get_local_guided_decoding_logits_processor( def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingRequest, guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]: tokenizer) -> Optional[LogitsProcessor]:
# request = _adapt_request_for_tool_use(request) # CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend == 'outlines' or guided_params.grammar:
if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor) get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor( return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer) guided_params, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer': if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor) get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor( return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options, tokenizer) guided_params, tokenizer)
raise ValueError( raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. " f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'") "Must be one of 'outlines, 'lm-format-enforcer'")
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
if type(request) is CompletionRequest:
return request
# user has chosen to not use any tool,
# OR is allowing the model to choose a tool.
if request.tool_choice == "none" or request.tool_choice == "auto":
return request
# user has chosen to use a named tool
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
tool_name = request.tool_choice.function.name
tools = {tool.function.name: tool.function for tool in request.tools}
if tool_name not in tools:
raise ValueError(
f"Tool '{tool_name}' has not been passed in `tools`.")
tool = tools[tool_name]
request.guided_json = tool.parameters
return request

View File

@ -4,6 +4,7 @@ from typing import Dict, List, Optional, TypedDict, Union
from pydantic import BaseModel from pydantic import BaseModel
# These classes are deprecated, see SamplingParams
class LLMGuidedOptions(TypedDict, total=False): class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[Dict, BaseModel, str] guided_json: Union[Dict, BaseModel, str]
guided_regex: str guided_regex: str

View File

@ -7,66 +7,13 @@ from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
TokenEnforcerTokenizerData, UnionParser) TokenEnforcerTokenizerData, UnionParser)
from lmformatenforcer.integrations.vllm import ( from lmformatenforcer.integrations.vllm import (
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.sampling_params import LogitsProcessor
async def get_lm_format_enforcer_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if request.guided_json:
schema = _normalize_json_schema_object(request.guided_json)
character_level_parser = JsonSchemaParser(schema)
elif request.guided_choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in request.guided_choice])
elif request.guided_regex:
character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
elif (request.response_format is not None
and request.response_format.type == "json_object"):
character_level_parser = JsonSchemaParser(
None) # None means any json object
elif (request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
schema = _normalize_json_schema_object(
request.response_format.json_schema.json_schema)
character_level_parser = JsonSchemaParser(schema)
else:
return None
logits_processor = build_vllm_logits_processor(tokenizer_data,
character_level_parser)
return logits_processor
def get_local_lm_format_enforcer_guided_decoding_logits_processor( def get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest, guided_params: GuidedDecodingParams,
tokenizer) -> Optional[LogitsProcessor]: tokenizer) -> Optional[LogitsProcessor]:
""" """
Given an OpenAI-compatible request, check for guided decoding parameters Given an OpenAI-compatible request, check for guided decoding parameters
@ -78,23 +25,20 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer) tokenizer)
character_level_parser: CharacterLevelParser character_level_parser: CharacterLevelParser
if guided_options.guided_json: if guided_params.json:
schema = _normalize_json_schema_object(guided_options.guided_json) schema_dict = _normalize_json_schema_object(guided_params.json)
character_level_parser = JsonSchemaParser(schema) character_level_parser = JsonSchemaParser(schema_dict)
elif guided_options.guided_choice: elif guided_params.choice:
character_level_parser = UnionParser( character_level_parser = UnionParser(
[StringParser(choice) for choice in guided_options.guided_choice]) [StringParser(choice) for choice in guided_params.choice])
elif guided_options.guided_regex: elif guided_params.regex:
character_level_parser = RegexParser(guided_options.guided_regex) character_level_parser = RegexParser(guided_params.regex)
elif guided_options.guided_grammar: elif guided_params.grammar:
# CFG grammar not supported by LMFE, revert to outlines # CFG grammar not supported by LMFE
raise ValueError("Cannot construct a guided decoding logits processor"
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 " using the grammar option with the"
from vllm.model_executor.guided_decoding.outlines_decoding import ( " lm_format_enforcer backend.")
get_local_outlines_guided_decoding_logits_processor) elif guided_params.json_object:
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
elif guided_options.guided_json_object:
# None means any json object # None means any json object
character_level_parser = JsonSchemaParser(None) character_level_parser = JsonSchemaParser(None)
else: else:
@ -105,13 +49,11 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
return logits_processor return logits_processor
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: def _normalize_json_schema_object(schema: Union[str, dict]) -> dict:
if isinstance(schema, str): if isinstance(schema, str):
return json_loads(schema) return json_loads(schema)
if isinstance(schema, dict): if isinstance(schema, dict):
return schema return schema
if isinstance(schema, BaseModel):
return schema.model_json_schema()
raise AssertionError(f"Unsupported schema type {schema}") raise AssertionError(f"Unsupported schema type {schema}")

View File

@ -5,16 +5,11 @@ from json import dumps as json_dumps
from re import escape as regex_escape from re import escape as regex_escape
from typing import Tuple, Union from typing import Tuple, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
class GuidedDecodingMode(Enum): class GuidedDecodingMode(Enum):
@ -55,8 +50,7 @@ global_thread_pool = None # used for generating logits processor fsm
async def get_outlines_guided_decoding_logits_processor( async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
@ -66,7 +60,7 @@ async def get_outlines_guided_decoding_logits_processor(
we make a shallow copy to reuse the same underlying FSM. we make a shallow copy to reuse the same underlying FSM.
""" """
global global_thread_pool global global_thread_pool
guide, mode = _get_guide_and_mode(request) guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode: if not guide or not mode:
return None return None
@ -77,11 +71,11 @@ async def get_outlines_guided_decoding_logits_processor(
return await loop.run_in_executor(global_thread_pool, return await loop.run_in_executor(global_thread_pool,
_get_logits_processor, guide, tokenizer, _get_logits_processor, guide, tokenizer,
mode, request.guided_whitespace_pattern) mode, guided_params.whitespace_pattern)
def get_local_outlines_guided_decoding_logits_processor( def get_local_outlines_guided_decoding_logits_processor(
guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
@ -90,65 +84,37 @@ def get_local_outlines_guided_decoding_logits_processor(
We cache logit processors by (guide, tokenizer), and on cache hit We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM. we make a shallow copy to reuse the same underlying FSM.
""" """
guide, mode = _get_guide_and_mode(guided_options) guide, mode = _get_guide_and_mode(guided_params)
if not guide or not mode: if not guide or not mode:
return None return None
return _get_logits_processor(guide, tokenizer, mode, return _get_logits_processor(guide, tokenizer, mode,
guided_options.guided_whitespace_pattern) guided_params.whitespace_pattern)
def _get_guide_and_mode( def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest, guided_params: GuidedDecodingParams
GuidedDecodingRequest]
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
# if the request is a chat completion request, AND the tool choice is a if guided_params.json:
# named tool choice, do guided decoding if isinstance(guided_params.json, dict):
# using that tool as the JSON schema
if isinstance(request, ChatCompletionRequest) and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam):
# Guided generation for tools/functions parameters
if request.tool_choice.type == "function":
for tool in request.tools:
if (tool.type == "function" and tool.function.name
== request.tool_choice.function.name):
json = json_dumps(tool.function.parameters, sort_keys=True)
return json, GuidedDecodingMode.JSON
return None, None
elif request.guided_json:
if isinstance(request.guided_json, dict):
# turn dict into hashable string # turn dict into hashable string
json = json_dumps(request.guided_json) json = json_dumps(guided_params.json)
elif isinstance(request.guided_json, BaseModel):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json = str(request.guided_json.__signature__)
else: else:
json = request.guided_json json = guided_params.json
return json, GuidedDecodingMode.JSON return json, GuidedDecodingMode.JSON
elif request.guided_regex: elif guided_params.regex:
return request.guided_regex, GuidedDecodingMode.REGEX return guided_params.regex, GuidedDecodingMode.REGEX
elif request.guided_choice: elif guided_params.choice:
# choice just uses regex # choice just uses regex
choices = [ choices = [
regex_escape(str(choice)) for choice in request.guided_choice regex_escape(str(choice)) for choice in guided_params.choice
] ]
choices_regex = "(" + "|".join(choices) + ")" choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE return choices_regex, GuidedDecodingMode.CHOICE
elif request.guided_grammar: elif guided_params.grammar:
return request.guided_grammar, GuidedDecodingMode.GRAMMAR return guided_params.grammar, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest) elif guided_params.json_object:
and request.response_format is not None
and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
json = json_dumps(request.response_format.json_schema.json_schema)
return json, GuidedDecodingMode.JSON
else: else:
return None, None return None, None

View File

@ -1,11 +1,13 @@
"""Sampling parameters for text generation.""" """Sampling parameters for text generation."""
import copy import copy
from dataclasses import dataclass
from enum import Enum, IntEnum from enum import Enum, IntEnum
from functools import cached_property from functools import cached_property
from typing import Any, Callable, Dict, List, Optional, Set, Union from typing import Any, Callable, Dict, List, Optional, Set, Union
import msgspec import msgspec
import torch import torch
from pydantic import BaseModel
from typing_extensions import Annotated from typing_extensions import Annotated
import vllm.envs as envs import vllm.envs as envs
@ -34,6 +36,54 @@ first argument, and returns a modified tensor of logits
to sample from.""" to sample from."""
# maybe make msgspec?
@dataclass
class GuidedDecodingParams:
"""One of these fields will be used to build a logit processor."""
json: Optional[Union[str, Dict]] = None
regex: Optional[str] = None
choice: Optional[List[str]] = None
grammar: Optional[str] = None
json_object: Optional[bool] = None
"""These are other options that can be set"""
backend: Optional[str] = None
whitespace_pattern: Optional[str] = None
@staticmethod
def from_optional(
json: Optional[Union[Dict, BaseModel, str]],
regex: Optional[str] = None,
choice: Optional[List[str]] = None,
grammar: Optional[str] = None,
json_object: Optional[bool] = None,
backend: Optional[str] = None,
whitespace_pattern: Optional[str] = None,
) -> "GuidedDecodingParams":
# Extract json schemas from pydantic models
if isinstance(json, (BaseModel, type(BaseModel))):
json = json.model_json_schema()
return GuidedDecodingParams(
json=json,
regex=regex,
choice=choice,
grammar=grammar,
json_object=json_object,
backend=backend,
whitespace_pattern=whitespace_pattern,
)
def __post_init__(self):
"""Validate that some fields are mutually exclusive."""
guide_count = sum([
self.json is not None, self.regex is not None, self.choice
is not None, self.grammar is not None, self.json_object is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding but multiple are "
f"specified: {self.__dict__}")
class RequestOutputKind(Enum): class RequestOutputKind(Enum):
# Return entire output so far in every RequestOutput # Return entire output so far in every RequestOutput
CUMULATIVE = 0 CUMULATIVE = 0
@ -124,6 +174,13 @@ class SamplingParams(
truncate_prompt_tokens: If set to an integer k, will use only the last k truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation). (i.e., no truncation).
guided_decoding: If provided, the engine will construct a guided
decoding logits processor from these parameters. Defaults to None.
logit_bias: If provided, the engine will construct a logits processor
that applies these logit biases. Defaults to None.
allowed_token_ids: If provided, the engine will construct a logits
processor which only retains scores for the given token ids.
Defaults to None.
""" """
n: int = 1 n: int = 1
@ -164,6 +221,11 @@ class SamplingParams(
output_text_buffer_length: int = 0 output_text_buffer_length: int = 0
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set) _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
# Fields used to construct logits processors
guided_decoding: Optional[GuidedDecodingParams] = None
logit_bias: Optional[Dict[int, float]] = None
allowed_token_ids: Optional[List[int]] = None
@staticmethod @staticmethod
def from_optional( def from_optional(
n: Optional[int] = 1, n: Optional[int] = 1,
@ -194,7 +256,16 @@ class SamplingParams(
truncate_prompt_tokens: Optional[Annotated[int, truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None, msgspec.Meta(ge=1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None,
allowed_token_ids: Optional[List[int]] = None,
) -> "SamplingParams": ) -> "SamplingParams":
if logit_bias is not None:
logit_bias = {
int(token): bias
for token, bias in logit_bias.items()
}
return SamplingParams( return SamplingParams(
n=1 if n is None else n, n=1 if n is None else n,
best_of=best_of, best_of=best_of,
@ -226,6 +297,9 @@ class SamplingParams(
logits_processors=logits_processors, logits_processors=logits_processors,
truncate_prompt_tokens=truncate_prompt_tokens, truncate_prompt_tokens=truncate_prompt_tokens,
output_kind=output_kind, output_kind=output_kind,
guided_decoding=guided_decoding,
logit_bias=logit_bias,
allowed_token_ids=allowed_token_ids,
) )
def __post_init__(self) -> None: def __post_init__(self) -> None:
@ -454,4 +528,5 @@ class SamplingParams(
f"skip_special_tokens={self.skip_special_tokens}, " f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens=" "spaces_between_special_tokens="
f"{self.spaces_between_special_tokens}, " f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens})") f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
f"guided_decoding={self.guided_decoding}")