From 0f94a71de55130ace34973fb4cb5237bbe8a6c24 Mon Sep 17 00:00:00 2001 From: yurekami Date: Wed, 24 Dec 2025 23:09:18 +0900 Subject: [PATCH] fix(spec_decode): sync ngram draft tokens across TP ranks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When using tensor parallelism with external_launcher, ngram draft tokens could diverge across TP ranks due to non-determinism in numba parallel execution. This caused verification failures and crashes in speculative decoding. The fix ensures that only TP rank 0 computes draft tokens and broadcasts them to all other ranks using broadcast_object(), guaranteeing identical draft tokens across all TP ranks. Fixes #31154 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 Signed-off-by: yurekami --- vllm/v1/spec_decode/ngram_proposer.py | 61 +++++++++++++++++---------- 1 file changed, 38 insertions(+), 23 deletions(-) diff --git a/vllm/v1/spec_decode/ngram_proposer.py b/vllm/v1/spec_decode/ngram_proposer.py index 1273ca12c3600..75d0b3b835576 100644 --- a/vllm/v1/spec_decode/ngram_proposer.py +++ b/vllm/v1/spec_decode/ngram_proposer.py @@ -6,6 +6,7 @@ import numpy as np from numba import get_num_threads, jit, njit, prange, set_num_threads from vllm.config import VllmConfig +from vllm.distributed import get_tp_group class NgramProposer: @@ -137,33 +138,47 @@ class NgramProposer: token_ids_cpu: np.ndarray, spec_decode_unsupported_reqs: set, ) -> list[list[int]]: - # find which requests need ngram proposals - valid_ngram_requests = [] - for i, sampled_ids in enumerate(sampled_token_ids): - num_sampled_ids = len(sampled_ids) - if not num_sampled_ids: - # Skip speculative decoding. - continue + # Only compute draft tokens on TP rank 0 and broadcast to other ranks. + # This ensures all TP ranks have identical draft tokens, which is + # required because numba parallel execution can produce different + # results across ranks due to non-determinism. + tp_group = get_tp_group() + if tp_group.is_first_rank: + # find which requests need ngram proposals + valid_ngram_requests = [] + for i, sampled_ids in enumerate(sampled_token_ids): + num_sampled_ids = len(sampled_ids) + if not num_sampled_ids: + # Skip speculative decoding. + continue - # Skip requests that require sampling parameters that are not - # supported with speculative decoding. - req_id = req_ids[i] - if req_id in spec_decode_unsupported_reqs: - continue + # Skip requests that require sampling parameters that are not + # supported with speculative decoding. + req_id = req_ids[i] + if req_id in spec_decode_unsupported_reqs: + continue - num_tokens = num_tokens_no_spec[i] - if num_tokens >= self.max_model_len: - # Skip requests that have already reached the max model length. - continue + num_tokens = num_tokens_no_spec[i] + if num_tokens >= self.max_model_len: + # Skip requests that have already reached the max model length. + continue - valid_ngram_requests.append(i) + valid_ngram_requests.append(i) - draft_token_ids = self.batch_propose( - len(sampled_token_ids), - valid_ngram_requests, - num_tokens_no_spec, - token_ids_cpu, - ) + draft_token_ids = self.batch_propose( + len(sampled_token_ids), + valid_ngram_requests, + num_tokens_no_spec, + token_ids_cpu, + ) + else: + draft_token_ids = None + + # Broadcast draft tokens from rank 0 to all other ranks. + # Rank 0 always computes valid draft_token_ids, so broadcast + # will never return None. + draft_token_ids = tp_group.broadcast_object(draft_token_ids, src=0) + assert draft_token_ids is not None return draft_token_ids