diff --git a/Dockerfile b/Dockerfile index 9bae9a12c0eb..ec6069f605eb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -162,7 +162,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist RUN --mount=type=cache,target=/root/.cache/pip \ . /etc/environment && \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl + python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl #################### vLLM installation IMAGE #################### diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 3ce4a5f65819..91a9d879eb4a 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -44,12 +44,16 @@ def mock_causal_accepted_tensor( ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() -def test_correct_output_format(which_tokens_accepted: str, - disable_bonus_tokens: bool, seed: int, - device: str): +def test_correct_output_format(which_tokens_accepted: str, seed: int, + disable_bonus_tokens: bool, device: str, + use_flashinfer: bool): """Verify the output has correct format given predetermined accepted matrix. """ + if use_flashinfer and disable_bonus_tokens: + pytest.skip("Flashinfer rejection sampler must enable bonus token.") + set_random_seed(seed) torch.set_default_device(device) @@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str, dtype=torch.int64) rejection_sampler = RejectionSampler( - disable_bonus_tokens=disable_bonus_tokens) + disable_bonus_tokens=disable_bonus_tokens, + use_flashinfer=use_flashinfer) rejection_sampler.init_gpu_tensors(device=device) output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access accepted, @@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str, @pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("batch_size", list(range(1, 32))) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, - device: str): + device: str, use_flashinfer: bool): torch.set_default_device(device) - rejection_sampler = RejectionSampler() + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer) rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, @pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) @pytest.mark.parametrize("n_rep", [100]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, - frac_seeded: float, n_rep: int, - device: str): + frac_seeded: float, n_rep: int, device: str, + use_flashinfer: bool): torch.set_default_device(device) - rejection_sampler = RejectionSampler() + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer) rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, assert torch.equal(results[j][i], results[0][i]) +@pytest.mark.parametrize("k", [1, 3, 6]) +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) +@pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_compare_nonflashinfer_backend(k: int, vocab_size: int, + batch_size: int, device: str): + """ + Test the flashinfer and nonflashinfer backend generate + the same output metrics. + """ + torch.set_default_device(device) + torch.manual_seed(0) + draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + + num_accepted_tokens = [] + num_emitted_tokens = [] + num_draft_tokens = [] + + def get_seeded_seqs(): + return { + i: torch.Generator(device=device).manual_seed(i) + for i in range(batch_size) + } + + for use_flashinfer in [True, False]: + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer) + rejection_sampler.init_gpu_tensors(device=device) + # We use seeded sequences to ensure the same tokens are accepted + # for both flashinfer and nonflashinfer backends. + seeded_seqs = get_seeded_seqs() + rejection_sampler(target_probs, bonus_token_ids, draft_probs, + draft_token_ids, seeded_seqs) + num_accepted_tokens.append(rejection_sampler.num_accepted_tokens) + num_emitted_tokens.append(rejection_sampler.num_emitted_tokens) + num_draft_tokens.append(rejection_sampler.num_draft_tokens) + + assert num_accepted_tokens[0] == num_accepted_tokens[1] + assert num_emitted_tokens[0] == num_emitted_tokens[1] + assert num_draft_tokens[0] == num_draft_tokens[1] + + @pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) @pytest.mark.parametrize("which_token_ids", ["bonus_token_ids", "draft_token_ids"]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_raises_when_vocab_oob(above_or_below_vocab_range: str, - which_token_ids: str, device: str): + which_token_ids: str, device: str, + use_flashinfer: bool): k = 3 batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - rejection_sampler = RejectionSampler(strict_mode=True) + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer, + strict_mode=True) rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, @pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) @pytest.mark.parametrize("seed", list(range(5))) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_rejection_sampling_approximates_target_distribution( - seed: int, draft_and_target_probs_equal: bool): + seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool): """Verify rejection sampling approximates target distribution, despite sampling from a potentially distinct draft distribution. @@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution( """ torch.set_default_device("cpu") set_random_seed(seed) - helper = _CorrectnessTestHelper( vocab_size=10, - rejection_sampler=RejectionSampler(), + rejection_sampler=RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer), ) draft_probs, target_probs, reference_probs = helper.generate_probs_for_test( @@ -398,10 +476,10 @@ class _CorrectnessTestHelper: draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat( num_samples, 1, 1) - # Repeat target probs num_samples * k times. + # Repeat target probs num_samples * (k + 1) times. # Rejection sampler requires bonus token probs, but they aren't used. target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat( - num_samples, self.k, 1) + num_samples, self.k + 1, 1) # Randomly sample draft token ids from draft probs. draft_token_ids = torch.multinomial(draft_probs[:, 0, :], diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index aa3c1d29bdb3..e81ec4a0fdf1 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler() typical_acceptance_sampler.init_gpu_tensors(device=device) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_with_bonus_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, size=(batch_size, k), dtype=torch.int64) # Verify that sampling succeeds for all cases. - typical_acceptance_sampler(target_probs, + typical_acceptance_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(device=device) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_with_bonus_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, oob_token_ids[0][0] = rogue_token_id with pytest.raises(AssertionError): - typical_acceptance_sampler(target_probs, + typical_acceptance_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens( typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(device=device) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_with_bonus_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) draft_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), @@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens( size=(batch_size, 1), dtype=torch.int64) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int, # Simulate temperature 0 probability distribution for target probabilities # and create target probabilities such that only 1 token id has # probability 1.0 - target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist( - batch_size, k, vocab_size) + target_with_bonus_probs, zero_temperature_token_ids = \ + get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] # Populate draft_token_ids such that they exclude the token_ids # with probability = 1.0 draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, @@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int, # fallback to the greedy sampling for selecting 1 token for each sequence. # Verify the same. output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, # For sequences 0 and 2 set the distribution to a temperature # zero distribution. For sequences 1 and 3 set it to a uniform # distribution. - target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( - batch_size, k, vocab_size)) + target_with_bonus_probs, zero_temperature_token_ids = \ + get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] + target_probs = target_with_bonus_probs[:, :-1] draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, zero_temperature_token_ids) uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) @@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, size=(batch_size, 1), dtype=torch.int64) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, # Create a temperature zero target probability distribution and ensure # all draft token ids correspond to the tokens with 1.0 probability. # Verify that all of them are accepted. - target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( - batch_size, k, vocab_size)) + target_with_bonus_probs, zero_temperature_token_ids = \ + get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] draft_token_ids = zero_temperature_token_ids bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, draft_token_ids = torch.cat( (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, # 0.00001. Populate draft_token_ids such that they exclude the token_ids # with probability = 1.0. Without any changes to the posterior thresholds # none of the draft tokens are accepted. - target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( - batch_size, k, vocab_size)) + target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist( + batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] target_probs[target_probs == 0] = 0.00001 draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, zero_temperature_token_ids) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index cbaffee2f41e..501d05756e01 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -230,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, assert torch.equal(actual.bonus_token_ids, target_token_ids.reshape(batch_size, k + 1)[:, -1:]) - assert torch.equal( - actual.target_probs, - target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) + assert torch.equal(actual.target_with_bonus_probs, + target_token_probs.reshape(batch_size, k + 1, -1)) assert torch.equal(actual.draft_token_ids, proposal_token_ids) assert torch.equal(actual.draft_probs, proposal_probs) diff --git a/vllm/envs.py b/vllm/envs.py index 30320af5fa43..3c6b6adff82f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: bool = False + VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 2124196d06f9..b2f333a5bcc8 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -1,12 +1,28 @@ from functools import cached_property +from importlib.util import find_spec from typing import Dict, List, Optional, Tuple import torch import torch.jit +import vllm.envs as envs +from vllm.logger import init_logger from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeStochasticBaseSampler) +logger = init_logger(__name__) + +if find_spec("flashinfer"): + """ + Consider utilizing the FlashInfer rejection sampling kernel initially, + as it employs a dedicated kernel rather than relying on + Torch tensor operations. This design choice helps to fuse operations, + reduce memory I/O, and consequently enhances performance. + """ + from flashinfer.sampling import chain_speculative_sampling +else: + chain_speculative_sampling = None + class RejectionSampler(SpecDecodeStochasticBaseSampler): """Apply modified rejection sampling as described in "Accelerating Large @@ -16,7 +32,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): def __init__(self, disable_bonus_tokens: bool = True, - strict_mode: bool = False): + strict_mode: bool = False, + use_flashinfer: Optional[bool] = None): """Create a rejection sampler. Args: @@ -26,13 +43,29 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): strict_mode: Whether or not to perform shape/device/dtype checks during sampling. This catches correctness issues but adds nontrivial latency. + use_falshinfer: We will use this parameter to determine whether + to use the FlashInfer rejection sampling kernel or not. If it's + None, we will use the default value from the environment variable. + This parameter is only used for testing purposes. """ super().__init__(disable_bonus_tokens=disable_bonus_tokens, strict_mode=strict_mode) + if use_flashinfer is None: + self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and ( + chain_speculative_sampling is not None) + else: + self.use_flashinfer = use_flashinfer + + if self.use_flashinfer: + assert not disable_bonus_tokens, \ + "flashinfer will enable bonus token by default" + logger.info("Use flashinfer for rejection sampling.") + else: + logger.info("Use pytorch for rejection sampling.") def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, @@ -50,9 +83,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): sequence. Args: - target_probs: The probability distribution over token ids given - context according to the target model. - shape = [batch_size, num_speculative_tokens, vocab_size] + target_with_bonus_probs: The probability distribution + over token ids given context according to the target model. + shape = [batch_size, num_speculative_tokens + 1, vocab_size] bonus_token_ids: The "bonus" token ids that are accepted iff all speculative tokens in a sequence are accepted. @@ -78,23 +111,52 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_input(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) - accepted, recovered_token_ids = ( - self._batch_modified_rejection_sampling( - target_probs, - draft_probs, + batch_size, k, _ = draft_probs.shape + + # batch_size = 0 when all requests in the batch are + # non_spec requests. In this case, output_token_ids is + # just an empty tensor. + if batch_size == 0: + return torch.empty(0, k + 1, device=draft_probs.device, dtype=int) + + # If use Flashinfer chain_speculative_sampling kernel + # for rejection sampling + if self.use_flashinfer: + batch_size, k, _ = draft_probs.shape + uniform_samples = self._create_uniform_samples( + seeded_seqs, batch_size, k, draft_probs.device) + output_token_ids, accepted_token_num, emitted_token_num \ + = chain_speculative_sampling( + draft_probs, draft_token_ids, uniform_samples, + target_with_bonus_probs) + + # num_emitted_tokens returned by flashinfer + # does not include the bonus token + # Flashinfer stops at the first token that violates + # the condition p >= q and does not include recovery/bonus token. + # Therefore, we need to add batch_size here. + self.num_accepted_tokens += accepted_token_num.sum() + self.num_emitted_tokens += emitted_token_num.sum() + batch_size + self.num_draft_tokens += batch_size * k + else: + accepted, recovered_token_ids = ( + self._batch_modified_rejection_sampling( + target_with_bonus_probs[:, :-1], + draft_probs, + draft_token_ids, + seeded_seqs, + )) + + output_token_ids = self._create_output( + accepted, + recovered_token_ids, draft_token_ids, - seeded_seqs, - )) - - output_token_ids = self._create_output( - accepted, - recovered_token_ids, - draft_token_ids, - bonus_token_ids, - ) + bonus_token_ids, + ) return output_token_ids @@ -135,6 +197,63 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): return accepted, recovered_token_ids + def _create_uniform_samples(self, + seeded_seqs: Optional[Dict[int, + torch.Generator]], + batch_size: int, k: int, + device: torch.device) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + for specific sequences. + + This method creates a tensor of shape `(batch_size, k + 1)` filled + with uniform random values in the range [0, 1). If `seeded_seqs` + is provided, the sequences corresponding to specific indices + will be generated using the provided `torch.Generator` for + reproducibility. The other sequences will be generated without + a seed. + + Args: + seeded_seqs : Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. If `None`, all samples are + generated without a seed. + batch_size : int + The number of sequences to generate. + k : int + The number of random samples per sequence. + device : torch.device + The device on which to allocate the tensor. + + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(batch_size, k + 1)` containing uniform + random values in the range [0, 1). + """ + if not seeded_seqs: + return torch.rand(batch_size, k + 1, device=device) + + uniform_rand = torch.empty(batch_size, k + 1, device=device) + + non_seeded_indices = [] + for idx in range(batch_size): + generator = seeded_seqs.get(idx) + if generator is None: + non_seeded_indices.append(idx) + else: + uniform_rand[idx, :] = torch.rand(1, + k + 1, + dtype=self.probs_dtype, + device=device, + generator=generator) + if non_seeded_indices: + uniform_rand[non_seeded_indices, :] = torch.rand( + len(non_seeded_indices), + k + 1, + dtype=self.probs_dtype, + device=device) + return uniform_rand + def _get_accepted( self, target_probs: torch.Tensor, # [batch_size, k, vocab_size] @@ -175,29 +294,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): selected_target_probs = target_probs[batch_indices, probs_indicies, draft_token_ids] - if not seeded_seqs: - uniform_rand = torch.rand_like(selected_target_probs) - else: - uniform_rand = torch.empty_like(selected_target_probs) - - non_seeded_indices = [] - for idx in range(batch_size): - generator = seeded_seqs.get(idx) - if generator is None: - non_seeded_indices.append(idx) - else: - uniform_rand[idx, :] = torch.rand( - 1, - k, - dtype=self.probs_dtype, - device=target_probs.device, - generator=generator) - if non_seeded_indices: - uniform_rand[non_seeded_indices, :] = torch.rand( - len(non_seeded_indices), - k, - dtype=self.probs_dtype, - device=target_probs.device) + uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size, + k - 1, target_probs.device) capped_ratio = torch.minimum( selected_target_probs / selected_draft_probs, diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 467c43c41550..f9532dffa92c 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -130,29 +130,35 @@ class SpecDecodeBaseSampler(nn.Module): def _raise_if_incorrect_input( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: - self._raise_if_incorrect_shape(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) - self._raise_if_incorrect_dtype(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) - self._raise_if_inconsistent_device(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) - self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], + self._raise_if_incorrect_shape(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_incorrect_dtype(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_inconsistent_device(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1], draft_token_ids, bonus_token_ids) def _raise_if_incorrect_shape( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: (target_batch_size, num_target_probs, - target_vocab_size) = target_probs.shape + target_vocab_size) = target_with_bonus_probs.shape + + # Does not count the extra token + num_target_probs -= 1 # validate the shape of draft token ids. draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape @@ -175,12 +181,12 @@ class SpecDecodeBaseSampler(nn.Module): def _raise_if_incorrect_dtype( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: - assert target_probs.dtype == self.probs_dtype + assert target_with_bonus_probs.dtype == self.probs_dtype assert draft_token_ids.dtype == self.token_id_dtype assert bonus_token_ids.dtype == self.token_id_dtype if draft_probs is not None: @@ -188,15 +194,16 @@ class SpecDecodeBaseSampler(nn.Module): def _raise_if_inconsistent_device( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: devices = [ - t.device for t in - [target_probs, bonus_token_ids, draft_probs, draft_token_ids] - if t is not None + t.device for t in [ + target_with_bonus_probs, bonus_token_ids, draft_probs, + draft_token_ids + ] if t is not None ] assert all([devices[0] == device for device in devices]) @@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler): @abstractmethod def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, @@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler): @abstractmethod def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index a87ea0eee57d..7428d33ea720 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -41,7 +41,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, @@ -80,8 +80,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_input(target_probs, draft_token_ids, - bonus_token_ids) + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids) + target_probs = target_with_bonus_probs[:, :-1] accepted = self._evaluate_accepted_tokens(target_probs, draft_token_ids) recovered_token_ids = self._replacement_token_ids(target_probs) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 78beb2ce4477..91f0a98c7bc3 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -625,8 +625,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): seq_group_metadata_list, proposal_lens_list) original_indices = spec_indices + non_spec_indices - # Get probabilities of target model, excluding bonus token. - proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1] + # Get probabilities of target model, including bonus tokens. + proposal_verifier_probs = proposal_scores.probs[spec_indices] # Get non-speculative sampled tokens from target model. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] @@ -651,13 +651,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase): } accepted_token_ids = self.spec_decode_sampler( - target_probs=proposal_verifier_probs, + target_with_bonus_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, draft_probs=proposal_probs, draft_token_ids=proposal_token_ids, **sampler_extra_kwargs, ) - # Append output tokens from non-speculative sequences to # the accepted token ids tensor. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +