From a3432f18fdd85eb18e29fc32327507fe1063ad57 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 28 Aug 2025 05:26:45 -0700 Subject: [PATCH] [BugFix][Spec Decode] Use float64 for uniform_probs (#23803) Signed-off-by: Woosuk Kwon --- examples/offline_inference/spec_decode.py | 2 +- vllm/v1/sample/rejection_sampler.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index c4972f02d0f8e..5af232cb6af6a 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -138,7 +138,7 @@ def main(): sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if not args.custom_mm_prompts: outputs = llm.generate( - TokensPrompt(prompt_token_ids=prompt_ids), + [TokensPrompt(prompt_token_ids=x) for x in prompt_ids], sampling_params=sampling_params, ) else: diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 2d9ce3101b6c9..511cdb3234253 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -365,9 +365,14 @@ def generate_uniform_probs( A tensor of shape `(num_tokens, )` containing uniform random values in the range [0, 1). """ + # NOTE(woosuk): We deliberately use float64 instead of float32 here + # because when using float32, there's a non-negligible chance that + # uniform_prob is sampled to be exact 0.0 as reported in + # https://github.com/pytorch/pytorch/issues/16706. Using float64 + # mitigates the issue. uniform_probs = torch.rand( (num_tokens, ), - dtype=torch.float32, + dtype=torch.float64, device=device, ) start_idx = 0