diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 8f0b38ecb34d..37ce5bef8403 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -164,12 +164,12 @@ def rejection_sample( assert target_probs.shape == (num_tokens, vocab_size) # Create output buffer. - output_token_ids = torch.empty( + output_token_ids = torch.full( (batch_size, max_spec_len + 1), + PLACEHOLDER_TOKEN_ID, dtype=torch.int32, # Consistent with SamplerOutput.sampled_token_ids. device=device, ) - output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) if sampling_metadata.all_greedy: is_greedy = None @@ -186,7 +186,6 @@ def rejection_sample( bonus_token_ids, is_greedy, max_spec_len, - num_warps=1, ) if sampling_metadata.all_greedy: return output_token_ids @@ -227,7 +226,6 @@ def rejection_sample( max_spec_len, vocab_size, NO_DRAFT_PROBS=draft_probs is None, - num_warps=1, ) return output_token_ids @@ -329,7 +327,6 @@ def expand_batch_to_tokens( replace_from, replace_to, MAX_NUM_TOKENS=MAX_SPEC_LEN, # To avoid recompilation. - num_warps=1, ) return expanded_x