[V1][spec decode] return logprobs for spec decoding (#26060)

Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Giancarlo Delfin 2025-10-22 22:59:59 -07:00 committed by GitHub
parent ff93cc8c84
commit 6644796bf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 392 additions and 186 deletions

View File

@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import math
from collections.abc import Generator
from typing import get_args
import pytest
import torch
from tests.utils import large_gpu_mark
from tests.v1.sample.utils import (
BatchLogprobsComposition,
BatchLogprobsSpecType,
@ -17,6 +19,7 @@ from tests.v1.sample.utils import (
)
from vllm import SamplingParams
from vllm.config.model import LogprobsMode
from vllm.distributed import cleanup_dist_env_and_memory
from ...conftest import HfRunner, VllmRunner
@ -508,3 +511,94 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
if logprobs_mode in ("raw_logits", "processed_logits"):
assert positive_values > 0
del llm
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
@pytest.mark.parametrize(
"model_setup",
[
pytest.param(
(
"eagle",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
),
marks=large_gpu_mark(min_gb=32),
),
],
)
def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str],
monkeypatch: pytest.MonkeyPatch,
):
"""Spec decode logprobs should match those of the base model.
Args:
logprobs_mode: logprobs mode.
model_setup: Spec decode method, base model name, and
draft model name.
"""
from vllm import LLM
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
prompt = "Hello world"
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
)
method, model_name, spec_model_name = model_setup
max_model_len = 256
# Run base LLM.
ref_llm = LLM(
model=model_name,
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
)
ref_results = ref_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from reference LLM.
ref_logprobs = []
for output in ref_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
ref_logprobs.append(logprobs[token_id])
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Run spec decode LLM.
spec_llm = LLM(
model_name,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": max_model_len,
},
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for output in spec_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
spec_logprobs.append(logprobs[token_id])
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Per-token logprobs are expected to be the same.
assert len(ref_logprobs) == len(spec_logprobs)
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from unittest.mock import Mock
import pytest
import torch
@ -11,6 +12,7 @@ from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
from vllm.v1.sample.sampler import Sampler, SamplerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = current_platform.device_type
@ -18,7 +20,28 @@ DEVICE = current_platform.device_type
@pytest.fixture
def rejection_sampler():
return RejectionSampler()
mock_sampler = Mock(spec=Sampler)
mock_sampler.logprobs_mode = "raw_logprobs"
return RejectionSampler(mock_sampler)
def mock_sampler_output(
rejection_sampler: RejectionSampler, bonus_token_ids: torch.Tensor
):
rejection_sampler.sampler.return_value = SamplerOutput(
sampled_token_ids=bonus_token_ids, logprobs_tensors=None
)
def create_spec_decode_metadata(
spec_tokens: list[list[int]], logits: torch.Tensor
) -> SpecDecodeMetadata:
metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device)
metadata.target_logits_indices = torch.arange(logits.shape[0])
# Output bonus token ids are mocked, so the bonus logit indices should
# be empty.
metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32)
return metadata
def create_logits_tensor(
@ -111,19 +134,17 @@ def test_perfect_match(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_early_mismatch(rejection_sampler):
@ -134,15 +155,13 @@ def test_early_mismatch(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor(
@ -150,7 +169,7 @@ def test_early_mismatch(rejection_sampler):
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_multiple_sequences(rejection_sampler):
@ -163,21 +182,19 @@ def test_multiple_sequences(rejection_sampler):
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_single_token_sequence(rejection_sampler):
@ -188,19 +205,17 @@ def test_single_token_sequence(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_empty_sequence(rejection_sampler):
@ -211,19 +226,17 @@ def test_empty_sequence(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_multiple_mismatches(rejection_sampler):
@ -236,15 +249,13 @@ def test_multiple_mismatches(rejection_sampler):
bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor(
@ -255,7 +266,7 @@ def test_multiple_mismatches(rejection_sampler):
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
@pytest.mark.parametrize(
@ -277,19 +288,17 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expec
bonus_token_tensor = torch.tensor(
[tokens[-1] for tokens in output_tokens], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
assert torch.equal(output, expected_tensor)
assert torch.equal(output.sampled_token_ids, expected_tensor)
########################### Tests for Random Sampling ###################
@ -331,18 +340,19 @@ def test_deterministic_when_seeded(
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=seeded_seqs
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids.tolist(), device=DEVICE
spec_decode_metadata = create_spec_decode_metadata(
draft_token_ids.tolist(), target_logits
)
mock_sampler_output(rejection_sampler, bonus_token_ids)
rep_result = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
draft_probs=None,
logits=target_logits,
sampling_metadata=sampling_metadata,
)
results.append(rep_result)
results.append(rep_result.sampled_token_ids)
for i in range(batch_size):
if seeded_mask[i]:
@ -460,7 +470,9 @@ def estimate_rejection_sampling_pdf(
Returns:
Estimated probability distribution of the output tokens.
"""
rejection_sampler = RejectionSampler()
mock_sampler = Mock(spec=Sampler)
mock_sampler.logprobs_mode = "raw_logprobs"
rejection_sampler = RejectionSampler(mock_sampler)
num_tokens = num_samples * k
# Repeat draft probs num_samples * k times.
draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
@ -483,17 +495,18 @@ def estimate_rejection_sampling_pdf(
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids.tolist(), device=bonus_token_ids.device
spec_decode_metadata = create_spec_decode_metadata(
draft_token_ids.tolist(), target_logits
)
output_token_ids = rejection_sampler(
mock_sampler_output(rejection_sampler, bonus_token_ids)
sampler_output = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
logits=target_logits,
sampling_metadata=sampling_metadata,
)
output_token_ids = output_token_ids[:, :-1].flatten()
output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten()
hist = torch.histogram(
output_token_ids.to(dtype=torch.float, device="cpu"),
@ -532,22 +545,19 @@ def _test_masked_logits(
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
# Create spec decode metadata
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids,
device=DEVICE,
)
spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
# Run rejection sampling
output_token_ids = rejection_sampler(
mock_sampler_output(rejection_sampler, bonus_token_ids)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
logits=target_logits,
sampling_metadata=sampling_metadata,
)
# Remove bonus tokens and reshape
output_token_ids = output_token_ids[:, :-1].flatten().tolist()
output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist()
# Check that all sampled tokens are within the unmasked indices.
for i in range(num_tokens):
@ -665,11 +675,11 @@ def test_frequency_penalties(rejection_sampler):
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
expected = torch.tensor(
@ -677,7 +687,7 @@ def test_frequency_penalties(rejection_sampler):
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_bad_words(rejection_sampler):
@ -707,14 +717,12 @@ def test_bad_words(rejection_sampler):
bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
@ -723,7 +731,7 @@ def test_bad_words(rejection_sampler):
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)
def test_allowed_token_ids(rejection_sampler):
@ -756,14 +764,12 @@ def test_allowed_token_ids(rejection_sampler):
bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device
)
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
logits=logits,
sampling_metadata=metadata,
)
@ -772,4 +778,4 @@ def test_allowed_token_ids(rejection_sampler):
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected)
assert torch.equal(output.sampled_token_ids, expected)

View File

@ -66,7 +66,7 @@ class LogprobsProcessor:
assert self.logprobs is not None
assert self.cumulative_logprob is not None
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst):
# Detokenize (non-incrementally).

View File

@ -14,34 +14,49 @@ else:
class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprob_token_ids: list[list[int]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprobs: list[list[float]]
# [num_reqs]
# [num_reqs x num_generated_tokens]
sampled_token_ranks: list[int]
# [num_reqs]
# Used for slicing the logprobs in cases like speculative
# decoding where the number of generated tokens may be
# different for each request.
cu_num_generated_tokens: list[int] | None = None
def slice(self, start: int, end: int):
def slice(self, start_req_idx: int, end_req_idx: int):
if self.cu_num_generated_tokens:
start = self.cu_num_generated_tokens[start_req_idx]
end = self.cu_num_generated_tokens[end_req_idx]
else:
start = start_req_idx
end = end_req_idx
return LogprobsLists(
self.logprob_token_ids[start:end],
self.logprobs[start:end],
self.sampled_token_ranks[start:end],
self.cu_num_generated_tokens[start_req_idx:end_req_idx]
if self.cu_num_generated_tokens
else None,
)
class LogprobsTensors(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprob_token_ids: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprobs: torch.Tensor
# [num_reqs]
# [num_reqs x num_generated_tokens]
selected_token_ranks: torch.Tensor
def tolists(self):
def tolists(self, cu_num_generated_tokens: list[int] | None = None):
return LogprobsLists(
self.logprob_token_ids.tolist(),
self.logprobs.tolist(),
self.selected_token_ranks.tolist(),
cu_num_generated_tokens,
)
@staticmethod

View File

@ -1,15 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import replace
import torch
import torch.nn as nn
from vllm.logger import init_logger
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__)
@ -44,17 +48,22 @@ class RejectionSampler(nn.Module):
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
def __init__(self, sampler: Sampler):
super().__init__()
self.sampler = sampler
logprobs_mode = self.sampler.logprobs_mode
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
def forward(
self,
metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size]
draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size]
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
# [num_tokens + batch_size, vocab_size]
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
) -> SamplerOutput:
"""
Args:
metadata:
@ -63,43 +72,65 @@ class RejectionSampler(nn.Module):
Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor):
logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
Shape is [num_tokens + batch_size, vocab_size]. Here,
probabilities from different requests are flattened into a
single tensor because this is the shape of the output logits.
NOTE: `logits` can be updated in place to save memory.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
SamplerOutput:
Contains the final output token IDs and their logprobs if
requested.
"""
assert metadata.max_spec_len <= MAX_SPEC_LEN
# Use float32 for the target_logits.
target_logits = target_logits.to(torch.float32)
bonus_logits_indices = metadata.bonus_logits_indices
target_logits_indices = metadata.target_logits_indices
target_logits = self.apply_logits_processors(
target_logits, sampling_metadata, metadata
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[bonus_logits_indices]
bonus_sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=replace(
sampling_metadata,
max_num_logprobs=-1,
),
predict_bonus_token=True,
# Override the logprobs mode to return logits because they are
# needed later to compute the accepted token logprobs.
logprobs_mode_override="processed_logits"
if self.is_processed_logprobs_mode
else "raw_logits",
)
bonus_token_ids = bonus_sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
raw_target_logits = logits[target_logits_indices]
# Use float32 for the target_logits.
raw_target_logits = raw_target_logits.to(torch.float32)
target_logits = self.apply_logits_processors(
raw_target_logits, sampling_metadata, metadata
)
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
target_probs = compute_probs(
# `apply_sampling_constraints` function.
target_logits = apply_sampling_constraints(
target_logits,
metadata.cu_num_draft_tokens,
sampling_metadata,
)
# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
output_token_ids = rejection_sample(
metadata.draft_token_ids,
@ -111,7 +142,63 @@ class RejectionSampler(nn.Module):
bonus_token_ids,
sampling_metadata,
)
return output_token_ids
logprobs_tensors = None
if sampling_metadata.max_num_logprobs:
logprobs_tensors = self._get_logprobs_tensors(
sampling_metadata.max_num_logprobs,
metadata,
logits,
target_logits if self.is_processed_logprobs_mode else raw_target_logits,
bonus_sampler_output.logprobs_tensors.logprobs,
output_token_ids,
)
return SamplerOutput(
sampled_token_ids=output_token_ids,
logprobs_tensors=logprobs_tensors,
)
def _get_logprobs_tensors(
self,
max_num_logprobs: int,
metadata: SpecDecodeMetadata,
logits: torch.Tensor,
target_logits: torch.Tensor,
bonus_logits: torch.Tensor,
sampled_token_ids: torch.Tensor,
) -> LogprobsTensors:
cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens)
cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1]
# Collect target and bonus logits.
bonus_logits_indices = metadata.bonus_logits_indices
target_logits_indices = metadata.target_logits_indices
final_logits = torch.zeros_like(logits, dtype=torch.float32)
final_logits[target_logits_indices] = target_logits.to(torch.float32)
final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
# Compute accepted token indices.
accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID
num_accepted_tokens = accepted_mask.sum(dim=-1)
accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1]
accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave(
num_accepted_tokens
)
# Compute logprobs for accepted tokens.
accepted_logits = final_logits[accepted_logit_indices]
accepted_logprobs = (
accepted_logits
if self.is_logits_logprobs_mode
else self.sampler.compute_logprobs(accepted_logits)
)
accepted_tokens = sampled_token_ids[accepted_mask]
return self.sampler.gather_logprobs(
accepted_logprobs,
max_num_logprobs,
accepted_tokens.to(torch.int64),
)
@staticmethod
def parse_output(
@ -119,14 +206,12 @@ class RejectionSampler(nn.Module):
vocab_size: int,
) -> list[list[int]]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
Returns:
A list of lists of token IDs.
"""
@ -328,27 +413,26 @@ def rejection_sample(
return output_token_ids
def compute_probs(
def apply_sampling_constraints(
logits: torch.Tensor, # [num_tokens, vocab_size]
cu_num_draft_tokens: torch.Tensor, # [batch_size]
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Compute probability distribution from logits based on sampling metadata.
"""Process logits based on sampling metadata.
This function applies temperature scaling to the logits and converts
them to probabilities using softmax. For greedy decoding, it returns
This function applies temperature scaling to the logits,
as well as top-k and top-p. For greedy decoding, it returns
the original logits.
Args:
logits: Input logits tensor to be converted to probabilities.
logits: Input logits tensor to be processed.
cu_num_draft_tokens: Cumulative number of draft tokens.
sampling_metadata: Metadata containing sampling parameters such as
temperature and whether greedy sampling is used.
Returns:
torch.Tensor: Probability distribution (softmax of scaled logits)
if non-greedy sampling is used, otherwise returns the
original logits.
torch.Tensor: Processed logits if non-greedy sampling is used,
otherwise returns the original logits.
"""
assert logits.ndim == 2
assert cu_num_draft_tokens.ndim == 1
@ -384,9 +468,7 @@ def compute_probs(
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
logits = apply_top_k_top_p(logits, top_k, top_p)
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
return output_prob
return apply_top_k_top_p(logits, top_k, top_p)
def expand_batch_to_tokens(

View File

@ -69,16 +69,18 @@ class Sampler(nn.Module):
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
predict_bonus_token: bool = False,
logprobs_mode_override: LogprobsMode | None = None,
) -> SamplerOutput:
logprobs_mode = logprobs_mode_override or self.logprobs_mode
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
if self.logprobs_mode == "raw_logprobs":
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == "raw_logits":
elif logprobs_mode == "raw_logits":
raw_logprobs = logits.clone()
# Use float32 for the logits.
@ -97,13 +99,18 @@ class Sampler(nn.Module):
# return int32 (while PyTorch argmax and topk return int64).
sampled = sampled.long()
# Gather the logprobs of the topk and sampled token (if requested).
# Get logprobs and rank tensors (if requested)
logprobs_tensors = (
None
if num_logprobs is None
else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
)
if num_logprobs is None:
logprobs_tensors = None
elif num_logprobs == -1:
# Return the full unsorted and unranked logprobs.
logprobs_tensors = LogprobsTensors(
torch.empty(0), raw_logprobs, torch.empty(0)
)
else:
# Gather the logprobs and ranks of the topk and sampled token.
logprobs_tensors = self.gather_logprobs(
raw_logprobs, num_logprobs, token_ids=sampled
)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
@ -138,6 +145,7 @@ class Sampler(nn.Module):
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
logprobs_mode_override: LogprobsMode | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Sample logits based on sampling metadata.
@ -145,6 +153,7 @@ class Sampler(nn.Module):
may update the logits tensor in-place.
"""
logprobs_mode = logprobs_mode_override or self.logprobs_mode
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
if sampling_metadata.all_random:
greedy_sampled = None
@ -153,9 +162,9 @@ class Sampler(nn.Module):
if sampling_metadata.all_greedy:
processed_logprobs = None
if sampling_metadata.max_num_logprobs is not None:
if self.logprobs_mode == "processed_logits":
if logprobs_mode == "processed_logits":
processed_logprobs = logits
elif self.logprobs_mode == "processed_logprobs":
elif logprobs_mode == "processed_logprobs":
processed_logprobs = self.compute_logprobs(logits)
return greedy_sampled, processed_logprobs

View File

@ -14,6 +14,8 @@ class SpecDecodeMetadata:
num_draft_tokens: list[int]
# [batch_size]
cu_num_draft_tokens: torch.Tensor
# [batch_size]
cu_num_sampled_tokens: torch.Tensor
# [num_tokens]
target_logits_indices: torch.Tensor
# [batch_size]
@ -32,6 +34,7 @@ class SpecDecodeMetadata:
) -> "SpecDecodeMetadata":
batch_size = len(draft_token_ids)
num_draft_tokens = [len(ids) for ids in draft_token_ids]
num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
flattened_draft_token_ids = sum(draft_token_ids, [])
num_tokens = len(flattened_draft_token_ids)
@ -40,6 +43,10 @@ class SpecDecodeMetadata:
)
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
device
)
target_logits_indices = torch.zeros(
num_tokens, dtype=torch.int32, device=device
@ -52,6 +59,7 @@ class SpecDecodeMetadata:
draft_token_ids=draft_token_ids_tensor,
num_draft_tokens=num_draft_tokens,
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,

View File

@ -327,7 +327,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"Unknown speculative decoding method: "
f"{self.speculative_config.method}"
)
self.rejection_sampler = RejectionSampler()
self.rejection_sampler = RejectionSampler(self.sampler)
# Request states.
self.requests: dict[str, CachedRequestState] = {}
@ -1624,6 +1624,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True
)
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
self.device, non_blocking=True
)
logits_indices = torch.from_numpy(logits_indices).to(
self.device, non_blocking=True
)
@ -1639,15 +1642,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
draft_token_ids = self.input_ids.gpu[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1]
metadata = SpecDecodeMetadata(
return SpecDecodeMetadata(
draft_token_ids=draft_token_ids,
num_draft_tokens=num_draft_tokens.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens,
cu_num_sampled_tokens=cu_num_sampled_tokens,
target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices,
)
return metadata
def _prepare_kv_sharing_fast_prefill(
self,
@ -2221,32 +2224,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata=sampling_metadata,
)
# When indexing with a tensor (bonus_logits_indices), PyTorch
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
predict_bonus_token=True,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
sampler_output = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
target_logits,
bonus_token_ids,
logits,
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
self._update_states_after_model_execute(output_token_ids)
self._update_states_after_model_execute(sampler_output.sampled_token_ids)
return sampler_output
def _bookkeeping_sync(
@ -2256,6 +2240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits: torch.Tensor | None,
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
spec_decode_metadata: SpecDecodeMetadata | None,
) -> tuple[
dict[str, int],
LogprobsLists | None,
@ -2282,19 +2267,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_ids_output_copy = self.input_batch.req_ids.copy()
req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = (
logprobs_tensors.tolists() if logprobs_tensors is not None else None
)
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output.num_scheduled_tokens,
)
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids
invalid_req_indices = []
@ -2335,6 +2307,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids
logprobs_tensors = sampler_output.logprobs_tensors
cu_num_accepted_tokens = (
[0] if spec_decode_metadata and logprobs_tensors else None
)
for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling:
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
@ -2360,6 +2336,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
if cu_num_accepted_tokens is not None:
cu_num_accepted_tokens.append(
cu_num_accepted_tokens[-1] + len(sampled_ids)
)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_lists = (
logprobs_tensors.tolists(cu_num_accepted_tokens)
if logprobs_tensors is not None
else None
)
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output.num_scheduled_tokens,
)
return (
num_nans_in_logits,
logprobs_lists,
@ -2644,6 +2639,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits,
hidden_states,
num_scheduled_tokens,
spec_decode_metadata,
)
if (
@ -3560,20 +3556,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype)
draft_probs = None
target_logits = torch.randn(
num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype
)
# NOTE(woosuk): Here, we should use int32 because the sampler uses
# int32 for bonus_token_ids. If the dtype mismatches, re-compilation
# will occur at runtime.
bonus_token_ids = torch.zeros(
num_reqs, device=self.device, dtype=torch.int32
logits = torch.randn(
num_tokens + num_reqs,
logits.shape[-1],
device=self.device,
dtype=logits.dtype,
)
self.rejection_sampler(
dummy_spec_decode_metadata,
draft_probs,
target_logits,
bonus_token_ids,
logits,
dummy_metadata,
)
return sampler_output