mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-31 03:47:03 +08:00
[V0][Bugfix] Fix parallel sampling performance regression when guided decoding is enabled (#17731)
Signed-off-by: Madeesh Kannan <shadeMe@users.noreply.github.com> Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
4ce64e2df4
commit
e493e48524
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import copy
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
@ -34,9 +35,24 @@ class GuidanceLogitsProcessor:
|
||||
self.grammar = grammar
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_name = tokenizer.name_or_path
|
||||
self.ll_tokenizer = None
|
||||
self.ll_matcher = None
|
||||
self.bitmask = None
|
||||
self.new_sampling = False
|
||||
self.initialized = False
|
||||
|
||||
def clone(self) -> "GuidanceLogitsProcessor":
|
||||
cloned = copy.copy(self)
|
||||
if self.initialized:
|
||||
cloned.ll_matcher = llguidance.LLMatcher(
|
||||
self.ll_tokenizer, # type: ignore[assignment]
|
||||
self.grammar,
|
||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||
)
|
||||
self.bitmask = llguidance.torch.allocate_token_bitmask(
|
||||
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
|
||||
return cloned
|
||||
|
||||
def _initialize(self):
|
||||
if self.initialized:
|
||||
return
|
||||
@ -56,7 +72,7 @@ class GuidanceLogitsProcessor:
|
||||
|
||||
# create reusable bitmask
|
||||
self.bitmask = llguidance.torch.allocate_token_bitmask(
|
||||
1, self.ll_tokenizer.vocab_size)
|
||||
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
|
||||
|
||||
self.initialized = True
|
||||
|
||||
@ -70,15 +86,17 @@ class GuidanceLogitsProcessor:
|
||||
self._initialize()
|
||||
|
||||
if self.new_sampling and len(input_ids) > 0:
|
||||
self.ll_matcher.consume_token(input_ids[-1])
|
||||
err = self.ll_matcher.get_error()
|
||||
self.ll_matcher.consume_token( # type: ignore[attr-defined]
|
||||
input_ids[-1])
|
||||
err = self.ll_matcher.get_error() # type: ignore[attr-defined]
|
||||
if err:
|
||||
logger.warning("Error in LLMatcher: %s", err)
|
||||
|
||||
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
|
||||
0)
|
||||
llguidance.torch.apply_token_bitmask_inplace(
|
||||
scores, self.bitmask.to(scores.device))
|
||||
scores,
|
||||
self.bitmask.to(scores.device)) # type: ignore[attr-defined]
|
||||
|
||||
self.new_sampling = True
|
||||
|
||||
|
||||
@ -56,6 +56,12 @@ class BaseLogitsProcessor:
|
||||
self._fsm_state: defaultdict[int, Union[int,
|
||||
CFGState]] = defaultdict(int)
|
||||
|
||||
def clone(self) -> "BaseLogitsProcessor":
|
||||
cloned = copy.copy(self)
|
||||
cloned._guide = self._guide.copy()
|
||||
cloned._fsm_state = copy.deepcopy(self._fsm_state)
|
||||
return cloned
|
||||
|
||||
def __call__(self, input_ids: list[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
"""Use the FSM to bias the logits before sampling the next token."""
|
||||
@ -218,6 +224,12 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
reasoner)
|
||||
self._guide = self._guide.copy()
|
||||
|
||||
def clone(self) -> "CFGLogitsProcessor":
|
||||
cloned = copy.copy(self)
|
||||
cloned._fsm_state = copy.deepcopy(self._fsm_state)
|
||||
cloned._guide = self._guide.copy()
|
||||
return cloned
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||
|
||||
@ -302,8 +302,9 @@ class XGrammarLogitsProcessor:
|
||||
prefilled: bool = field(default=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.tokenizer_info = self.config.tokenizer_info(
|
||||
self.config.tokenizer_data)
|
||||
if self.tokenizer_info is None:
|
||||
self.tokenizer_info = self.config.tokenizer_info(
|
||||
self.config.tokenizer_data)
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
return {'config': self.config, 'reasoner': self.reasoner}
|
||||
@ -400,7 +401,8 @@ class XGrammarLogitsProcessor:
|
||||
def clone(self) -> XGrammarLogitsProcessor:
|
||||
"""Create a new instance with shared compiled grammar
|
||||
but separate state"""
|
||||
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner)
|
||||
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner,
|
||||
None, self.tokenizer_info)
|
||||
|
||||
# Share the compiled grammar context (immutable after compilation)
|
||||
new_processor.ctx = self.ctx
|
||||
|
||||
@ -1494,7 +1494,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
|
||||
for i in range(original_params.n):
|
||||
request_id_i = f"{request_id}_parallel_sample_{i}"
|
||||
group.seq_id_to_index[request_id_i] = i
|
||||
params = copy.deepcopy(original_params)
|
||||
params = params.clone()
|
||||
params.n = 1
|
||||
if params.seed is not None:
|
||||
params.seed += i
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user