mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 04:22:15 +08:00
[Frontend][Core] Move guided decoding params into sampling params (#8252)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
bce324487a
commit
062c89e7c9
@ -7,7 +7,7 @@ import pytest
|
|||||||
|
|
||||||
from vllm.entrypoints.llm import LLM
|
from vllm.entrypoints.llm import LLM
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
|
|
||||||
from ...conftest import cleanup
|
from ...conftest import cleanup
|
||||||
|
|
||||||
@ -31,14 +31,12 @@ def test_guided_regex(sample_regex, llm):
|
|||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
)
|
guided_decoding=GuidedDecodingParams(regex=sample_regex))
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(prompts=[
|
||||||
prompts=[
|
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
] * 2,
|
||||||
] * 2,
|
sampling_params=sampling_params,
|
||||||
sampling_params=sampling_params,
|
use_tqdm=True)
|
||||||
use_tqdm=True,
|
|
||||||
guided_options_request=dict(guided_regex=sample_regex))
|
|
||||||
|
|
||||||
assert outputs is not None
|
assert outputs is not None
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
@ -57,15 +55,13 @@ def test_guided_json_completion(sample_json_schema, llm):
|
|||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
)
|
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(prompts=[
|
||||||
prompts=[
|
f"Give an example JSON for an employee profile "
|
||||||
f"Give an example JSON for an employee profile "
|
f"that fits this schema: {sample_json_schema}"
|
||||||
f"that fits this schema: {sample_json_schema}"
|
] * 2,
|
||||||
] * 2,
|
sampling_params=sampling_params,
|
||||||
sampling_params=sampling_params,
|
use_tqdm=True)
|
||||||
use_tqdm=True,
|
|
||||||
guided_options_request=dict(guided_json=sample_json_schema))
|
|
||||||
|
|
||||||
assert outputs is not None
|
assert outputs is not None
|
||||||
|
|
||||||
@ -86,12 +82,11 @@ def test_guided_choice_completion(sample_guided_choice, llm):
|
|||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
)
|
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
prompts="The best language for type-safe systems programming is ",
|
prompts="The best language for type-safe systems programming is ",
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=True,
|
use_tqdm=True)
|
||||||
guided_options_request=dict(guided_choice=sample_guided_choice))
|
|
||||||
|
|
||||||
assert outputs is not None
|
assert outputs is not None
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
@ -112,13 +107,13 @@ def test_guided_grammar(sample_sql_statements, llm):
|
|||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
)
|
guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements))
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
prompts=("Generate a sql state that select col_1 from "
|
prompts=("Generate a sql state that select col_1 from "
|
||||||
"table_1 where it is equals to 1"),
|
"table_1 where it is equals to 1"),
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=True,
|
use_tqdm=True,
|
||||||
guided_options_request=dict(guided_grammar=sample_sql_statements))
|
)
|
||||||
|
|
||||||
assert outputs is not None
|
assert outputs is not None
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
@ -140,3 +135,28 @@ def test_guided_grammar(sample_sql_statements, llm):
|
|||||||
assert generated_text.strip() == ground_truth
|
assert generated_text.strip() == ground_truth
|
||||||
|
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
def test_guided_options_request_deprecation_warning(sample_regex, llm):
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning, match="guided_options_request"):
|
||||||
|
llm.generate(prompts="This should fail",
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True,
|
||||||
|
guided_options_request=dict(guided_regex=sample_regex))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
def test_validation_against_both_guided_decoding_options(sample_regex, llm):
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
guided_decoding=GuidedDecodingParams(regex=sample_regex))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Cannot set both"):
|
||||||
|
llm.generate(prompts="This should fail",
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True,
|
||||||
|
guided_options_request=dict(guided_regex=sample_regex))
|
||||||
|
|||||||
49
tests/model_executor/conftest.py
Normal file
49
tests/model_executor/conftest.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_regex():
|
||||||
|
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||||
|
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_json_schema():
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"age": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"skills": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"maxLength": 10
|
||||||
|
},
|
||||||
|
"minItems": 3
|
||||||
|
},
|
||||||
|
"work_history": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"company": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"duration": {
|
||||||
|
"type": "number"
|
||||||
|
},
|
||||||
|
"position": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["company", "position"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["name", "age", "skills", "work_history"]
|
||||||
|
}
|
||||||
@ -1,14 +1,12 @@
|
|||||||
# This unit test should be moved to a new
|
|
||||||
# tests/test_guided_decoding directory.
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import CompletionRequest
|
|
||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding import (
|
||||||
get_guided_decoding_logits_processor)
|
get_guided_decoding_logits_processor)
|
||||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||||
JSONLogitsProcessor, RegexLogitsProcessor)
|
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||||
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
|
|
||||||
def test_guided_logits_processors(sample_regex, sample_json_schema):
|
def test_guided_logits_processors(sample_regex, sample_json_schema):
|
||||||
@ -44,11 +42,9 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
|
|||||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||||
token_ids = tokenizer.encode(
|
token_ids = tokenizer.encode(
|
||||||
f"Give an example IPv4 address with this regex: {sample_regex}")
|
f"Give an example IPv4 address with this regex: {sample_regex}")
|
||||||
regex_request = CompletionRequest(model='test',
|
regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend)
|
||||||
prompt=token_ids,
|
|
||||||
guided_regex=sample_regex)
|
|
||||||
regex_lp = await get_guided_decoding_logits_processor(
|
regex_lp = await get_guided_decoding_logits_processor(
|
||||||
backend, regex_request, tokenizer)
|
regex_request, tokenizer)
|
||||||
assert regex_lp is not None
|
assert regex_lp is not None
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
@ -59,14 +55,31 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex,
|
|||||||
token_ids = tokenizer.encode(
|
token_ids = tokenizer.encode(
|
||||||
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
f"Give an employee profile that fits this schema: {sample_json_schema}"
|
||||||
)
|
)
|
||||||
json_request = CompletionRequest(model='test',
|
json_request = GuidedDecodingParams(json=sample_json_schema,
|
||||||
prompt=token_ids,
|
backend=backend)
|
||||||
guided_json=sample_json_schema)
|
|
||||||
json_lp = await get_guided_decoding_logits_processor(
|
json_lp = await get_guided_decoding_logits_processor(
|
||||||
backend, json_request, tokenizer)
|
json_request, tokenizer)
|
||||||
assert json_lp is not None
|
assert json_lp is not None
|
||||||
tensor = torch.rand(32000)
|
tensor = torch.rand(32000)
|
||||||
original_tensor = torch.clone(tensor)
|
original_tensor = torch.clone(tensor)
|
||||||
tensor = json_lp(token_ids, tensor)
|
tensor = json_lp(token_ids, tensor)
|
||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
assert not torch.allclose(tensor, original_tensor)
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex):
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="You can only use one kind of guided"):
|
||||||
|
GuidedDecodingParams(json=sample_json_schema, regex=sample_regex)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="You can only use one kind of guided"):
|
||||||
|
GuidedDecodingParams(json=sample_json_schema, json_object=True)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="You can only use one kind of guided"):
|
||||||
|
GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="You can only use one kind of guided"):
|
||||||
|
GuidedDecodingParams(json=sample_json_schema, grammar="test grammar")
|
||||||
@ -20,6 +20,8 @@ from vllm.executor.ray_utils import initialize_ray_cluster
|
|||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.model_executor.guided_decoding import (
|
||||||
|
get_guided_decoding_logits_processor)
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
@ -477,6 +479,18 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
)
|
)
|
||||||
processed_inputs = self.input_processor(preprocessed_inputs)
|
processed_inputs = self.input_processor(preprocessed_inputs)
|
||||||
|
|
||||||
|
if isinstance(params, SamplingParams) and \
|
||||||
|
params.guided_decoding is not None:
|
||||||
|
# Guided decoding has an async implementation for building logits
|
||||||
|
# processors in a separate threadpool.
|
||||||
|
# We want to invoke that here instead of using the blocking
|
||||||
|
# implementation in the LLMEngine
|
||||||
|
params = await build_guided_decoding_logits_processor_async(
|
||||||
|
sampling_params=params,
|
||||||
|
tokenizer=self.get_tokenizer(lora_request),
|
||||||
|
default_guided_backend=self.decoding_config.
|
||||||
|
guided_decoding_backend)
|
||||||
|
|
||||||
self._add_processed_request(
|
self._add_processed_request(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
processed_inputs=processed_inputs,
|
processed_inputs=processed_inputs,
|
||||||
@ -494,6 +508,36 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
self.model_executor.check_health()
|
self.model_executor.check_health()
|
||||||
|
|
||||||
|
|
||||||
|
async def build_guided_decoding_logits_processor_async(
|
||||||
|
sampling_params: SamplingParams, tokenizer: AnyTokenizer,
|
||||||
|
default_guided_backend: str) -> SamplingParams:
|
||||||
|
"""Constructs logits processors based on the guided_decoding,
|
||||||
|
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||||
|
those fields and adds the constructed logits processors to the
|
||||||
|
logits_processors field. Modifies sampling params in-place and returns
|
||||||
|
the modified sampling params."""
|
||||||
|
if (guided_decoding := sampling_params.guided_decoding) is None:
|
||||||
|
return sampling_params
|
||||||
|
|
||||||
|
logger.debug("Building guided decoding logits processor. "
|
||||||
|
"Params: %s", guided_decoding)
|
||||||
|
|
||||||
|
guided_decoding.backend = guided_decoding.backend or default_guided_backend
|
||||||
|
|
||||||
|
processor = await get_guided_decoding_logits_processor(
|
||||||
|
guided_params=guided_decoding, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
if processor:
|
||||||
|
if sampling_params.logits_processors is None:
|
||||||
|
sampling_params.logits_processors = []
|
||||||
|
sampling_params.logits_processors.append(processor)
|
||||||
|
|
||||||
|
# Unset guided decoding params after constructing the lp from them
|
||||||
|
sampling_params.guided_decoding = None
|
||||||
|
|
||||||
|
return sampling_params
|
||||||
|
|
||||||
|
|
||||||
class AsyncLLMEngine:
|
class AsyncLLMEngine:
|
||||||
"""An asynchronous wrapper for :class:`LLMEngine`.
|
"""An asynchronous wrapper for :class:`LLMEngine`.
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,7 @@ from vllm.engine.output_processor.interfaces import (
|
|||||||
SequenceGroupOutputProcessor)
|
SequenceGroupOutputProcessor)
|
||||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||||
|
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.executor.gpu_executor import GPUExecutor
|
from vllm.executor.gpu_executor import GPUExecutor
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
@ -33,6 +34,8 @@ from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
|
|||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.model_executor.guided_decoding import (
|
||||||
|
get_local_guided_decoding_logits_processor)
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||||
RequestOutputFactory)
|
RequestOutputFactory)
|
||||||
@ -843,6 +846,9 @@ class LLMEngine:
|
|||||||
raise ValueError(f"Cannot request more than "
|
raise ValueError(f"Cannot request more than "
|
||||||
f"{max_logprobs} logprobs.")
|
f"{max_logprobs} logprobs.")
|
||||||
|
|
||||||
|
sampling_params = self._build_logits_processors(
|
||||||
|
sampling_params, lora_request)
|
||||||
|
|
||||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||||
# this doesn't deep-copy LogitsProcessor objects
|
# this doesn't deep-copy LogitsProcessor objects
|
||||||
sampling_params = sampling_params.clone()
|
sampling_params = sampling_params.clone()
|
||||||
@ -1895,3 +1901,51 @@ class LLMEngine:
|
|||||||
# TODO: Find out how many placeholder tokens are there so we can
|
# TODO: Find out how many placeholder tokens are there so we can
|
||||||
# check that chunked prefill does not truncate them
|
# check that chunked prefill does not truncate them
|
||||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||||
|
|
||||||
|
def _build_logits_processors(
|
||||||
|
self, sampling_params: SamplingParams,
|
||||||
|
lora_request: Optional[LoRARequest]) -> SamplingParams:
|
||||||
|
"""Constructs logits processors based on the guided_decoding,
|
||||||
|
logits_bias, and allowed_token_ids fields in sampling_params. Deletes
|
||||||
|
those fields and adds the constructed logits processors to the
|
||||||
|
logits_processors field. Returns the modified sampling params."""
|
||||||
|
|
||||||
|
logits_processors = []
|
||||||
|
if (guided_decoding := sampling_params.guided_decoding) is not None:
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Building guided decoding logits processor in "
|
||||||
|
"LLMEngine. Params: %s", guided_decoding)
|
||||||
|
|
||||||
|
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||||
|
guided_decoding.backend = guided_decoding.backend or \
|
||||||
|
self.decoding_config.guided_decoding_backend
|
||||||
|
|
||||||
|
processor = get_local_guided_decoding_logits_processor(
|
||||||
|
guided_params=guided_decoding, tokenizer=tokenizer)
|
||||||
|
if processor:
|
||||||
|
logits_processors.append(processor)
|
||||||
|
|
||||||
|
# Unset so this doesn't get passed down to the model
|
||||||
|
sampling_params.guided_decoding = None
|
||||||
|
|
||||||
|
if (sampling_params.logit_bias or sampling_params.allowed_token_ids):
|
||||||
|
tokenizer = self.get_tokenizer(lora_request=lora_request)
|
||||||
|
|
||||||
|
processors = get_logits_processors(
|
||||||
|
logit_bias=sampling_params.logit_bias,
|
||||||
|
allowed_token_ids=sampling_params.allowed_token_ids,
|
||||||
|
tokenizer=tokenizer)
|
||||||
|
logits_processors.extend(processors)
|
||||||
|
|
||||||
|
# Unset so these don't get passed down to the model
|
||||||
|
sampling_params.logit_bias = None
|
||||||
|
sampling_params.allowed_token_ids = None
|
||||||
|
|
||||||
|
if logits_processors:
|
||||||
|
if sampling_params.logits_processors is None:
|
||||||
|
sampling_params.logits_processors = logits_processors
|
||||||
|
else:
|
||||||
|
sampling_params.logits_processors.extend(logits_processors)
|
||||||
|
|
||||||
|
return sampling_params
|
||||||
|
|||||||
@ -16,6 +16,8 @@ from vllm.config import DecodingConfig, EngineConfig, ModelConfig
|
|||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
|
from vllm.engine.async_llm_engine import (
|
||||||
|
build_guided_decoding_logits_processor_async)
|
||||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||||
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
||||||
@ -512,6 +514,18 @@ class MQLLMEngineClient:
|
|||||||
if self._errored_with is not None:
|
if self._errored_with is not None:
|
||||||
raise ENGINE_DEAD_ERROR(self._errored_with)
|
raise ENGINE_DEAD_ERROR(self._errored_with)
|
||||||
|
|
||||||
|
# Constructing guided decoding logits processors is expensive, so we do
|
||||||
|
# it here to avoid contending with cpu resources and the GIL on the
|
||||||
|
# backend process.
|
||||||
|
if isinstance(params, SamplingParams) and \
|
||||||
|
params.guided_decoding is not None:
|
||||||
|
params = await \
|
||||||
|
build_guided_decoding_logits_processor_async(
|
||||||
|
sampling_params=params,
|
||||||
|
tokenizer=await self.get_tokenizer(lora_request),
|
||||||
|
default_guided_backend=self.decoding_config.guided_decoding_backend
|
||||||
|
)
|
||||||
|
|
||||||
# 1) Create output queue for this requests.
|
# 1) Create output queue for this requests.
|
||||||
queue: asyncio.Queue[Union[RequestOutput,
|
queue: asyncio.Queue[Union[RequestOutput,
|
||||||
BaseException]] = asyncio.Queue()
|
BaseException]] = asyncio.Queue()
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import itertools
|
import itertools
|
||||||
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
|
||||||
@ -16,13 +17,13 @@ from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
|||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||||
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
|
GuidedDecodingRequest, LLMGuidedOptions)
|
||||||
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
|
|
||||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
|
||||||
|
SamplingParams)
|
||||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||||
get_cached_tokenizer)
|
get_cached_tokenizer)
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
@ -798,6 +799,14 @@ class LLM:
|
|||||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||||
priority: Optional[List[int]] = None,
|
priority: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if guided_options is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"guided_options_request is deprecated, use "
|
||||||
|
"SamplingParams.guided_decoding instead",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(prompts, (str, dict)):
|
if isinstance(prompts, (str, dict)):
|
||||||
# Convert a single prompt to a list.
|
# Convert a single prompt to a list.
|
||||||
prompts = [prompts]
|
prompts = [prompts]
|
||||||
@ -813,7 +822,7 @@ class LLM:
|
|||||||
|
|
||||||
for sp in params if isinstance(params, list) else (params, ):
|
for sp in params if isinstance(params, list) else (params, ):
|
||||||
if isinstance(sp, SamplingParams):
|
if isinstance(sp, SamplingParams):
|
||||||
self._add_guided_processor(sp, guided_options)
|
self._add_guided_params(sp, guided_options)
|
||||||
|
|
||||||
# We only care about the final output
|
# We only care about the final output
|
||||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||||
@ -847,22 +856,25 @@ class LLM:
|
|||||||
priority=priority,
|
priority=priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_guided_processor(
|
def _add_guided_params(
|
||||||
self,
|
self,
|
||||||
params: SamplingParams,
|
params: SamplingParams,
|
||||||
guided_options: Optional[GuidedDecodingRequest] = None):
|
guided_options: Optional[GuidedDecodingRequest] = None):
|
||||||
if guided_options:
|
if guided_options is None:
|
||||||
if guided_options.guided_decoding_backend is None:
|
return params
|
||||||
decoding_config = self.llm_engine.get_decoding_config()
|
|
||||||
guided_options.guided_decoding_backend = (
|
if params.guided_decoding is not None:
|
||||||
decoding_config.guided_decoding_backend)
|
raise ValueError("Cannot set both guided_options_request and"
|
||||||
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
|
"params.guided_decoding.")
|
||||||
guided_options.guided_decoding_backend, guided_options,
|
|
||||||
self.get_tokenizer())
|
params.guided_decoding = GuidedDecodingParams(
|
||||||
if guided_logits_processor:
|
json=guided_options.guided_json,
|
||||||
if params.logits_processors is None:
|
regex=guided_options.guided_regex,
|
||||||
params.logits_processors = []
|
choice=guided_options.guided_choice,
|
||||||
params.logits_processors.append(guided_logits_processor)
|
grammar=guided_options.guided_grammar,
|
||||||
|
json_object=guided_options.guided_json_object,
|
||||||
|
backend=guided_options.guided_decoding_backend,
|
||||||
|
whitespace_pattern=guided_options.guided_whitespace_pattern)
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def _run_engine(
|
def _run_engine(
|
||||||
|
|||||||
@ -10,12 +10,10 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|||||||
from typing_extensions import Annotated, Required, TypedDict
|
from typing_extensions import Annotated, Required, TypedDict
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
from vllm.entrypoints.openai.logits_processors import get_logits_processors
|
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
|
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
|
||||||
SamplingParams)
|
SamplingParams)
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
# torch is mocked during docs generation,
|
# torch is mocked during docs generation,
|
||||||
@ -284,10 +282,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# doc: end-chat-completion-extra-params
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
def to_sampling_params(
|
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||||
self, tokenizer: AnyTokenizer,
|
|
||||||
guided_decode_logits_processor: Optional[LogitsProcessor],
|
|
||||||
default_max_tokens: int) -> SamplingParams:
|
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = default_max_tokens
|
max_tokens = default_max_tokens
|
||||||
@ -296,14 +291,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
if prompt_logprobs is None and self.echo:
|
if prompt_logprobs is None and self.echo:
|
||||||
prompt_logprobs = self.top_logprobs
|
prompt_logprobs = self.top_logprobs
|
||||||
|
|
||||||
# We now allow logprobs being true without top_logrobs.
|
guided_json_object = None
|
||||||
logits_processors = get_logits_processors(
|
if (self.response_format is not None
|
||||||
logit_bias=self.logit_bias,
|
and self.response_format.type == "json_object"):
|
||||||
allowed_token_ids=None,
|
guided_json_object = True
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
guided_decoding = GuidedDecodingParams.from_optional(
|
||||||
if guided_decode_logits_processor:
|
json=self._get_guided_json_from_tool() or self.guided_json,
|
||||||
logits_processors.append(guided_decode_logits_processor)
|
regex=self.guided_regex,
|
||||||
|
choice=self.guided_choice,
|
||||||
|
grammar=self.guided_grammar,
|
||||||
|
json_object=guided_json_object,
|
||||||
|
backend=self.guided_decoding_backend,
|
||||||
|
whitespace_pattern=self.guided_whitespace_pattern)
|
||||||
|
|
||||||
return SamplingParams.from_optional(
|
return SamplingParams.from_optional(
|
||||||
n=self.n,
|
n=self.n,
|
||||||
@ -329,11 +329,29 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||||
length_penalty=self.length_penalty,
|
length_penalty=self.length_penalty,
|
||||||
logits_processors=logits_processors,
|
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||||
else RequestOutputKind.FINAL_ONLY,
|
else RequestOutputKind.FINAL_ONLY,
|
||||||
)
|
guided_decoding=guided_decoding,
|
||||||
|
logit_bias=self.logit_bias)
|
||||||
|
|
||||||
|
def _get_guided_json_from_tool(
|
||||||
|
self) -> Optional[Union[str, dict, BaseModel]]:
|
||||||
|
# user has chosen to not use any tool
|
||||||
|
if self.tool_choice == "none" or self.tools is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# user has chosen to use a named tool
|
||||||
|
if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||||
|
tool_name = self.tool_choice.function.name
|
||||||
|
tools = {tool.function.name: tool.function for tool in self.tools}
|
||||||
|
if tool_name not in tools:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool '{tool_name}' has not been passed in `tools`.")
|
||||||
|
tool = tools[tool_name]
|
||||||
|
return tool.parameters
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -537,10 +555,7 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# doc: end-completion-extra-params
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
def to_sampling_params(
|
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||||
self, tokenizer: AnyTokenizer,
|
|
||||||
guided_decode_logits_processor: Optional[LogitsProcessor],
|
|
||||||
default_max_tokens: int) -> SamplingParams:
|
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
max_tokens = default_max_tokens
|
max_tokens = default_max_tokens
|
||||||
@ -551,13 +566,19 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
echo_without_generation = self.echo and self.max_tokens == 0
|
echo_without_generation = self.echo and self.max_tokens == 0
|
||||||
|
|
||||||
logits_processors = get_logits_processors(
|
guided_json_object = None
|
||||||
logit_bias=self.logit_bias,
|
if (self.response_format is not None
|
||||||
allowed_token_ids=self.allowed_token_ids,
|
and self.response_format.type == "json_object"):
|
||||||
tokenizer=tokenizer,
|
guided_json_object = True
|
||||||
)
|
|
||||||
if guided_decode_logits_processor:
|
guided_decoding = GuidedDecodingParams.from_optional(
|
||||||
logits_processors.append(guided_decode_logits_processor)
|
json=self.guided_json,
|
||||||
|
regex=self.guided_regex,
|
||||||
|
choice=self.guided_choice,
|
||||||
|
grammar=self.guided_grammar,
|
||||||
|
json_object=guided_json_object,
|
||||||
|
backend=self.guided_decoding_backend,
|
||||||
|
whitespace_pattern=self.guided_whitespace_pattern)
|
||||||
|
|
||||||
return SamplingParams.from_optional(
|
return SamplingParams.from_optional(
|
||||||
n=self.n,
|
n=self.n,
|
||||||
@ -583,11 +604,12 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||||
length_penalty=self.length_penalty,
|
length_penalty=self.length_penalty,
|
||||||
logits_processors=logits_processors,
|
|
||||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
output_kind=RequestOutputKind.DELTA if self.stream \
|
output_kind=RequestOutputKind.DELTA if self.stream \
|
||||||
else RequestOutputKind.FINAL_ONLY,
|
else RequestOutputKind.FINAL_ONLY,
|
||||||
)
|
guided_decoding=guided_decoding,
|
||||||
|
logit_bias=self.logit_bias,
|
||||||
|
allowed_token_ids=self.allowed_token_ids)
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -187,9 +187,6 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
raw_request.state.request_metadata = request_metadata
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
try:
|
try:
|
||||||
guided_decode_logits_processor = (
|
|
||||||
await self._guided_decode_logits_processor(request, tokenizer))
|
|
||||||
|
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
prompt_inputs = self._tokenize_prompt_input(
|
prompt_inputs = self._tokenize_prompt_input(
|
||||||
request,
|
request,
|
||||||
@ -208,8 +205,6 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
assert prompt_inputs is not None
|
assert prompt_inputs is not None
|
||||||
|
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
tokenizer,
|
|
||||||
guided_decode_logits_processor,
|
|
||||||
default_max_tokens=self.max_model_len -
|
default_max_tokens=self.max_model_len -
|
||||||
len(prompt_inputs["prompt_token_ids"]))
|
len(prompt_inputs["prompt_token_ids"]))
|
||||||
|
|
||||||
|
|||||||
@ -110,8 +110,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
guided_decode_logits_processor = (
|
|
||||||
await self._guided_decode_logits_processor(request, tokenizer))
|
|
||||||
prompts = list(
|
prompts = list(
|
||||||
self._tokenize_prompt_input_or_inputs(
|
self._tokenize_prompt_input_or_inputs(
|
||||||
request,
|
request,
|
||||||
@ -123,8 +121,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
for i, prompt_inputs in enumerate(prompts):
|
for i, prompt_inputs in enumerate(prompts):
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params = request.to_sampling_params(
|
||||||
tokenizer,
|
|
||||||
guided_decode_logits_processor,
|
|
||||||
default_max_tokens=self.max_model_len -
|
default_max_tokens=self.max_model_len -
|
||||||
len(prompt_inputs["prompt_token_ids"]))
|
len(prompt_inputs["prompt_token_ids"]))
|
||||||
|
|
||||||
|
|||||||
@ -27,11 +27,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.guided_decoding import (
|
|
||||||
get_guided_decoding_logits_processor)
|
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import LogitsProcessor, SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import AtomicCounter
|
from vllm.utils import AtomicCounter
|
||||||
@ -168,15 +166,6 @@ class OpenAIServing:
|
|||||||
})
|
})
|
||||||
return json_str
|
return json_str
|
||||||
|
|
||||||
async def _guided_decode_logits_processor(
|
|
||||||
self, request: Union[ChatCompletionRequest, CompletionRequest],
|
|
||||||
tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]:
|
|
||||||
decoding_config = await self.engine_client.get_decoding_config()
|
|
||||||
guided_decoding_backend = request.guided_decoding_backend \
|
|
||||||
or decoding_config.guided_decoding_backend
|
|
||||||
return await get_guided_decoding_logits_processor(
|
|
||||||
guided_decoding_backend, request, tokenizer)
|
|
||||||
|
|
||||||
async def _check_model(
|
async def _check_model(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
|
|||||||
@ -1,77 +1,45 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
|
||||||
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
|
|
||||||
CompletionRequest)
|
|
||||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
|
||||||
GuidedDecodingRequest)
|
|
||||||
from vllm.sampling_params import LogitsProcessor
|
|
||||||
|
|
||||||
|
|
||||||
async def get_guided_decoding_logits_processor(
|
async def get_guided_decoding_logits_processor(
|
||||||
guided_decoding_backend: str, request: Union[CompletionRequest,
|
guided_params: GuidedDecodingParams,
|
||||||
ChatCompletionRequest],
|
|
||||||
tokenizer) -> Optional[LogitsProcessor]:
|
tokenizer) -> Optional[LogitsProcessor]:
|
||||||
request = _adapt_request_for_tool_use(request)
|
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||||
|
if guided_params.backend == 'outlines' or guided_params.grammar:
|
||||||
if guided_decoding_backend == 'outlines':
|
|
||||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||||
get_outlines_guided_decoding_logits_processor)
|
get_outlines_guided_decoding_logits_processor)
|
||||||
return await get_outlines_guided_decoding_logits_processor(
|
return await get_outlines_guided_decoding_logits_processor(
|
||||||
request, tokenizer)
|
guided_params, tokenizer)
|
||||||
if guided_decoding_backend == 'lm-format-enforcer':
|
if guided_params.backend == 'lm-format-enforcer':
|
||||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||||
get_lm_format_enforcer_guided_decoding_logits_processor)
|
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||||
return await get_lm_format_enforcer_guided_decoding_logits_processor(
|
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||||
request, tokenizer)
|
guided_params, tokenizer)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
"Must be one of 'outlines, 'lm-format-enforcer'")
|
||||||
|
|
||||||
|
|
||||||
def get_local_guided_decoding_logits_processor(
|
def get_local_guided_decoding_logits_processor(
|
||||||
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
|
guided_params: GuidedDecodingParams,
|
||||||
tokenizer) -> Optional[LogitsProcessor]:
|
tokenizer) -> Optional[LogitsProcessor]:
|
||||||
# request = _adapt_request_for_tool_use(request)
|
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||||
|
if guided_params.backend == 'outlines' or guided_params.grammar:
|
||||||
if guided_decoding_backend == 'outlines':
|
|
||||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||||
get_local_outlines_guided_decoding_logits_processor)
|
get_local_outlines_guided_decoding_logits_processor)
|
||||||
return get_local_outlines_guided_decoding_logits_processor(
|
return get_local_outlines_guided_decoding_logits_processor(
|
||||||
guided_options, tokenizer)
|
guided_params, tokenizer)
|
||||||
if guided_decoding_backend == 'lm-format-enforcer':
|
if guided_params.backend == 'lm-format-enforcer':
|
||||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
|
||||||
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
get_local_lm_format_enforcer_guided_decoding_logits_processor)
|
||||||
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||||
guided_options, tokenizer)
|
guided_params, tokenizer)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
"Must be one of 'outlines, 'lm-format-enforcer'")
|
||||||
|
|
||||||
|
|
||||||
def _adapt_request_for_tool_use(request: Union[CompletionRequest,
|
|
||||||
ChatCompletionRequest]):
|
|
||||||
# the legacy completion API does not support tool use
|
|
||||||
if type(request) is CompletionRequest:
|
|
||||||
return request
|
|
||||||
|
|
||||||
# user has chosen to not use any tool,
|
|
||||||
# OR is allowing the model to choose a tool.
|
|
||||||
if request.tool_choice == "none" or request.tool_choice == "auto":
|
|
||||||
return request
|
|
||||||
|
|
||||||
# user has chosen to use a named tool
|
|
||||||
if type(request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
|
||||||
tool_name = request.tool_choice.function.name
|
|
||||||
tools = {tool.function.name: tool.function for tool in request.tools}
|
|
||||||
if tool_name not in tools:
|
|
||||||
raise ValueError(
|
|
||||||
f"Tool '{tool_name}' has not been passed in `tools`.")
|
|
||||||
tool = tools[tool_name]
|
|
||||||
request.guided_json = tool.parameters
|
|
||||||
|
|
||||||
return request
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Dict, List, Optional, TypedDict, Union
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
# These classes are deprecated, see SamplingParams
|
||||||
class LLMGuidedOptions(TypedDict, total=False):
|
class LLMGuidedOptions(TypedDict, total=False):
|
||||||
guided_json: Union[Dict, BaseModel, str]
|
guided_json: Union[Dict, BaseModel, str]
|
||||||
guided_regex: str
|
guided_regex: str
|
||||||
|
|||||||
@ -7,66 +7,13 @@ from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
|
|||||||
TokenEnforcerTokenizerData, UnionParser)
|
TokenEnforcerTokenizerData, UnionParser)
|
||||||
from lmformatenforcer.integrations.vllm import (
|
from lmformatenforcer.integrations.vllm import (
|
||||||
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
||||||
from pydantic import BaseModel
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor
|
||||||
CompletionRequest)
|
|
||||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
|
||||||
GuidedDecodingRequest)
|
|
||||||
from vllm.sampling_params import LogitsProcessor
|
|
||||||
|
|
||||||
|
|
||||||
async def get_lm_format_enforcer_guided_decoding_logits_processor(
|
|
||||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
|
||||||
tokenizer) -> Optional[LogitsProcessor]:
|
|
||||||
"""
|
|
||||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
|
||||||
and get the necessary logits processor for the given guide.
|
|
||||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
|
||||||
we make a shallow copy to reuse the same underlying FSM.
|
|
||||||
"""
|
|
||||||
|
|
||||||
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
|
||||||
tokenizer)
|
|
||||||
character_level_parser: CharacterLevelParser
|
|
||||||
if request.guided_json:
|
|
||||||
schema = _normalize_json_schema_object(request.guided_json)
|
|
||||||
character_level_parser = JsonSchemaParser(schema)
|
|
||||||
elif request.guided_choice:
|
|
||||||
character_level_parser = UnionParser(
|
|
||||||
[StringParser(choice) for choice in request.guided_choice])
|
|
||||||
elif request.guided_regex:
|
|
||||||
character_level_parser = RegexParser(request.guided_regex)
|
|
||||||
elif request.guided_grammar:
|
|
||||||
# CFG grammar not supported by LMFE, revert to outlines
|
|
||||||
|
|
||||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
|
||||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
|
||||||
get_outlines_guided_decoding_logits_processor)
|
|
||||||
return await get_outlines_guided_decoding_logits_processor(
|
|
||||||
request, tokenizer)
|
|
||||||
elif (request.response_format is not None
|
|
||||||
and request.response_format.type == "json_object"):
|
|
||||||
character_level_parser = JsonSchemaParser(
|
|
||||||
None) # None means any json object
|
|
||||||
elif (request.response_format is not None
|
|
||||||
and request.response_format.type == "json_schema"
|
|
||||||
and request.response_format.json_schema is not None
|
|
||||||
and request.response_format.json_schema.json_schema is not None):
|
|
||||||
schema = _normalize_json_schema_object(
|
|
||||||
request.response_format.json_schema.json_schema)
|
|
||||||
character_level_parser = JsonSchemaParser(schema)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
logits_processor = build_vllm_logits_processor(tokenizer_data,
|
|
||||||
character_level_parser)
|
|
||||||
return logits_processor
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||||
guided_options: GuidedDecodingRequest,
|
guided_params: GuidedDecodingParams,
|
||||||
tokenizer) -> Optional[LogitsProcessor]:
|
tokenizer) -> Optional[LogitsProcessor]:
|
||||||
"""
|
"""
|
||||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||||
@ -78,23 +25,20 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
|||||||
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||||
tokenizer)
|
tokenizer)
|
||||||
character_level_parser: CharacterLevelParser
|
character_level_parser: CharacterLevelParser
|
||||||
if guided_options.guided_json:
|
if guided_params.json:
|
||||||
schema = _normalize_json_schema_object(guided_options.guided_json)
|
schema_dict = _normalize_json_schema_object(guided_params.json)
|
||||||
character_level_parser = JsonSchemaParser(schema)
|
character_level_parser = JsonSchemaParser(schema_dict)
|
||||||
elif guided_options.guided_choice:
|
elif guided_params.choice:
|
||||||
character_level_parser = UnionParser(
|
character_level_parser = UnionParser(
|
||||||
[StringParser(choice) for choice in guided_options.guided_choice])
|
[StringParser(choice) for choice in guided_params.choice])
|
||||||
elif guided_options.guided_regex:
|
elif guided_params.regex:
|
||||||
character_level_parser = RegexParser(guided_options.guided_regex)
|
character_level_parser = RegexParser(guided_params.regex)
|
||||||
elif guided_options.guided_grammar:
|
elif guided_params.grammar:
|
||||||
# CFG grammar not supported by LMFE, revert to outlines
|
# CFG grammar not supported by LMFE
|
||||||
|
raise ValueError("Cannot construct a guided decoding logits processor"
|
||||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
" using the grammar option with the"
|
||||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
" lm_format_enforcer backend.")
|
||||||
get_local_outlines_guided_decoding_logits_processor)
|
elif guided_params.json_object:
|
||||||
return get_local_outlines_guided_decoding_logits_processor(
|
|
||||||
guided_options, tokenizer)
|
|
||||||
elif guided_options.guided_json_object:
|
|
||||||
# None means any json object
|
# None means any json object
|
||||||
character_level_parser = JsonSchemaParser(None)
|
character_level_parser = JsonSchemaParser(None)
|
||||||
else:
|
else:
|
||||||
@ -105,13 +49,11 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
|||||||
return logits_processor
|
return logits_processor
|
||||||
|
|
||||||
|
|
||||||
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
|
def _normalize_json_schema_object(schema: Union[str, dict]) -> dict:
|
||||||
if isinstance(schema, str):
|
if isinstance(schema, str):
|
||||||
return json_loads(schema)
|
return json_loads(schema)
|
||||||
if isinstance(schema, dict):
|
if isinstance(schema, dict):
|
||||||
return schema
|
return schema
|
||||||
if isinstance(schema, BaseModel):
|
|
||||||
return schema.model_json_schema()
|
|
||||||
raise AssertionError(f"Unsupported schema type {schema}")
|
raise AssertionError(f"Unsupported schema type {schema}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,16 +5,11 @@ from json import dumps as json_dumps
|
|||||||
from re import escape as regex_escape
|
from re import escape as regex_escape
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (
|
|
||||||
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
|
|
||||||
CompletionRequest)
|
|
||||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
|
||||||
GuidedDecodingRequest)
|
|
||||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||||
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
|
|
||||||
class GuidedDecodingMode(Enum):
|
class GuidedDecodingMode(Enum):
|
||||||
@ -55,8 +50,7 @@ global_thread_pool = None # used for generating logits processor fsm
|
|||||||
|
|
||||||
|
|
||||||
async def get_outlines_guided_decoding_logits_processor(
|
async def get_outlines_guided_decoding_logits_processor(
|
||||||
request: Union[CompletionRequest,
|
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
|
||||||
ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase
|
|
||||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||||
None]:
|
None]:
|
||||||
"""
|
"""
|
||||||
@ -66,7 +60,7 @@ async def get_outlines_guided_decoding_logits_processor(
|
|||||||
we make a shallow copy to reuse the same underlying FSM.
|
we make a shallow copy to reuse the same underlying FSM.
|
||||||
"""
|
"""
|
||||||
global global_thread_pool
|
global global_thread_pool
|
||||||
guide, mode = _get_guide_and_mode(request)
|
guide, mode = _get_guide_and_mode(guided_params)
|
||||||
if not guide or not mode:
|
if not guide or not mode:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -77,11 +71,11 @@ async def get_outlines_guided_decoding_logits_processor(
|
|||||||
|
|
||||||
return await loop.run_in_executor(global_thread_pool,
|
return await loop.run_in_executor(global_thread_pool,
|
||||||
_get_logits_processor, guide, tokenizer,
|
_get_logits_processor, guide, tokenizer,
|
||||||
mode, request.guided_whitespace_pattern)
|
mode, guided_params.whitespace_pattern)
|
||||||
|
|
||||||
|
|
||||||
def get_local_outlines_guided_decoding_logits_processor(
|
def get_local_outlines_guided_decoding_logits_processor(
|
||||||
guided_options: GuidedDecodingRequest, tokenizer: PreTrainedTokenizerBase
|
guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase
|
||||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||||
None]:
|
None]:
|
||||||
"""
|
"""
|
||||||
@ -90,65 +84,37 @@ def get_local_outlines_guided_decoding_logits_processor(
|
|||||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||||
we make a shallow copy to reuse the same underlying FSM.
|
we make a shallow copy to reuse the same underlying FSM.
|
||||||
"""
|
"""
|
||||||
guide, mode = _get_guide_and_mode(guided_options)
|
guide, mode = _get_guide_and_mode(guided_params)
|
||||||
if not guide or not mode:
|
if not guide or not mode:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return _get_logits_processor(guide, tokenizer, mode,
|
return _get_logits_processor(guide, tokenizer, mode,
|
||||||
guided_options.guided_whitespace_pattern)
|
guided_params.whitespace_pattern)
|
||||||
|
|
||||||
|
|
||||||
def _get_guide_and_mode(
|
def _get_guide_and_mode(
|
||||||
request: Union[CompletionRequest, ChatCompletionRequest,
|
guided_params: GuidedDecodingParams
|
||||||
GuidedDecodingRequest]
|
|
||||||
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
|
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
|
||||||
# if the request is a chat completion request, AND the tool choice is a
|
if guided_params.json:
|
||||||
# named tool choice, do guided decoding
|
if isinstance(guided_params.json, dict):
|
||||||
# using that tool as the JSON schema
|
|
||||||
if isinstance(request, ChatCompletionRequest) and isinstance(
|
|
||||||
request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
|
||||||
# Guided generation for tools/functions parameters
|
|
||||||
if request.tool_choice.type == "function":
|
|
||||||
for tool in request.tools:
|
|
||||||
if (tool.type == "function" and tool.function.name
|
|
||||||
== request.tool_choice.function.name):
|
|
||||||
json = json_dumps(tool.function.parameters, sort_keys=True)
|
|
||||||
return json, GuidedDecodingMode.JSON
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
elif request.guided_json:
|
|
||||||
if isinstance(request.guided_json, dict):
|
|
||||||
# turn dict into hashable string
|
# turn dict into hashable string
|
||||||
json = json_dumps(request.guided_json)
|
json = json_dumps(guided_params.json)
|
||||||
elif isinstance(request.guided_json, BaseModel):
|
|
||||||
# use pydantic signature so that different model classes
|
|
||||||
# with the same fields will get hashed the same
|
|
||||||
json = str(request.guided_json.__signature__)
|
|
||||||
else:
|
else:
|
||||||
json = request.guided_json
|
json = guided_params.json
|
||||||
return json, GuidedDecodingMode.JSON
|
return json, GuidedDecodingMode.JSON
|
||||||
elif request.guided_regex:
|
elif guided_params.regex:
|
||||||
return request.guided_regex, GuidedDecodingMode.REGEX
|
return guided_params.regex, GuidedDecodingMode.REGEX
|
||||||
elif request.guided_choice:
|
elif guided_params.choice:
|
||||||
# choice just uses regex
|
# choice just uses regex
|
||||||
choices = [
|
choices = [
|
||||||
regex_escape(str(choice)) for choice in request.guided_choice
|
regex_escape(str(choice)) for choice in guided_params.choice
|
||||||
]
|
]
|
||||||
choices_regex = "(" + "|".join(choices) + ")"
|
choices_regex = "(" + "|".join(choices) + ")"
|
||||||
return choices_regex, GuidedDecodingMode.CHOICE
|
return choices_regex, GuidedDecodingMode.CHOICE
|
||||||
elif request.guided_grammar:
|
elif guided_params.grammar:
|
||||||
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
|
return guided_params.grammar, GuidedDecodingMode.GRAMMAR
|
||||||
elif (not isinstance(request, GuidedDecodingRequest)
|
elif guided_params.json_object:
|
||||||
and request.response_format is not None
|
|
||||||
and request.response_format.type == "json_object"):
|
|
||||||
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
|
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
|
||||||
elif (not isinstance(request, GuidedDecodingRequest)
|
|
||||||
and request.response_format is not None
|
|
||||||
and request.response_format.type == "json_schema"
|
|
||||||
and request.response_format.json_schema is not None
|
|
||||||
and request.response_format.json_schema.json_schema is not None):
|
|
||||||
json = json_dumps(request.response_format.json_schema.json_schema)
|
|
||||||
return json, GuidedDecodingMode.JSON
|
|
||||||
else:
|
else:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
"""Sampling parameters for text generation."""
|
"""Sampling parameters for text generation."""
|
||||||
import copy
|
import copy
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -34,6 +36,54 @@ first argument, and returns a modified tensor of logits
|
|||||||
to sample from."""
|
to sample from."""
|
||||||
|
|
||||||
|
|
||||||
|
# maybe make msgspec?
|
||||||
|
@dataclass
|
||||||
|
class GuidedDecodingParams:
|
||||||
|
"""One of these fields will be used to build a logit processor."""
|
||||||
|
json: Optional[Union[str, Dict]] = None
|
||||||
|
regex: Optional[str] = None
|
||||||
|
choice: Optional[List[str]] = None
|
||||||
|
grammar: Optional[str] = None
|
||||||
|
json_object: Optional[bool] = None
|
||||||
|
"""These are other options that can be set"""
|
||||||
|
backend: Optional[str] = None
|
||||||
|
whitespace_pattern: Optional[str] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_optional(
|
||||||
|
json: Optional[Union[Dict, BaseModel, str]],
|
||||||
|
regex: Optional[str] = None,
|
||||||
|
choice: Optional[List[str]] = None,
|
||||||
|
grammar: Optional[str] = None,
|
||||||
|
json_object: Optional[bool] = None,
|
||||||
|
backend: Optional[str] = None,
|
||||||
|
whitespace_pattern: Optional[str] = None,
|
||||||
|
) -> "GuidedDecodingParams":
|
||||||
|
# Extract json schemas from pydantic models
|
||||||
|
if isinstance(json, (BaseModel, type(BaseModel))):
|
||||||
|
json = json.model_json_schema()
|
||||||
|
return GuidedDecodingParams(
|
||||||
|
json=json,
|
||||||
|
regex=regex,
|
||||||
|
choice=choice,
|
||||||
|
grammar=grammar,
|
||||||
|
json_object=json_object,
|
||||||
|
backend=backend,
|
||||||
|
whitespace_pattern=whitespace_pattern,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Validate that some fields are mutually exclusive."""
|
||||||
|
guide_count = sum([
|
||||||
|
self.json is not None, self.regex is not None, self.choice
|
||||||
|
is not None, self.grammar is not None, self.json_object is not None
|
||||||
|
])
|
||||||
|
if guide_count > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"You can only use one kind of guided decoding but multiple are "
|
||||||
|
f"specified: {self.__dict__}")
|
||||||
|
|
||||||
|
|
||||||
class RequestOutputKind(Enum):
|
class RequestOutputKind(Enum):
|
||||||
# Return entire output so far in every RequestOutput
|
# Return entire output so far in every RequestOutput
|
||||||
CUMULATIVE = 0
|
CUMULATIVE = 0
|
||||||
@ -124,6 +174,13 @@ class SamplingParams(
|
|||||||
truncate_prompt_tokens: If set to an integer k, will use only the last k
|
truncate_prompt_tokens: If set to an integer k, will use only the last k
|
||||||
tokens from the prompt (i.e., left truncation). Defaults to None
|
tokens from the prompt (i.e., left truncation). Defaults to None
|
||||||
(i.e., no truncation).
|
(i.e., no truncation).
|
||||||
|
guided_decoding: If provided, the engine will construct a guided
|
||||||
|
decoding logits processor from these parameters. Defaults to None.
|
||||||
|
logit_bias: If provided, the engine will construct a logits processor
|
||||||
|
that applies these logit biases. Defaults to None.
|
||||||
|
allowed_token_ids: If provided, the engine will construct a logits
|
||||||
|
processor which only retains scores for the given token ids.
|
||||||
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
n: int = 1
|
n: int = 1
|
||||||
@ -164,6 +221,11 @@ class SamplingParams(
|
|||||||
output_text_buffer_length: int = 0
|
output_text_buffer_length: int = 0
|
||||||
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
|
_all_stop_token_ids: Set[int] = msgspec.field(default_factory=set)
|
||||||
|
|
||||||
|
# Fields used to construct logits processors
|
||||||
|
guided_decoding: Optional[GuidedDecodingParams] = None
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None
|
||||||
|
allowed_token_ids: Optional[List[int]] = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_optional(
|
def from_optional(
|
||||||
n: Optional[int] = 1,
|
n: Optional[int] = 1,
|
||||||
@ -194,7 +256,16 @@ class SamplingParams(
|
|||||||
truncate_prompt_tokens: Optional[Annotated[int,
|
truncate_prompt_tokens: Optional[Annotated[int,
|
||||||
msgspec.Meta(ge=1)]] = None,
|
msgspec.Meta(ge=1)]] = None,
|
||||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||||
|
guided_decoding: Optional[GuidedDecodingParams] = None,
|
||||||
|
logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None,
|
||||||
|
allowed_token_ids: Optional[List[int]] = None,
|
||||||
) -> "SamplingParams":
|
) -> "SamplingParams":
|
||||||
|
if logit_bias is not None:
|
||||||
|
logit_bias = {
|
||||||
|
int(token): bias
|
||||||
|
for token, bias in logit_bias.items()
|
||||||
|
}
|
||||||
|
|
||||||
return SamplingParams(
|
return SamplingParams(
|
||||||
n=1 if n is None else n,
|
n=1 if n is None else n,
|
||||||
best_of=best_of,
|
best_of=best_of,
|
||||||
@ -226,6 +297,9 @@ class SamplingParams(
|
|||||||
logits_processors=logits_processors,
|
logits_processors=logits_processors,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||||
output_kind=output_kind,
|
output_kind=output_kind,
|
||||||
|
guided_decoding=guided_decoding,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
allowed_token_ids=allowed_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
@ -454,4 +528,5 @@ class SamplingParams(
|
|||||||
f"skip_special_tokens={self.skip_special_tokens}, "
|
f"skip_special_tokens={self.skip_special_tokens}, "
|
||||||
"spaces_between_special_tokens="
|
"spaces_between_special_tokens="
|
||||||
f"{self.spaces_between_special_tokens}, "
|
f"{self.spaces_between_special_tokens}, "
|
||||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
|
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
|
||||||
|
f"guided_decoding={self.guided_decoding}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user