From c57bb199b3eff82b4bdd5ffb089502936a5b0c9a Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Thu, 12 Jun 2025 19:30:09 -0400 Subject: [PATCH] [V1] Resolve failed concurrent structured output requests (#19565) Signed-off-by: Russell Bryant --- vllm/v1/worker/gpu_model_runner.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 987a24496d755..2fa9f25c37195 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -66,11 +66,15 @@ from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, if TYPE_CHECKING: import xgrammar as xgr + import xgrammar.kernels.apply_token_bitmask_inplace_torch_compile as xgr_torch_compile # noqa: E501 from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput else: xgr = LazyLoader("xgr", globals(), "xgrammar") + xgr_torch_compile = LazyLoader( + "xgr_torch_compile", globals(), + "xgrammar.kernels.apply_token_bitmask_inplace_torch_compile") logger = init_logger(__name__) @@ -1103,7 +1107,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): # so we receive it in that format. grammar_bitmask = torch.from_numpy(grammar_bitmask) - xgr.apply_token_bitmask_inplace( + # Force use of the torch.compile implementation from xgrammar to work + # around issues with the Triton kernel in concurrent structured output + # scenarios. See PR #19565 and issues #19493, #18376 for details. + xgr_torch_compile.apply_token_bitmask_inplace_torch_compile( logits, grammar_bitmask.to(self.device, non_blocking=True), indices=out_indices,