fix(spec_decode): sync ngram draft tokens across TP ranks

When using tensor parallelism with external_launcher, ngram draft
tokens could diverge across TP ranks due to non-determinism in
numba parallel execution. This caused verification failures and
crashes in speculative decoding.

The fix ensures that only TP rank 0 computes draft tokens and
broadcasts them to all other ranks using broadcast_object(),
guaranteeing identical draft tokens across all TP ranks.

Fixes #31154

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Signed-off-by: yurekami <yurekami@users.noreply.github.com>
This commit is contained in:
yurekami 2025-12-24 23:09:18 +09:00
parent 7cd288a4b3
commit 0f94a71de5

View File

@ -6,6 +6,7 @@ import numpy as np
from numba import get_num_threads, jit, njit, prange, set_num_threads
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group
class NgramProposer:
@ -137,33 +138,47 @@ class NgramProposer:
token_ids_cpu: np.ndarray,
spec_decode_unsupported_reqs: set,
) -> list[list[int]]:
# find which requests need ngram proposals
valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
continue
# Only compute draft tokens on TP rank 0 and broadcast to other ranks.
# This ensures all TP ranks have identical draft tokens, which is
# required because numba parallel execution can produce different
# results across ranks due to non-determinism.
tp_group = get_tp_group()
if tp_group.is_first_rank:
# find which requests need ngram proposals
valid_ngram_requests = []
for i, sampled_ids in enumerate(sampled_token_ids):
num_sampled_ids = len(sampled_ids)
if not num_sampled_ids:
# Skip speculative decoding.
continue
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = req_ids[i]
if req_id in spec_decode_unsupported_reqs:
continue
# Skip requests that require sampling parameters that are not
# supported with speculative decoding.
req_id = req_ids[i]
if req_id in spec_decode_unsupported_reqs:
continue
num_tokens = num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
continue
num_tokens = num_tokens_no_spec[i]
if num_tokens >= self.max_model_len:
# Skip requests that have already reached the max model length.
continue
valid_ngram_requests.append(i)
valid_ngram_requests.append(i)
draft_token_ids = self.batch_propose(
len(sampled_token_ids),
valid_ngram_requests,
num_tokens_no_spec,
token_ids_cpu,
)
draft_token_ids = self.batch_propose(
len(sampled_token_ids),
valid_ngram_requests,
num_tokens_no_spec,
token_ids_cpu,
)
else:
draft_token_ids = None
# Broadcast draft tokens from rank 0 to all other ranks.
# Rank 0 always computes valid draft_token_ids, so broadcast
# will never return None.
draft_token_ids = tp_group.broadcast_object(draft_token_ids, src=0)
assert draft_token_ids is not None
return draft_token_ids