From e493e48524e9e78ab33eafec6461b3940e361189 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Fri, 23 May 2025 12:38:23 +0200 Subject: [PATCH] [V0][Bugfix] Fix parallel sampling performance regression when guided decoding is enabled (#17731) Signed-off-by: Madeesh Kannan Co-authored-by: Russell Bryant --- .../guidance_logits_processors.py | 26 ++++++++++++++++--- .../outlines_logits_processors.py | 12 +++++++++ .../guided_decoding/xgrammar_decoding.py | 8 +++--- vllm/sequence.py | 2 +- 4 files changed, 40 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/guided_decoding/guidance_logits_processors.py b/vllm/model_executor/guided_decoding/guidance_logits_processors.py index 4b45c272adc52..e17df68b4b4da 100644 --- a/vllm/model_executor/guided_decoding/guidance_logits_processors.py +++ b/vllm/model_executor/guided_decoding/guidance_logits_processors.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +import copy import os from typing import Any @@ -34,9 +35,24 @@ class GuidanceLogitsProcessor: self.grammar = grammar self.tokenizer = tokenizer self.tokenizer_name = tokenizer.name_or_path + self.ll_tokenizer = None + self.ll_matcher = None + self.bitmask = None self.new_sampling = False self.initialized = False + def clone(self) -> "GuidanceLogitsProcessor": + cloned = copy.copy(self) + if self.initialized: + cloned.ll_matcher = llguidance.LLMatcher( + self.ll_tokenizer, # type: ignore[assignment] + self.grammar, + log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), + ) + self.bitmask = llguidance.torch.allocate_token_bitmask( + 1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined] + return cloned + def _initialize(self): if self.initialized: return @@ -56,7 +72,7 @@ class GuidanceLogitsProcessor: # create reusable bitmask self.bitmask = llguidance.torch.allocate_token_bitmask( - 1, self.ll_tokenizer.vocab_size) + 1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined] self.initialized = True @@ -70,15 +86,17 @@ class GuidanceLogitsProcessor: self._initialize() if self.new_sampling and len(input_ids) > 0: - self.ll_matcher.consume_token(input_ids[-1]) - err = self.ll_matcher.get_error() + self.ll_matcher.consume_token( # type: ignore[attr-defined] + input_ids[-1]) + err = self.ll_matcher.get_error() # type: ignore[attr-defined] if err: logger.warning("Error in LLMatcher: %s", err) llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask, 0) llguidance.torch.apply_token_bitmask_inplace( - scores, self.bitmask.to(scores.device)) + scores, + self.bitmask.to(scores.device)) # type: ignore[attr-defined] self.new_sampling = True diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 8ae7c7b6b2c78..6986b6554c230 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -56,6 +56,12 @@ class BaseLogitsProcessor: self._fsm_state: defaultdict[int, Union[int, CFGState]] = defaultdict(int) + def clone(self) -> "BaseLogitsProcessor": + cloned = copy.copy(self) + cloned._guide = self._guide.copy() + cloned._fsm_state = copy.deepcopy(self._fsm_state) + return cloned + def __call__(self, input_ids: list[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" @@ -218,6 +224,12 @@ class CFGLogitsProcessor(BaseLogitsProcessor): reasoner) self._guide = self._guide.copy() + def clone(self) -> "CFGLogitsProcessor": + cloned = copy.copy(self) + cloned._fsm_state = copy.deepcopy(self._fsm_state) + cloned._guide = self._guide.copy() + return cloned + @lru_cache(maxsize=32) def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py index 8e40da4b3aa99..7ca7bab818fca 100644 --- a/vllm/model_executor/guided_decoding/xgrammar_decoding.py +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -302,8 +302,9 @@ class XGrammarLogitsProcessor: prefilled: bool = field(default=False) def __post_init__(self): - self.tokenizer_info = self.config.tokenizer_info( - self.config.tokenizer_data) + if self.tokenizer_info is None: + self.tokenizer_info = self.config.tokenizer_info( + self.config.tokenizer_data) def __getstate__(self) -> dict[str, Any]: return {'config': self.config, 'reasoner': self.reasoner} @@ -400,7 +401,8 @@ class XGrammarLogitsProcessor: def clone(self) -> XGrammarLogitsProcessor: """Create a new instance with shared compiled grammar but separate state""" - new_processor = XGrammarLogitsProcessor(self.config, self.reasoner) + new_processor = XGrammarLogitsProcessor(self.config, self.reasoner, + None, self.tokenizer_info) # Share the compiled grammar context (immutable after compilation) new_processor.ctx = self.ctx diff --git a/vllm/sequence.py b/vllm/sequence.py index f5f9c56a7db23..f3dfd32d9169e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1494,7 +1494,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase): for i in range(original_params.n): request_id_i = f"{request_id}_parallel_sample_{i}" group.seq_id_to_index[request_id_i] = i - params = copy.deepcopy(original_params) + params = params.clone() params.n = 1 if params.seed is not None: params.seed += i