[V0][Bugfix] Fix parallel sampling performance regression when guided decoding is enabled (#17731)

Signed-off-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Madeesh Kannan 2025-05-23 12:38:23 +02:00 committed by GitHub
parent 4ce64e2df4
commit e493e48524
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 40 additions and 8 deletions

View File

@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import os
from typing import Any
@ -34,9 +35,24 @@ class GuidanceLogitsProcessor:
self.grammar = grammar
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer.name_or_path
self.ll_tokenizer = None
self.ll_matcher = None
self.bitmask = None
self.new_sampling = False
self.initialized = False
def clone(self) -> "GuidanceLogitsProcessor":
cloned = copy.copy(self)
if self.initialized:
cloned.ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer, # type: ignore[assignment]
self.grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
return cloned
def _initialize(self):
if self.initialized:
return
@ -56,7 +72,7 @@ class GuidanceLogitsProcessor:
# create reusable bitmask
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size)
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
self.initialized = True
@ -70,15 +86,17 @@ class GuidanceLogitsProcessor:
self._initialize()
if self.new_sampling and len(input_ids) > 0:
self.ll_matcher.consume_token(input_ids[-1])
err = self.ll_matcher.get_error()
self.ll_matcher.consume_token( # type: ignore[attr-defined]
input_ids[-1])
err = self.ll_matcher.get_error() # type: ignore[attr-defined]
if err:
logger.warning("Error in LLMatcher: %s", err)
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
0)
llguidance.torch.apply_token_bitmask_inplace(
scores, self.bitmask.to(scores.device))
scores,
self.bitmask.to(scores.device)) # type: ignore[attr-defined]
self.new_sampling = True

View File

@ -56,6 +56,12 @@ class BaseLogitsProcessor:
self._fsm_state: defaultdict[int, Union[int,
CFGState]] = defaultdict(int)
def clone(self) -> "BaseLogitsProcessor":
cloned = copy.copy(self)
cloned._guide = self._guide.copy()
cloned._fsm_state = copy.deepcopy(self._fsm_state)
return cloned
def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
@ -218,6 +224,12 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
reasoner)
self._guide = self._guide.copy()
def clone(self) -> "CFGLogitsProcessor":
cloned = copy.copy(self)
cloned._fsm_state = copy.deepcopy(self._fsm_state)
cloned._guide = self._guide.copy()
return cloned
@lru_cache(maxsize=32)
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):

View File

@ -302,8 +302,9 @@ class XGrammarLogitsProcessor:
prefilled: bool = field(default=False)
def __post_init__(self):
self.tokenizer_info = self.config.tokenizer_info(
self.config.tokenizer_data)
if self.tokenizer_info is None:
self.tokenizer_info = self.config.tokenizer_info(
self.config.tokenizer_data)
def __getstate__(self) -> dict[str, Any]:
return {'config': self.config, 'reasoner': self.reasoner}
@ -400,7 +401,8 @@ class XGrammarLogitsProcessor:
def clone(self) -> XGrammarLogitsProcessor:
"""Create a new instance with shared compiled grammar
but separate state"""
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner)
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner,
None, self.tokenizer_info)
# Share the compiled grammar context (immutable after compilation)
new_processor.ctx = self.ctx

View File

@ -1494,7 +1494,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i
params = copy.deepcopy(original_params)
params = params.clone()
params.n = 1
if params.seed is not None:
params.seed += i