mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 11:02:17 +08:00
[BugFix][Spec Decode] Use float64 for uniform_probs (#23803)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
67cee40da0
commit
a3432f18fd
@ -138,7 +138,7 @@ def main():
|
|||||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||||
if not args.custom_mm_prompts:
|
if not args.custom_mm_prompts:
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
TokensPrompt(prompt_token_ids=prompt_ids),
|
[TokensPrompt(prompt_token_ids=x) for x in prompt_ids],
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -365,9 +365,14 @@ def generate_uniform_probs(
|
|||||||
A tensor of shape `(num_tokens, )` containing uniform
|
A tensor of shape `(num_tokens, )` containing uniform
|
||||||
random values in the range [0, 1).
|
random values in the range [0, 1).
|
||||||
"""
|
"""
|
||||||
|
# NOTE(woosuk): We deliberately use float64 instead of float32 here
|
||||||
|
# because when using float32, there's a non-negligible chance that
|
||||||
|
# uniform_prob is sampled to be exact 0.0 as reported in
|
||||||
|
# https://github.com/pytorch/pytorch/issues/16706. Using float64
|
||||||
|
# mitigates the issue.
|
||||||
uniform_probs = torch.rand(
|
uniform_probs = torch.rand(
|
||||||
(num_tokens, ),
|
(num_tokens, ),
|
||||||
dtype=torch.float32,
|
dtype=torch.float64,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user