[V1][spec decode] return logprobs for spec decoding (#26060)

Signed-off-by: Giancarlo Delfin <gdelfin@meta.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Giancarlo Delfin 2025-10-22 22:59:59 -07:00 committed by GitHub
parent ff93cc8c84
commit 6644796bf4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 392 additions and 186 deletions

View File

@ -2,12 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools import itertools
import math
from collections.abc import Generator from collections.abc import Generator
from typing import get_args from typing import get_args
import pytest import pytest
import torch import torch
from tests.utils import large_gpu_mark
from tests.v1.sample.utils import ( from tests.v1.sample.utils import (
BatchLogprobsComposition, BatchLogprobsComposition,
BatchLogprobsSpecType, BatchLogprobsSpecType,
@ -17,6 +19,7 @@ from tests.v1.sample.utils import (
) )
from vllm import SamplingParams from vllm import SamplingParams
from vllm.config.model import LogprobsMode from vllm.config.model import LogprobsMode
from vllm.distributed import cleanup_dist_env_and_memory
from ...conftest import HfRunner, VllmRunner from ...conftest import HfRunner, VllmRunner
@ -508,3 +511,94 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
if logprobs_mode in ("raw_logits", "processed_logits"): if logprobs_mode in ("raw_logits", "processed_logits"):
assert positive_values > 0 assert positive_values > 0
del llm del llm
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
@pytest.mark.parametrize(
"model_setup",
[
pytest.param(
(
"eagle",
"meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
),
marks=large_gpu_mark(min_gb=32),
),
],
)
def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, str],
monkeypatch: pytest.MonkeyPatch,
):
"""Spec decode logprobs should match those of the base model.
Args:
logprobs_mode: logprobs mode.
model_setup: Spec decode method, base model name, and
draft model name.
"""
from vllm import LLM
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
prompt = "Hello world"
sampling_params = SamplingParams(
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
)
method, model_name, spec_model_name = model_setup
max_model_len = 256
# Run base LLM.
ref_llm = LLM(
model=model_name,
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
)
ref_results = ref_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from reference LLM.
ref_logprobs = []
for output in ref_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
ref_logprobs.append(logprobs[token_id])
del ref_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Run spec decode LLM.
spec_llm = LLM(
model_name,
speculative_config={
"method": method,
"model": spec_model_name,
"num_speculative_tokens": 3,
"max_model_len": max_model_len,
},
max_logprobs=5,
max_model_len=max_model_len,
seed=42,
logprobs_mode=logprobs_mode,
gpu_memory_utilization=0.4,
)
spec_results = spec_llm.generate([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for output in spec_results[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
spec_logprobs.append(logprobs[token_id])
del spec_llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
# Per-token logprobs are expected to be the same.
assert len(ref_logprobs) == len(spec_logprobs)
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any from typing import Any
from unittest.mock import Mock
import pytest import pytest
import torch import torch
@ -11,6 +12,7 @@ from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
from vllm.v1.sample.sampler import Sampler, SamplerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = current_platform.device_type DEVICE = current_platform.device_type
@ -18,7 +20,28 @@ DEVICE = current_platform.device_type
@pytest.fixture @pytest.fixture
def rejection_sampler(): def rejection_sampler():
return RejectionSampler() mock_sampler = Mock(spec=Sampler)
mock_sampler.logprobs_mode = "raw_logprobs"
return RejectionSampler(mock_sampler)
def mock_sampler_output(
rejection_sampler: RejectionSampler, bonus_token_ids: torch.Tensor
):
rejection_sampler.sampler.return_value = SamplerOutput(
sampled_token_ids=bonus_token_ids, logprobs_tensors=None
)
def create_spec_decode_metadata(
spec_tokens: list[list[int]], logits: torch.Tensor
) -> SpecDecodeMetadata:
metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device)
metadata.target_logits_indices = torch.arange(logits.shape[0])
# Output bonus token ids are mocked, so the bonus logit indices should
# be empty.
metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32)
return metadata
def create_logits_tensor( def create_logits_tensor(
@ -111,19 +134,17 @@ def test_perfect_match(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True) metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
spec_tokens, device=logits.device
)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler( output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=None, draft_probs=None,
target_logits=logits, logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata, sampling_metadata=metadata,
) )
expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device) expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output.sampled_token_ids, expected)
def test_early_mismatch(rejection_sampler): def test_early_mismatch(rejection_sampler):
@ -134,15 +155,13 @@ def test_early_mismatch(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True) metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
spec_tokens, device=logits.device
)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler( output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=None, draft_probs=None,
target_logits=logits, logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata, sampling_metadata=metadata,
) )
expected = torch.tensor( expected = torch.tensor(
@ -150,7 +169,7 @@ def test_early_mismatch(rejection_sampler):
dtype=torch.int, dtype=torch.int,
device=logits.device, device=logits.device,
) )
assert torch.equal(output, expected) assert torch.equal(output.sampled_token_ids, expected)
def test_multiple_sequences(rejection_sampler): def test_multiple_sequences(rejection_sampler):
@ -163,21 +182,19 @@ def test_multiple_sequences(rejection_sampler):
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
) )
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
spec_tokens, device=logits.device
)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler( output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=None, draft_probs=None,
target_logits=logits, logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata, sampling_metadata=metadata,
) )
expected = torch.tensor( expected = torch.tensor(
[[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device [[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
) )
assert torch.equal(output, expected) assert torch.equal(output.sampled_token_ids, expected)
def test_single_token_sequence(rejection_sampler): def test_single_token_sequence(rejection_sampler):
@ -188,19 +205,17 @@ def test_single_token_sequence(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True) metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
spec_tokens, device=logits.device
)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler( output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=None, draft_probs=None,
target_logits=logits, logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata, sampling_metadata=metadata,
) )
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output.sampled_token_ids, expected)
def test_empty_sequence(rejection_sampler): def test_empty_sequence(rejection_sampler):
@ -211,19 +226,17 @@ def test_empty_sequence(rejection_sampler):
metadata = create_sampling_metadata(all_greedy=True) metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device) bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
spec_tokens, device=logits.device
)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler( output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=None, draft_probs=None,
target_logits=logits, logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata, sampling_metadata=metadata,
) )
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output.sampled_token_ids, expected)
def test_multiple_mismatches(rejection_sampler): def test_multiple_mismatches(rejection_sampler):
@ -236,15 +249,13 @@ def test_multiple_mismatches(rejection_sampler):
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
) )
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
spec_tokens, device=logits.device
)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler( output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=None, draft_probs=None,
target_logits=logits, logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata, sampling_metadata=metadata,
) )
expected = torch.tensor( expected = torch.tensor(
@ -255,7 +266,7 @@ def test_multiple_mismatches(rejection_sampler):
dtype=torch.int, dtype=torch.int,
device=logits.device, device=logits.device,
) )
assert torch.equal(output, expected) assert torch.equal(output.sampled_token_ids, expected)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -277,19 +288,17 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expec
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
[tokens[-1] for tokens in output_tokens], device=logits.device [tokens[-1] for tokens in output_tokens], device=logits.device
) )
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
spec_tokens, device=logits.device
)
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler( output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=None, draft_probs=None,
target_logits=logits, logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata, sampling_metadata=metadata,
) )
expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device) expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
assert torch.equal(output, expected_tensor) assert torch.equal(output.sampled_token_ids, expected_tensor)
########################### Tests for Random Sampling ################### ########################### Tests for Random Sampling ###################
@ -331,18 +340,19 @@ def test_deterministic_when_seeded(
sampling_metadata = create_sampling_metadata( sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=seeded_seqs all_greedy=False, temperature=temperature, generators=seeded_seqs
) )
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(
draft_token_ids.tolist(), device=DEVICE draft_token_ids.tolist(), target_logits
) )
mock_sampler_output(rejection_sampler, bonus_token_ids)
rep_result = rejection_sampler( rep_result = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=draft_probs, draft_probs=None,
target_logits=target_logits, logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
results.append(rep_result) results.append(rep_result.sampled_token_ids)
for i in range(batch_size): for i in range(batch_size):
if seeded_mask[i]: if seeded_mask[i]:
@ -460,7 +470,9 @@ def estimate_rejection_sampling_pdf(
Returns: Returns:
Estimated probability distribution of the output tokens. Estimated probability distribution of the output tokens.
""" """
rejection_sampler = RejectionSampler() mock_sampler = Mock(spec=Sampler)
mock_sampler.logprobs_mode = "raw_logprobs"
rejection_sampler = RejectionSampler(mock_sampler)
num_tokens = num_samples * k num_tokens = num_samples * k
# Repeat draft probs num_samples * k times. # Repeat draft probs num_samples * k times.
draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1) draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
@ -483,17 +495,18 @@ def estimate_rejection_sampling_pdf(
sampling_metadata = create_sampling_metadata( sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature all_greedy=False, temperature=temperature
) )
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(
draft_token_ids.tolist(), device=bonus_token_ids.device draft_token_ids.tolist(), target_logits
) )
output_token_ids = rejection_sampler(
mock_sampler_output(rejection_sampler, bonus_token_ids)
sampler_output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=draft_probs, draft_probs=draft_probs,
target_logits=target_logits, logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
output_token_ids = output_token_ids[:, :-1].flatten() output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten()
hist = torch.histogram( hist = torch.histogram(
output_token_ids.to(dtype=torch.float, device="cpu"), output_token_ids.to(dtype=torch.float, device="cpu"),
@ -532,22 +545,19 @@ def _test_masked_logits(
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE) bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
# Create spec decode metadata # Create spec decode metadata
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
draft_token_ids,
device=DEVICE,
)
# Run rejection sampling # Run rejection sampling
output_token_ids = rejection_sampler( mock_sampler_output(rejection_sampler, bonus_token_ids)
output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=draft_probs, draft_probs=draft_probs,
target_logits=target_logits, logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
# Remove bonus tokens and reshape # Remove bonus tokens and reshape
output_token_ids = output_token_ids[:, :-1].flatten().tolist() output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist()
# Check that all sampled tokens are within the unmasked indices. # Check that all sampled tokens are within the unmasked indices.
for i in range(num_tokens): for i in range(num_tokens):
@ -665,11 +675,11 @@ def test_frequency_penalties(rejection_sampler):
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = SpecDecodeMetadata.make_dummy(
spec_tokens, device=logits.device spec_tokens, device=logits.device
) )
mock_sampler_output(rejection_sampler, bonus_token_tensor)
output = rejection_sampler( output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=None, draft_probs=None,
target_logits=logits, logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata, sampling_metadata=metadata,
) )
expected = torch.tensor( expected = torch.tensor(
@ -677,7 +687,7 @@ def test_frequency_penalties(rejection_sampler):
dtype=torch.int, dtype=torch.int,
device=logits.device, device=logits.device,
) )
assert torch.equal(output, expected) assert torch.equal(output.sampled_token_ids, expected)
def test_bad_words(rejection_sampler): def test_bad_words(rejection_sampler):
@ -707,14 +717,12 @@ def test_bad_words(rejection_sampler):
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
) )
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
spec_tokens, device=logits.device mock_sampler_output(rejection_sampler, bonus_token_tensor)
)
output = rejection_sampler( output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=None, draft_probs=None,
target_logits=logits, logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata, sampling_metadata=metadata,
) )
@ -723,7 +731,7 @@ def test_bad_words(rejection_sampler):
dtype=torch.int, dtype=torch.int,
device=logits.device, device=logits.device,
) )
assert torch.equal(output, expected) assert torch.equal(output.sampled_token_ids, expected)
def test_allowed_token_ids(rejection_sampler): def test_allowed_token_ids(rejection_sampler):
@ -756,14 +764,12 @@ def test_allowed_token_ids(rejection_sampler):
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device [output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
) )
spec_decode_metadata = SpecDecodeMetadata.make_dummy( spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
spec_tokens, device=logits.device mock_sampler_output(rejection_sampler, bonus_token_tensor)
)
output = rejection_sampler( output = rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
draft_probs=None, draft_probs=None,
target_logits=logits, logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata, sampling_metadata=metadata,
) )
@ -772,4 +778,4 @@ def test_allowed_token_ids(rejection_sampler):
dtype=torch.int, dtype=torch.int,
device=logits.device, device=logits.device,
) )
assert torch.equal(output, expected) assert torch.equal(output.sampled_token_ids, expected)

View File

@ -66,7 +66,7 @@ class LogprobsProcessor:
assert self.logprobs is not None assert self.logprobs is not None
assert self.cumulative_logprob is not None assert self.cumulative_logprob is not None
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst): for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst):
# Detokenize (non-incrementally). # Detokenize (non-incrementally).

View File

@ -14,34 +14,49 @@ else:
class LogprobsLists(NamedTuple): class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1] # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprob_token_ids: list[list[int]] logprob_token_ids: list[list[int]]
# [num_reqs, max_num_logprobs + 1] # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprobs: list[list[float]] logprobs: list[list[float]]
# [num_reqs] # [num_reqs x num_generated_tokens]
sampled_token_ranks: list[int] sampled_token_ranks: list[int]
# [num_reqs]
# Used for slicing the logprobs in cases like speculative
# decoding where the number of generated tokens may be
# different for each request.
cu_num_generated_tokens: list[int] | None = None
def slice(self, start: int, end: int): def slice(self, start_req_idx: int, end_req_idx: int):
if self.cu_num_generated_tokens:
start = self.cu_num_generated_tokens[start_req_idx]
end = self.cu_num_generated_tokens[end_req_idx]
else:
start = start_req_idx
end = end_req_idx
return LogprobsLists( return LogprobsLists(
self.logprob_token_ids[start:end], self.logprob_token_ids[start:end],
self.logprobs[start:end], self.logprobs[start:end],
self.sampled_token_ranks[start:end], self.sampled_token_ranks[start:end],
self.cu_num_generated_tokens[start_req_idx:end_req_idx]
if self.cu_num_generated_tokens
else None,
) )
class LogprobsTensors(NamedTuple): class LogprobsTensors(NamedTuple):
# [num_reqs, max_num_logprobs + 1] # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprob_token_ids: torch.Tensor logprob_token_ids: torch.Tensor
# [num_reqs, max_num_logprobs + 1] # [num_reqs x num_generated_tokens, max_num_logprobs + 1]
logprobs: torch.Tensor logprobs: torch.Tensor
# [num_reqs] # [num_reqs x num_generated_tokens]
selected_token_ranks: torch.Tensor selected_token_ranks: torch.Tensor
def tolists(self): def tolists(self, cu_num_generated_tokens: list[int] | None = None):
return LogprobsLists( return LogprobsLists(
self.logprob_token_ids.tolist(), self.logprob_token_ids.tolist(),
self.logprobs.tolist(), self.logprobs.tolist(),
self.selected_token_ranks.tolist(), self.selected_token_ranks.tolist(),
cu_num_generated_tokens,
) )
@staticmethod @staticmethod

View File

@ -1,15 +1,19 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import replace
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.penalties import apply_all_penalties
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.sampler import Sampler
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
@ -44,17 +48,22 @@ class RejectionSampler(nn.Module):
output tokens = accepted tokens + recovered tokens + bonus tokens output tokens = accepted tokens + recovered tokens + bonus tokens
""" """
def __init__(self, sampler: Sampler):
super().__init__()
self.sampler = sampler
logprobs_mode = self.sampler.logprobs_mode
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
def forward( def forward(
self, self,
metadata: SpecDecodeMetadata, metadata: SpecDecodeMetadata,
# [num_tokens, vocab_size] # [num_tokens, vocab_size]
draft_probs: torch.Tensor | None, draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size] # [num_tokens + batch_size, vocab_size]
target_logits: torch.Tensor, logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> SamplerOutput:
""" """
Args: Args:
metadata: metadata:
@ -63,43 +72,65 @@ class RejectionSampler(nn.Module):
Probability distribution for the draft tokens. Shape is Probability distribution for the draft tokens. Shape is
[num_tokens, vocab_size]. Can be None if probabilities are [num_tokens, vocab_size]. Can be None if probabilities are
not provided, which is the case for ngram spec decode. not provided, which is the case for ngram spec decode.
target_logits (torch.Tensor): logits (torch.Tensor):
Target model's logits probability distribution. Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from Shape is [num_tokens + batch_size, vocab_size]. Here,
different requests are flattened into a single tensor because probabilities from different requests are flattened into a
this is the shape of the output logits. single tensor because this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory. NOTE: `logits` can be updated in place to save memory.
bonus_token_ids (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata): sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
Additional metadata needed for sampling, such as temperature, Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information. top-k/top-p parameters, or other relevant information.
Returns: Returns:
output_token_ids (torch.Tensor): SamplerOutput:
A tensor containing the final output token IDs. Contains the final output token IDs and their logprobs if
requested.
""" """
assert metadata.max_spec_len <= MAX_SPEC_LEN assert metadata.max_spec_len <= MAX_SPEC_LEN
# Use float32 for the target_logits. bonus_logits_indices = metadata.bonus_logits_indices
target_logits = target_logits.to(torch.float32) target_logits_indices = metadata.target_logits_indices
target_logits = self.apply_logits_processors( # When indexing with a tensor (bonus_logits_indices), PyTorch
target_logits, sampling_metadata, metadata # creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[bonus_logits_indices]
bonus_sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=replace(
sampling_metadata,
max_num_logprobs=-1,
),
predict_bonus_token=True,
# Override the logprobs mode to return logits because they are
# needed later to compute the accepted token logprobs.
logprobs_mode_override="processed_logits"
if self.is_processed_logprobs_mode
else "raw_logits",
) )
bonus_token_ids = bonus_sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
raw_target_logits = logits[target_logits_indices]
# Use float32 for the target_logits.
raw_target_logits = raw_target_logits.to(torch.float32)
target_logits = self.apply_logits_processors(
raw_target_logits, sampling_metadata, metadata
)
# [num_tokens, vocab_size] # [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the # NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function. # `apply_sampling_constraints` function.
target_probs = compute_probs( target_logits = apply_sampling_constraints(
target_logits, target_logits,
metadata.cu_num_draft_tokens, metadata.cu_num_draft_tokens,
sampling_metadata, sampling_metadata,
) )
# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
output_token_ids = rejection_sample( output_token_ids = rejection_sample(
metadata.draft_token_ids, metadata.draft_token_ids,
@ -111,7 +142,63 @@ class RejectionSampler(nn.Module):
bonus_token_ids, bonus_token_ids,
sampling_metadata, sampling_metadata,
) )
return output_token_ids
logprobs_tensors = None
if sampling_metadata.max_num_logprobs:
logprobs_tensors = self._get_logprobs_tensors(
sampling_metadata.max_num_logprobs,
metadata,
logits,
target_logits if self.is_processed_logprobs_mode else raw_target_logits,
bonus_sampler_output.logprobs_tensors.logprobs,
output_token_ids,
)
return SamplerOutput(
sampled_token_ids=output_token_ids,
logprobs_tensors=logprobs_tensors,
)
def _get_logprobs_tensors(
self,
max_num_logprobs: int,
metadata: SpecDecodeMetadata,
logits: torch.Tensor,
target_logits: torch.Tensor,
bonus_logits: torch.Tensor,
sampled_token_ids: torch.Tensor,
) -> LogprobsTensors:
cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens)
cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1]
# Collect target and bonus logits.
bonus_logits_indices = metadata.bonus_logits_indices
target_logits_indices = metadata.target_logits_indices
final_logits = torch.zeros_like(logits, dtype=torch.float32)
final_logits[target_logits_indices] = target_logits.to(torch.float32)
final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
# Compute accepted token indices.
accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID
num_accepted_tokens = accepted_mask.sum(dim=-1)
accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1]
accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave(
num_accepted_tokens
)
# Compute logprobs for accepted tokens.
accepted_logits = final_logits[accepted_logit_indices]
accepted_logprobs = (
accepted_logits
if self.is_logits_logprobs_mode
else self.sampler.compute_logprobs(accepted_logits)
)
accepted_tokens = sampled_token_ids[accepted_mask]
return self.sampler.gather_logprobs(
accepted_logprobs,
max_num_logprobs,
accepted_tokens.to(torch.int64),
)
@staticmethod @staticmethod
def parse_output( def parse_output(
@ -119,14 +206,12 @@ class RejectionSampler(nn.Module):
vocab_size: int, vocab_size: int,
) -> list[list[int]]: ) -> list[list[int]]:
"""Parse the output of the rejection sampler. """Parse the output of the rejection sampler.
Args: Args:
output_token_ids: The sampled token IDs in shape output_token_ids: The sampled token IDs in shape
[batch_size, max_spec_len + 1]. The rejected tokens are [batch_size, max_spec_len + 1]. The rejected tokens are
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function. and will be filtered out in this function.
vocab_size: The size of the vocabulary. vocab_size: The size of the vocabulary.
Returns: Returns:
A list of lists of token IDs. A list of lists of token IDs.
""" """
@ -328,27 +413,26 @@ def rejection_sample(
return output_token_ids return output_token_ids
def compute_probs( def apply_sampling_constraints(
logits: torch.Tensor, # [num_tokens, vocab_size] logits: torch.Tensor, # [num_tokens, vocab_size]
cu_num_draft_tokens: torch.Tensor, # [batch_size] cu_num_draft_tokens: torch.Tensor, # [batch_size]
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute probability distribution from logits based on sampling metadata. """Process logits based on sampling metadata.
This function applies temperature scaling to the logits and converts This function applies temperature scaling to the logits,
them to probabilities using softmax. For greedy decoding, it returns as well as top-k and top-p. For greedy decoding, it returns
the original logits. the original logits.
Args: Args:
logits: Input logits tensor to be converted to probabilities. logits: Input logits tensor to be processed.
cu_num_draft_tokens: Cumulative number of draft tokens. cu_num_draft_tokens: Cumulative number of draft tokens.
sampling_metadata: Metadata containing sampling parameters such as sampling_metadata: Metadata containing sampling parameters such as
temperature and whether greedy sampling is used. temperature and whether greedy sampling is used.
Returns: Returns:
torch.Tensor: Probability distribution (softmax of scaled logits) torch.Tensor: Processed logits if non-greedy sampling is used,
if non-greedy sampling is used, otherwise returns the otherwise returns the original logits.
original logits.
""" """
assert logits.ndim == 2 assert logits.ndim == 2
assert cu_num_draft_tokens.ndim == 1 assert cu_num_draft_tokens.ndim == 1
@ -384,9 +468,7 @@ def compute_probs(
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues. # which is slow for large vocab sizes. This may cause performance issues.
logits = apply_top_k_top_p(logits, top_k, top_p) return apply_top_k_top_p(logits, top_k, top_p)
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
return output_prob
def expand_batch_to_tokens( def expand_batch_to_tokens(

View File

@ -69,16 +69,18 @@ class Sampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
predict_bonus_token: bool = False, predict_bonus_token: bool = False,
logprobs_mode_override: LogprobsMode | None = None,
) -> SamplerOutput: ) -> SamplerOutput:
logprobs_mode = logprobs_mode_override or self.logprobs_mode
# NOTE(woosuk): Use the original logits (before any penalties or # NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs. # temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that # This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling). # is used for sampling (after penalties and temperature scaling).
num_logprobs = sampling_metadata.max_num_logprobs num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None: if num_logprobs is not None:
if self.logprobs_mode == "raw_logprobs": if logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(logits) raw_logprobs = self.compute_logprobs(logits)
elif self.logprobs_mode == "raw_logits": elif logprobs_mode == "raw_logits":
raw_logprobs = logits.clone() raw_logprobs = logits.clone()
# Use float32 for the logits. # Use float32 for the logits.
@ -97,12 +99,17 @@ class Sampler(nn.Module):
# return int32 (while PyTorch argmax and topk return int64). # return int32 (while PyTorch argmax and topk return int64).
sampled = sampled.long() sampled = sampled.long()
# Gather the logprobs of the topk and sampled token (if requested). if num_logprobs is None:
# Get logprobs and rank tensors (if requested) logprobs_tensors = None
logprobs_tensors = ( elif num_logprobs == -1:
None # Return the full unsorted and unranked logprobs.
if num_logprobs is None logprobs_tensors = LogprobsTensors(
else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled) torch.empty(0), raw_logprobs, torch.empty(0)
)
else:
# Gather the logprobs and ranks of the topk and sampled token.
logprobs_tensors = self.gather_logprobs(
raw_logprobs, num_logprobs, token_ids=sampled
) )
# Use int32 to reduce the tensor size. # Use int32 to reduce the tensor size.
@ -138,6 +145,7 @@ class Sampler(nn.Module):
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
logprobs_mode_override: LogprobsMode | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Sample logits based on sampling metadata. """Sample logits based on sampling metadata.
@ -145,6 +153,7 @@ class Sampler(nn.Module):
may update the logits tensor in-place. may update the logits tensor in-place.
""" """
logprobs_mode = logprobs_mode_override or self.logprobs_mode
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random) assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
if sampling_metadata.all_random: if sampling_metadata.all_random:
greedy_sampled = None greedy_sampled = None
@ -153,9 +162,9 @@ class Sampler(nn.Module):
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
processed_logprobs = None processed_logprobs = None
if sampling_metadata.max_num_logprobs is not None: if sampling_metadata.max_num_logprobs is not None:
if self.logprobs_mode == "processed_logits": if logprobs_mode == "processed_logits":
processed_logprobs = logits processed_logprobs = logits
elif self.logprobs_mode == "processed_logprobs": elif logprobs_mode == "processed_logprobs":
processed_logprobs = self.compute_logprobs(logits) processed_logprobs = self.compute_logprobs(logits)
return greedy_sampled, processed_logprobs return greedy_sampled, processed_logprobs

View File

@ -14,6 +14,8 @@ class SpecDecodeMetadata:
num_draft_tokens: list[int] num_draft_tokens: list[int]
# [batch_size] # [batch_size]
cu_num_draft_tokens: torch.Tensor cu_num_draft_tokens: torch.Tensor
# [batch_size]
cu_num_sampled_tokens: torch.Tensor
# [num_tokens] # [num_tokens]
target_logits_indices: torch.Tensor target_logits_indices: torch.Tensor
# [batch_size] # [batch_size]
@ -32,6 +34,7 @@ class SpecDecodeMetadata:
) -> "SpecDecodeMetadata": ) -> "SpecDecodeMetadata":
batch_size = len(draft_token_ids) batch_size = len(draft_token_ids)
num_draft_tokens = [len(ids) for ids in draft_token_ids] num_draft_tokens = [len(ids) for ids in draft_token_ids]
num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
flattened_draft_token_ids = sum(draft_token_ids, []) flattened_draft_token_ids = sum(draft_token_ids, [])
num_tokens = len(flattened_draft_token_ids) num_tokens = len(flattened_draft_token_ids)
@ -40,6 +43,10 @@ class SpecDecodeMetadata:
) )
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32) cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device) cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
device
)
target_logits_indices = torch.zeros( target_logits_indices = torch.zeros(
num_tokens, dtype=torch.int32, device=device num_tokens, dtype=torch.int32, device=device
@ -52,6 +59,7 @@ class SpecDecodeMetadata:
draft_token_ids=draft_token_ids_tensor, draft_token_ids=draft_token_ids_tensor,
num_draft_tokens=num_draft_tokens, num_draft_tokens=num_draft_tokens,
cu_num_draft_tokens=cu_num_draft_tokens_tensor, cu_num_draft_tokens=cu_num_draft_tokens_tensor,
cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
target_logits_indices=target_logits_indices, target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices, bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices, logits_indices=logits_indices,

View File

@ -327,7 +327,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"Unknown speculative decoding method: " "Unknown speculative decoding method: "
f"{self.speculative_config.method}" f"{self.speculative_config.method}"
) )
self.rejection_sampler = RejectionSampler() self.rejection_sampler = RejectionSampler(self.sampler)
# Request states. # Request states.
self.requests: dict[str, CachedRequestState] = {} self.requests: dict[str, CachedRequestState] = {}
@ -1624,6 +1624,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to( cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
self.device, non_blocking=True
)
logits_indices = torch.from_numpy(logits_indices).to( logits_indices = torch.from_numpy(logits_indices).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
@ -1639,15 +1642,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
draft_token_ids = self.input_ids.gpu[logits_indices] draft_token_ids = self.input_ids.gpu[logits_indices]
draft_token_ids = draft_token_ids[target_logits_indices + 1] draft_token_ids = draft_token_ids[target_logits_indices + 1]
metadata = SpecDecodeMetadata( return SpecDecodeMetadata(
draft_token_ids=draft_token_ids, draft_token_ids=draft_token_ids,
num_draft_tokens=num_draft_tokens.tolist(), num_draft_tokens=num_draft_tokens.tolist(),
cu_num_draft_tokens=cu_num_draft_tokens, cu_num_draft_tokens=cu_num_draft_tokens,
cu_num_sampled_tokens=cu_num_sampled_tokens,
target_logits_indices=target_logits_indices, target_logits_indices=target_logits_indices,
bonus_logits_indices=bonus_logits_indices, bonus_logits_indices=bonus_logits_indices,
logits_indices=logits_indices, logits_indices=logits_indices,
) )
return metadata
def _prepare_kv_sharing_fast_prefill( def _prepare_kv_sharing_fast_prefill(
self, self,
@ -2221,32 +2224,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
# When indexing with a tensor (bonus_logits_indices), PyTorch sampler_output = self.rejection_sampler(
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert logits is not None
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
sampler_output = self.sampler(
logits=bonus_logits,
sampling_metadata=sampling_metadata,
predict_bonus_token=True,
)
bonus_token_ids = sampler_output.sampled_token_ids
# Just like `bonus_logits`, `target_logits` is a new tensor with
# separate storage from the original `logits` tensor. Therefore,
# it is safe to update `target_logits` in place.
target_logits = logits[spec_decode_metadata.target_logits_indices]
output_token_ids = self.rejection_sampler(
spec_decode_metadata, spec_decode_metadata,
None, # draft_probs None, # draft_probs
target_logits, logits,
bonus_token_ids,
sampling_metadata, sampling_metadata,
) )
sampler_output.sampled_token_ids = output_token_ids self._update_states_after_model_execute(sampler_output.sampled_token_ids)
self._update_states_after_model_execute(output_token_ids)
return sampler_output return sampler_output
def _bookkeeping_sync( def _bookkeeping_sync(
@ -2256,6 +2240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits: torch.Tensor | None, logits: torch.Tensor | None,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
num_scheduled_tokens: int, num_scheduled_tokens: int,
spec_decode_metadata: SpecDecodeMetadata | None,
) -> tuple[ ) -> tuple[
dict[str, int], dict[str, int],
LogprobsLists | None, LogprobsLists | None,
@ -2282,19 +2267,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_ids_output_copy = self.input_batch.req_ids.copy() req_ids_output_copy = self.input_batch.req_ids.copy()
req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy() req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = (
logprobs_tensors.tolists() if logprobs_tensors is not None else None
)
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output.num_scheduled_tokens,
)
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0] num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
sampled_token_ids = sampler_output.sampled_token_ids sampled_token_ids = sampler_output.sampled_token_ids
invalid_req_indices = [] invalid_req_indices = []
@ -2335,6 +2307,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# the sampled tokens back, because there's no direct communication # the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker. # between the first-stage worker and the last-stage worker.
req_ids = self.input_batch.req_ids req_ids = self.input_batch.req_ids
logprobs_tensors = sampler_output.logprobs_tensors
cu_num_accepted_tokens = (
[0] if spec_decode_metadata and logprobs_tensors else None
)
for req_idx in range(num_sampled_tokens): for req_idx in range(num_sampled_tokens):
if self.use_async_scheduling: if self.use_async_scheduling:
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
@ -2360,6 +2336,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id] req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids) req_state.output_token_ids.extend(sampled_ids)
if cu_num_accepted_tokens is not None:
cu_num_accepted_tokens.append(
cu_num_accepted_tokens[-1] + len(sampled_ids)
)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_lists = (
logprobs_tensors.tolists(cu_num_accepted_tokens)
if logprobs_tensors is not None
else None
)
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states[:num_scheduled_tokens],
scheduler_output.num_scheduled_tokens,
)
return ( return (
num_nans_in_logits, num_nans_in_logits,
logprobs_lists, logprobs_lists,
@ -2644,6 +2639,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits, logits,
hidden_states, hidden_states,
num_scheduled_tokens, num_scheduled_tokens,
spec_decode_metadata,
) )
if ( if (
@ -3560,20 +3556,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# num_tokens, logits.shape[-1], device=self.device, # num_tokens, logits.shape[-1], device=self.device,
# dtype=logits.dtype) # dtype=logits.dtype)
draft_probs = None draft_probs = None
target_logits = torch.randn( logits = torch.randn(
num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype num_tokens + num_reqs,
) logits.shape[-1],
# NOTE(woosuk): Here, we should use int32 because the sampler uses device=self.device,
# int32 for bonus_token_ids. If the dtype mismatches, re-compilation dtype=logits.dtype,
# will occur at runtime.
bonus_token_ids = torch.zeros(
num_reqs, device=self.device, dtype=torch.int32
) )
self.rejection_sampler( self.rejection_sampler(
dummy_spec_decode_metadata, dummy_spec_decode_metadata,
draft_probs, draft_probs,
target_logits, logits,
bonus_token_ids,
dummy_metadata, dummy_metadata,
) )
return sampler_output return sampler_output