From 57d4ede520b6071341ebd310c0ddd4c6f4d54917 Mon Sep 17 00:00:00 2001 From: Jingkai He Date: Fri, 29 Aug 2025 03:05:20 +0800 Subject: [PATCH] [bugfix] [spec-decoding] fix data race in sample_recovered_tokens_kernel (vLLM v1) (#23829) Signed-off-by: He-Jingkai --- vllm/v1/sample/rejection_sampler.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 511cdb3234253..3d5e59addfcfa 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -598,17 +598,10 @@ def sample_recovered_tokens_kernel( vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + - draft_token_id) - # Temporarily zero out the probability of the draft token. - # This is essentially the same as target_prob - draft_prob, except that - # n-gram does not have draft_prob. We regard it as 1. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - 0) prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, + mask=((vocab_offset < vocab_size) & + (vocab_offset != draft_token_id)), other=0) else: draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + @@ -628,9 +621,3 @@ def sample_recovered_tokens_kernel( other=float("-inf")) recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) - - if NO_DRAFT_PROBS: - # Restore the original probability. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - orig_prob)