diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 511cdb3234253..3d5e59addfcfa 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -598,17 +598,10 @@ def sample_recovered_tokens_kernel( vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) if NO_DRAFT_PROBS: draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + - draft_token_id) - # Temporarily zero out the probability of the draft token. - # This is essentially the same as target_prob - draft_prob, except that - # n-gram does not have draft_prob. We regard it as 1. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - 0) prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, + mask=((vocab_offset < vocab_size) & + (vocab_offset != draft_token_id)), other=0) else: draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + @@ -628,9 +621,3 @@ def sample_recovered_tokens_kernel( other=float("-inf")) recovered_id = tl.argmax(prob / q, axis=-1) tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) - - if NO_DRAFT_PROBS: - # Restore the original probability. - tl.store( - target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id, - orig_prob)