mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 23:05:21 +08:00
[Bugfix] Fix spec decoding when seed is none in a batch (#10863)
Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
parent
b880ffb87e
commit
86c2d8fd1c
@ -200,6 +200,69 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
|||||||
assert torch.equal(results[j][i], results[0][i])
|
assert torch.equal(results[j][i], results[0][i])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||||
|
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||||
|
@pytest.mark.parametrize("batch_size", [3, 8, 32, 128])
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@pytest.mark.parametrize("use_flashinfer", [True, False])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
|
||||||
|
device: str, use_flashinfer: bool):
|
||||||
|
torch.set_default_device(device)
|
||||||
|
set_random_seed(0)
|
||||||
|
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
|
target_probs = torch.rand(batch_size,
|
||||||
|
k + 1,
|
||||||
|
vocab_size,
|
||||||
|
dtype=torch.float32)
|
||||||
|
bonus_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, 1),
|
||||||
|
dtype=torch.int64)
|
||||||
|
draft_token_ids = torch.randint(low=0,
|
||||||
|
high=vocab_size,
|
||||||
|
size=(batch_size, k),
|
||||||
|
dtype=torch.int64)
|
||||||
|
|
||||||
|
single_batches = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
single_batches.append((draft_probs[i].clone().unsqueeze(0),
|
||||||
|
draft_token_ids[i].clone().unsqueeze(0),
|
||||||
|
target_probs[i].clone().unsqueeze(0),
|
||||||
|
bonus_token_ids[i].clone().unsqueeze(0),
|
||||||
|
draft_token_ids[i].clone().unsqueeze(0)))
|
||||||
|
|
||||||
|
set_random_seed(0)
|
||||||
|
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||||
|
rejection_sampler.init_gpu_tensors(device=device)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
seeded_seqs = {
|
||||||
|
i: torch.Generator(device=device).manual_seed(i)
|
||||||
|
for i in range(1, batch_size) # 0 is seed None
|
||||||
|
}
|
||||||
|
batch_result = rejection_sampler(target_probs.clone(),
|
||||||
|
bonus_token_ids.clone(),
|
||||||
|
draft_probs.clone(),
|
||||||
|
draft_token_ids.clone(), seeded_seqs)
|
||||||
|
|
||||||
|
set_random_seed(0)
|
||||||
|
|
||||||
|
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
|
||||||
|
rejection_sampler.init_gpu_tensors(device=device)
|
||||||
|
for i in range(batch_size):
|
||||||
|
request_seeded_seqs = {
|
||||||
|
0: torch.Generator(device=device).manual_seed(i)
|
||||||
|
} if seeded_seqs.get(i) is not None else None
|
||||||
|
(draft_probs, draft_token_ids, target_probs, bonus_token_ids,
|
||||||
|
draft_token_ids) = single_batches[i]
|
||||||
|
results.append(
|
||||||
|
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||||
|
draft_token_ids, request_seeded_seqs))
|
||||||
|
for i in range(batch_size):
|
||||||
|
assert torch.equal(batch_result[i], results[i].squeeze(0))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("k", [1, 3, 6])
|
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.jit
|
import torch.jit
|
||||||
@ -386,16 +386,12 @@ def _multinomial(
|
|||||||
if not seeded_seqs:
|
if not seeded_seqs:
|
||||||
q.exponential_(1.0)
|
q.exponential_(1.0)
|
||||||
else:
|
else:
|
||||||
non_seeded_indices: List[int] = []
|
|
||||||
start = 0
|
start = 0
|
||||||
for idx in range(len(q) // k):
|
for idx in range(len(q) // k):
|
||||||
end = start + k
|
end = start + k
|
||||||
generator = seeded_seqs.get(idx)
|
generator = seeded_seqs.get(idx)
|
||||||
if generator is None:
|
# Note: generator might be None for non seeded
|
||||||
non_seeded_indices.extend(list(range(start, end)))
|
|
||||||
else:
|
|
||||||
q[start:end].exponential_(1.0, generator=generator)
|
q[start:end].exponential_(1.0, generator=generator)
|
||||||
start = end
|
start = end
|
||||||
q[non_seeded_indices].exponential_(1.0)
|
|
||||||
|
|
||||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user