mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 00:22:20 +08:00
[Bugfix] Fix OpenAI parallel sampling when using xgrammar (#11637)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
a2a40bcd0d
commit
74fa1d123c
@ -28,6 +28,8 @@ PA_NAME = "swapnilbp/llama_tweet_ptune"
|
|||||||
# need to change to match the prompt adapter
|
# need to change to match the prompt adapter
|
||||||
PA_NUM_VIRTUAL_TOKENS = 8
|
PA_NUM_VIRTUAL_TOKENS = 8
|
||||||
|
|
||||||
|
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def zephyr_lora_files():
|
def zephyr_lora_files():
|
||||||
@ -635,8 +637,7 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||||
["outlines", "lm-format-enforcer"])
|
|
||||||
async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
sample_json_schema):
|
sample_json_schema):
|
||||||
@ -658,8 +659,7 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||||
["outlines", "lm-format-enforcer"])
|
|
||||||
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
sample_regex):
|
sample_regex):
|
||||||
@ -680,8 +680,7 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||||
["outlines", "lm-format-enforcer"])
|
|
||||||
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
|
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
sample_guided_choice):
|
sample_guided_choice):
|
||||||
@ -761,8 +760,7 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||||
["outlines", "lm-format-enforcer"])
|
|
||||||
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
||||||
guided_decoding_backend: str,
|
guided_decoding_backend: str,
|
||||||
sample_json_schema, sample_regex):
|
sample_json_schema, sample_regex):
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# noqa: UP007
|
# noqa: UP007
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@ -309,3 +310,7 @@ class XGrammarLogitsProcessor:
|
|||||||
scores = scores.to(device_type).squeeze()
|
scores = scores.to(device_type).squeeze()
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
def clone(self) -> XGrammarLogitsProcessor:
|
||||||
|
"""Deepcopy due to per-sequence state in the matchers"""
|
||||||
|
return copy.deepcopy(self)
|
||||||
|
|||||||
@ -450,15 +450,16 @@ class SamplingParams(
|
|||||||
return self._all_stop_token_ids
|
return self._all_stop_token_ids
|
||||||
|
|
||||||
def clone(self) -> "SamplingParams":
|
def clone(self) -> "SamplingParams":
|
||||||
"""Deep copy excluding LogitsProcessor objects.
|
"""Deep copy, but maybe not the LogitsProcessor objects.
|
||||||
|
|
||||||
LogitsProcessor objects are excluded because they may contain an
|
LogitsProcessor objects may contain an arbitrary, nontrivial amount of
|
||||||
arbitrary, nontrivial amount of data.
|
data that is expensive to copy. However, if not copied, the processor
|
||||||
|
needs to support parallel decoding for multiple sequences
|
||||||
See https://github.com/vllm-project/vllm/issues/3087
|
See https://github.com/vllm-project/vllm/issues/3087
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logit_processor_refs = None if self.logits_processors is None else {
|
logit_processor_refs = None if self.logits_processors is None else {
|
||||||
id(lp): lp
|
id(lp): lp.clone() if hasattr(lp, 'clone') else lp
|
||||||
for lp in self.logits_processors
|
for lp in self.logits_processors
|
||||||
}
|
}
|
||||||
return copy.deepcopy(self, memo=logit_processor_refs)
|
return copy.deepcopy(self, memo=logit_processor_refs)
|
||||||
|
|||||||
@ -1372,7 +1372,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def add_request(request_id: str, engine, params, **kwargs):
|
def add_request(request_id: str, engine, params, **kwargs):
|
||||||
original_params = params
|
original_params = params
|
||||||
params = copy.deepcopy(original_params)
|
params = original_params.clone()
|
||||||
params.n = 1
|
params.n = 1
|
||||||
group = ParallelSampleSequenceGroup(request_id)
|
group = ParallelSampleSequenceGroup(request_id)
|
||||||
seqs = []
|
seqs = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user