mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[SpecDecode] [Minor] Fix spec decode sampler tests (#7183)
This commit is contained in:
parent
00afc78590
commit
5c60c8c423
@ -25,7 +25,7 @@ def mock_causal_accepted_tensor(
|
|||||||
|
|
||||||
accepted = (torch.arange(k).expand(batch_size, k) <=
|
accepted = (torch.arange(k).expand(batch_size, k) <=
|
||||||
last_accepted_indices.unsqueeze(-1).broadcast_to(
|
last_accepted_indices.unsqueeze(-1).broadcast_to(
|
||||||
batch_size, k)).to(device="cuda")
|
batch_size, k))
|
||||||
|
|
||||||
# Sprinkle accepted values after the contiguous initial accepted values.
|
# Sprinkle accepted values after the contiguous initial accepted values.
|
||||||
# This replicates the behavior of rejection sampling, which may "accept"
|
# This replicates the behavior of rejection sampling, which may "accept"
|
||||||
@ -33,7 +33,7 @@ def mock_causal_accepted_tensor(
|
|||||||
sprinkle_candidates = (
|
sprinkle_candidates = (
|
||||||
torch.arange(k).expand(batch_size, k) >
|
torch.arange(k).expand(batch_size, k) >
|
||||||
last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
|
last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
|
||||||
sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5
|
sprinkle = torch.rand(batch_size, k) > 0.5
|
||||||
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
|
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
|
||||||
return accepted
|
return accepted
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ def test_correct_output_format(which_tokens_accepted: str,
|
|||||||
|
|
||||||
rejection_sampler = RejectionSampler(
|
rejection_sampler = RejectionSampler(
|
||||||
disable_bonus_tokens=disable_bonus_tokens)
|
disable_bonus_tokens=disable_bonus_tokens)
|
||||||
rejection_sampler.init_gpu_tensors(rank=0)
|
rejection_sampler.init_gpu_tensors(device=device)
|
||||||
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
|
||||||
accepted,
|
accepted,
|
||||||
recovered_token_ids,
|
recovered_token_ids,
|
||||||
@ -138,7 +138,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
|||||||
device: str):
|
device: str):
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
rejection_sampler = RejectionSampler()
|
rejection_sampler = RejectionSampler()
|
||||||
rejection_sampler.init_gpu_tensors(rank=0)
|
rejection_sampler.init_gpu_tensors(device=device)
|
||||||
|
|
||||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
@ -167,7 +167,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
|||||||
device: str):
|
device: str):
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
rejection_sampler = RejectionSampler()
|
rejection_sampler = RejectionSampler()
|
||||||
rejection_sampler.init_gpu_tensors(rank=0)
|
rejection_sampler.init_gpu_tensors(device=device)
|
||||||
|
|
||||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
@ -211,7 +211,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
|
|
||||||
rejection_sampler = RejectionSampler(strict_mode=True)
|
rejection_sampler = RejectionSampler(strict_mode=True)
|
||||||
rejection_sampler.init_gpu_tensors(rank=0)
|
rejection_sampler.init_gpu_tensors(device=device)
|
||||||
|
|
||||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
@ -339,7 +339,7 @@ class _CorrectnessTestHelper:
|
|||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.vocab_range = (0, vocab_size)
|
self.vocab_range = (0, vocab_size)
|
||||||
|
|
||||||
self.rejection_sampler.init_gpu_tensors(rank=0)
|
self.rejection_sampler.init_gpu_tensors(device=0)
|
||||||
|
|
||||||
# Keep test simple, use k=1
|
# Keep test simple, use k=1
|
||||||
self.k = 1
|
self.k = 1
|
||||||
|
|||||||
@ -78,7 +78,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
|||||||
"""
|
"""
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler()
|
typical_acceptance_sampler = get_acceptance_sampler()
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
bonus_token_ids = torch.randint(low=0,
|
bonus_token_ids = torch.randint(low=0,
|
||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
@ -111,7 +111,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
|||||||
vocab_size = 30_000
|
vocab_size = 30_000
|
||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
bonus_token_ids = torch.randint(low=0,
|
bonus_token_ids = torch.randint(low=0,
|
||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
@ -171,7 +171,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
draft_token_ids = torch.randint(low=0,
|
draft_token_ids = torch.randint(low=0,
|
||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
@ -225,7 +225,7 @@ def test_temperature_zero_target_distribution(seed: int,
|
|||||||
|
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
# Simulate temperature 0 probability distribution for target probabilities
|
# Simulate temperature 0 probability distribution for target probabilities
|
||||||
# and create target probabilities such that only 1 token id has
|
# and create target probabilities such that only 1 token id has
|
||||||
# probability 1.0
|
# probability 1.0
|
||||||
@ -285,7 +285,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
# For sequences 0 and 2 set the distribution to a temperature
|
# For sequences 0 and 2 set the distribution to a temperature
|
||||||
# zero distribution. For sequences 1 and 3 set it to a uniform
|
# zero distribution. For sequences 1 and 3 set it to a uniform
|
||||||
# distribution.
|
# distribution.
|
||||||
@ -352,7 +352,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
# Create a temperature zero target probability distribution and ensure
|
# Create a temperature zero target probability distribution and ensure
|
||||||
# all draft token ids correspond to the tokens with 1.0 probability.
|
# all draft token ids correspond to the tokens with 1.0 probability.
|
||||||
# Verify that all of them are accepted.
|
# Verify that all of them are accepted.
|
||||||
@ -414,7 +414,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
# Simulate temperature 0 probability distribution for target
|
# Simulate temperature 0 probability distribution for target
|
||||||
# probabilities and create target probabilities such that only 1 token
|
# probabilities and create target probabilities such that only 1 token
|
||||||
# id has probability 1.0 and others have a very low probability of
|
# id has probability 1.0 and others have a very low probability of
|
||||||
@ -447,7 +447,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
|
|||||||
disable_bonus_tokens=disable_bonus_tokens,
|
disable_bonus_tokens=disable_bonus_tokens,
|
||||||
posterior_threshold=0.0,
|
posterior_threshold=0.0,
|
||||||
posterior_alpha=0.0)
|
posterior_alpha=0.0)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
output_token_ids = typical_acceptance_sampler(
|
output_token_ids = typical_acceptance_sampler(
|
||||||
target_probs,
|
target_probs,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
@ -485,7 +485,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
|
|||||||
torch.set_default_device(device)
|
torch.set_default_device(device)
|
||||||
typical_acceptance_sampler = get_acceptance_sampler(
|
typical_acceptance_sampler = get_acceptance_sampler(
|
||||||
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
|
||||||
typical_acceptance_sampler.init_gpu_tensors(rank=0)
|
typical_acceptance_sampler.init_gpu_tensors(device=device)
|
||||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||||
expected_replacement_tokens = -torch.ones(
|
expected_replacement_tokens = -torch.ones(
|
||||||
(batch_size, k), dtype=torch.long)
|
(batch_size, k), dtype=torch.long)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.jit
|
import torch.jit
|
||||||
@ -36,9 +36,12 @@ class SpecDecodeBaseSampler(nn.Module):
|
|||||||
self.num_emitted_tokens: Optional[torch.Tensor] = None
|
self.num_emitted_tokens: Optional[torch.Tensor] = None
|
||||||
self.num_draft_tokens: int = 0
|
self.num_draft_tokens: int = 0
|
||||||
|
|
||||||
def init_gpu_tensors(self, rank: int) -> None:
|
def init_gpu_tensors(self, device: Union[int, str]) -> None:
|
||||||
assert self.num_accepted_tokens is None
|
assert self.num_accepted_tokens is None
|
||||||
device = f"cuda:{rank}"
|
if isinstance(device, int):
|
||||||
|
device = f"cuda:{device}"
|
||||||
|
elif not isinstance(device, str):
|
||||||
|
raise ValueError(f"Device must be int or str, get {type(device)}")
|
||||||
self.num_accepted_tokens = torch.tensor(0,
|
self.num_accepted_tokens = torch.tensor(0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user