mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-10 11:24:36 +08:00
[Bugfix] Fix OpenAI parallel sampling when using xgrammar (#11637)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
a2a40bcd0d
commit
74fa1d123c
@ -28,6 +28,8 @@ PA_NAME = "swapnilbp/llama_tweet_ptune"
|
||||
# need to change to match the prompt adapter
|
||||
PA_NUM_VIRTUAL_TOKENS = 8
|
||||
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def zephyr_lora_files():
|
||||
@ -635,8 +637,7 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema):
|
||||
@ -658,8 +659,7 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_regex):
|
||||
@ -680,8 +680,7 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_guided_choice):
|
||||
@ -761,8 +760,7 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str,
|
||||
sample_json_schema, sample_regex):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# noqa: UP007
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -309,3 +310,7 @@ class XGrammarLogitsProcessor:
|
||||
scores = scores.to(device_type).squeeze()
|
||||
|
||||
return scores
|
||||
|
||||
def clone(self) -> XGrammarLogitsProcessor:
|
||||
"""Deepcopy due to per-sequence state in the matchers"""
|
||||
return copy.deepcopy(self)
|
||||
|
||||
@ -450,15 +450,16 @@ class SamplingParams(
|
||||
return self._all_stop_token_ids
|
||||
|
||||
def clone(self) -> "SamplingParams":
|
||||
"""Deep copy excluding LogitsProcessor objects.
|
||||
"""Deep copy, but maybe not the LogitsProcessor objects.
|
||||
|
||||
LogitsProcessor objects are excluded because they may contain an
|
||||
arbitrary, nontrivial amount of data.
|
||||
LogitsProcessor objects may contain an arbitrary, nontrivial amount of
|
||||
data that is expensive to copy. However, if not copied, the processor
|
||||
needs to support parallel decoding for multiple sequences
|
||||
See https://github.com/vllm-project/vllm/issues/3087
|
||||
"""
|
||||
|
||||
logit_processor_refs = None if self.logits_processors is None else {
|
||||
id(lp): lp
|
||||
id(lp): lp.clone() if hasattr(lp, 'clone') else lp
|
||||
for lp in self.logits_processors
|
||||
}
|
||||
return copy.deepcopy(self, memo=logit_processor_refs)
|
||||
|
||||
@ -1372,7 +1372,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
||||
@staticmethod
|
||||
def add_request(request_id: str, engine, params, **kwargs):
|
||||
original_params = params
|
||||
params = copy.deepcopy(original_params)
|
||||
params = original_params.clone()
|
||||
params.n = 1
|
||||
group = ParallelSampleSequenceGroup(request_id)
|
||||
seqs = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user