diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 8f6c292620c2..3ce4a5f65819 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -25,7 +25,7 @@ def mock_causal_accepted_tensor( accepted = (torch.arange(k).expand(batch_size, k) <= last_accepted_indices.unsqueeze(-1).broadcast_to( - batch_size, k)).to(device="cuda") + batch_size, k)) # Sprinkle accepted values after the contiguous initial accepted values. # This replicates the behavior of rejection sampling, which may "accept" @@ -33,7 +33,7 @@ def mock_causal_accepted_tensor( sprinkle_candidates = ( torch.arange(k).expand(batch_size, k) > last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1) - sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5 + sprinkle = torch.rand(batch_size, k) > 0.5 accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates] return accepted @@ -86,7 +86,7 @@ def test_correct_output_format(which_tokens_accepted: str, rejection_sampler = RejectionSampler( disable_bonus_tokens=disable_bonus_tokens) - rejection_sampler.init_gpu_tensors(rank=0) + rejection_sampler.init_gpu_tensors(device=device) output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access accepted, recovered_token_ids, @@ -138,7 +138,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, device: str): torch.set_default_device(device) rejection_sampler = RejectionSampler() - rejection_sampler.init_gpu_tensors(rank=0) + rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) @@ -167,7 +167,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, device: str): torch.set_default_device(device) rejection_sampler = RejectionSampler() - rejection_sampler.init_gpu_tensors(rank=0) + rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) @@ -211,7 +211,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, torch.set_default_device(device) rejection_sampler = RejectionSampler(strict_mode=True) - rejection_sampler.init_gpu_tensors(rank=0) + rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) @@ -339,7 +339,7 @@ class _CorrectnessTestHelper: self.vocab_size = vocab_size self.vocab_range = (0, vocab_size) - self.rejection_sampler.init_gpu_tensors(rank=0) + self.rejection_sampler.init_gpu_tensors(device=0) # Keep test simple, use k=1 self.k = 1 diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 4f6290795b2c..aa3c1d29bdb3 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -78,7 +78,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, """ torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler() - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, @@ -111,7 +111,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, vocab_size = 30_000 torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, @@ -171,7 +171,7 @@ def test_uniform_target_distribution_accepts_all_tokens( torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) draft_token_ids = torch.randint(low=0, high=vocab_size, @@ -225,7 +225,7 @@ def test_temperature_zero_target_distribution(seed: int, typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) # Simulate temperature 0 probability distribution for target probabilities # and create target probabilities such that only 1 token id has # probability 1.0 @@ -285,7 +285,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) # For sequences 0 and 2 set the distribution to a temperature # zero distribution. For sequences 1 and 3 set it to a uniform # distribution. @@ -352,7 +352,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) # Create a temperature zero target probability distribution and ensure # all draft token ids correspond to the tokens with 1.0 probability. # Verify that all of them are accepted. @@ -414,7 +414,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) # Simulate temperature 0 probability distribution for target # probabilities and create target probabilities such that only 1 token # id has probability 1.0 and others have a very low probability of @@ -447,7 +447,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, disable_bonus_tokens=disable_bonus_tokens, posterior_threshold=0.0, posterior_alpha=0.0) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) output_token_ids = typical_acceptance_sampler( target_probs, bonus_token_ids, @@ -485,7 +485,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) - typical_acceptance_sampler.init_gpu_tensors(rank=0) + typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) expected_replacement_tokens = -torch.ones( (batch_size, k), dtype=torch.long) diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 3091e639727b..467c43c41550 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Dict, Optional +from typing import Dict, Optional, Union import torch import torch.jit @@ -36,9 +36,12 @@ class SpecDecodeBaseSampler(nn.Module): self.num_emitted_tokens: Optional[torch.Tensor] = None self.num_draft_tokens: int = 0 - def init_gpu_tensors(self, rank: int) -> None: + def init_gpu_tensors(self, device: Union[int, str]) -> None: assert self.num_accepted_tokens is None - device = f"cuda:{rank}" + if isinstance(device, int): + device = f"cuda:{device}" + elif not isinstance(device, str): + raise ValueError(f"Device must be int or str, get {type(device)}") self.num_accepted_tokens = torch.tensor(0, dtype=torch.long, device=device)