diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index a108cd7bf9a1e..c43e4d793e112 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -195,7 +195,7 @@ def rejection_sample( num_warps=1, ) if sampling_metadata.all_greedy: - return output_token_ids + return output_token_ids, output_probs # Generate uniform probabilities for rejection sampling. # [num_tokens] @@ -475,8 +475,8 @@ def rejection_greedy_sample_kernel( if draft_token_id != target_argmax_id: # Reject. rejected = True - tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, - not rejected) + tl.store(output_probs_ptr + req_idx * (max_spec_len + 1) + pos, + draft_token_id == target_argmax_id) if not rejected: # If all tokens are accepted, append the bonus token.