mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:45:29 +08:00
[V1][spec decode] return logprobs for spec decoding (#26060)
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
ff93cc8c84
commit
6644796bf4
@ -2,12 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from collections.abc import Generator
|
||||
from typing import get_args
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.utils import large_gpu_mark
|
||||
from tests.v1.sample.utils import (
|
||||
BatchLogprobsComposition,
|
||||
BatchLogprobsSpecType,
|
||||
@ -17,6 +19,7 @@ from tests.v1.sample.utils import (
|
||||
)
|
||||
from vllm import SamplingParams
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
|
||||
from ...conftest import HfRunner, VllmRunner
|
||||
|
||||
@ -508,3 +511,94 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
|
||||
if logprobs_mode in ("raw_logits", "processed_logits"):
|
||||
assert positive_values > 0
|
||||
del llm
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
|
||||
@pytest.mark.parametrize(
|
||||
"model_setup",
|
||||
[
|
||||
pytest.param(
|
||||
(
|
||||
"eagle",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||
),
|
||||
marks=large_gpu_mark(min_gb=32),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_spec_decode_logprobs(
|
||||
logprobs_mode: LogprobsMode,
|
||||
model_setup: tuple[str, str, str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Spec decode logprobs should match those of the base model.
|
||||
|
||||
Args:
|
||||
logprobs_mode: logprobs mode.
|
||||
model_setup: Spec decode method, base model name, and
|
||||
draft model name.
|
||||
"""
|
||||
from vllm import LLM
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
prompt = "Hello world"
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
|
||||
)
|
||||
method, model_name, spec_model_name = model_setup
|
||||
max_model_len = 256
|
||||
|
||||
# Run base LLM.
|
||||
ref_llm = LLM(
|
||||
model=model_name,
|
||||
max_logprobs=5,
|
||||
max_model_len=max_model_len,
|
||||
seed=42,
|
||||
logprobs_mode=logprobs_mode,
|
||||
gpu_memory_utilization=0.4,
|
||||
)
|
||||
ref_results = ref_llm.generate([prompt], sampling_params)
|
||||
# Collect logprobs outputs from reference LLM.
|
||||
ref_logprobs = []
|
||||
for output in ref_results[0].outputs:
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
ref_logprobs.append(logprobs[token_id])
|
||||
del ref_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Run spec decode LLM.
|
||||
spec_llm = LLM(
|
||||
model_name,
|
||||
speculative_config={
|
||||
"method": method,
|
||||
"model": spec_model_name,
|
||||
"num_speculative_tokens": 3,
|
||||
"max_model_len": max_model_len,
|
||||
},
|
||||
max_logprobs=5,
|
||||
max_model_len=max_model_len,
|
||||
seed=42,
|
||||
logprobs_mode=logprobs_mode,
|
||||
gpu_memory_utilization=0.4,
|
||||
)
|
||||
spec_results = spec_llm.generate([prompt], sampling_params)
|
||||
# Collect logprobs outputs from spec decode LLM.
|
||||
spec_logprobs = []
|
||||
for output in spec_results[0].outputs:
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
spec_logprobs.append(logprobs[token_id])
|
||||
del spec_llm
|
||||
torch.cuda.empty_cache()
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
# Per-token logprobs are expected to be the same.
|
||||
assert len(ref_logprobs) == len(spec_logprobs)
|
||||
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
|
||||
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
|
||||
assert ref_logprob.rank == spec_logprob.rank
|
||||
assert ref_logprob.decoded_token == spec_logprob.decoded_token
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -11,6 +12,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
|
||||
from vllm.v1.sample.sampler import Sampler, SamplerOutput
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
DEVICE = current_platform.device_type
|
||||
@ -18,7 +20,28 @@ DEVICE = current_platform.device_type
|
||||
|
||||
@pytest.fixture
|
||||
def rejection_sampler():
|
||||
return RejectionSampler()
|
||||
mock_sampler = Mock(spec=Sampler)
|
||||
mock_sampler.logprobs_mode = "raw_logprobs"
|
||||
return RejectionSampler(mock_sampler)
|
||||
|
||||
|
||||
def mock_sampler_output(
|
||||
rejection_sampler: RejectionSampler, bonus_token_ids: torch.Tensor
|
||||
):
|
||||
rejection_sampler.sampler.return_value = SamplerOutput(
|
||||
sampled_token_ids=bonus_token_ids, logprobs_tensors=None
|
||||
)
|
||||
|
||||
|
||||
def create_spec_decode_metadata(
|
||||
spec_tokens: list[list[int]], logits: torch.Tensor
|
||||
) -> SpecDecodeMetadata:
|
||||
metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device)
|
||||
metadata.target_logits_indices = torch.arange(logits.shape[0])
|
||||
# Output bonus token ids are mocked, so the bonus logit indices should
|
||||
# be empty.
|
||||
metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32)
|
||||
return metadata
|
||||
|
||||
|
||||
def create_logits_tensor(
|
||||
@ -111,19 +134,17 @@ def test_perfect_match(rejection_sampler):
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_early_mismatch(rejection_sampler):
|
||||
@ -134,15 +155,13 @@ def test_early_mismatch(rejection_sampler):
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
@ -150,7 +169,7 @@ def test_early_mismatch(rejection_sampler):
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_multiple_sequences(rejection_sampler):
|
||||
@ -163,21 +182,19 @@ def test_multiple_sequences(rejection_sampler):
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
[[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_single_token_sequence(rejection_sampler):
|
||||
@ -188,19 +205,17 @@ def test_single_token_sequence(rejection_sampler):
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_empty_sequence(rejection_sampler):
|
||||
@ -211,19 +226,17 @@ def test_empty_sequence(rejection_sampler):
|
||||
metadata = create_sampling_metadata(all_greedy=True)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_multiple_mismatches(rejection_sampler):
|
||||
@ -236,15 +249,13 @@ def test_multiple_mismatches(rejection_sampler):
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
@ -255,7 +266,7 @@ def test_multiple_mismatches(rejection_sampler):
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -277,19 +288,17 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expec
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[tokens[-1] for tokens in output_tokens], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output, expected_tensor)
|
||||
assert torch.equal(output.sampled_token_ids, expected_tensor)
|
||||
|
||||
|
||||
########################### Tests for Random Sampling ###################
|
||||
@ -331,18 +340,19 @@ def test_deterministic_when_seeded(
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False, temperature=temperature, generators=seeded_seqs
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids.tolist(), device=DEVICE
|
||||
spec_decode_metadata = create_spec_decode_metadata(
|
||||
draft_token_ids.tolist(), target_logits
|
||||
)
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_ids)
|
||||
rep_result = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
target_logits=target_logits,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=None,
|
||||
logits=target_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
results.append(rep_result)
|
||||
results.append(rep_result.sampled_token_ids)
|
||||
|
||||
for i in range(batch_size):
|
||||
if seeded_mask[i]:
|
||||
@ -460,7 +470,9 @@ def estimate_rejection_sampling_pdf(
|
||||
Returns:
|
||||
Estimated probability distribution of the output tokens.
|
||||
"""
|
||||
rejection_sampler = RejectionSampler()
|
||||
mock_sampler = Mock(spec=Sampler)
|
||||
mock_sampler.logprobs_mode = "raw_logprobs"
|
||||
rejection_sampler = RejectionSampler(mock_sampler)
|
||||
num_tokens = num_samples * k
|
||||
# Repeat draft probs num_samples * k times.
|
||||
draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
|
||||
@ -483,17 +495,18 @@ def estimate_rejection_sampling_pdf(
|
||||
sampling_metadata = create_sampling_metadata(
|
||||
all_greedy=False, temperature=temperature
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids.tolist(), device=bonus_token_ids.device
|
||||
spec_decode_metadata = create_spec_decode_metadata(
|
||||
draft_token_ids.tolist(), target_logits
|
||||
)
|
||||
output_token_ids = rejection_sampler(
|
||||
|
||||
mock_sampler_output(rejection_sampler, bonus_token_ids)
|
||||
sampler_output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
target_logits=target_logits,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
logits=target_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
||||
output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten()
|
||||
|
||||
hist = torch.histogram(
|
||||
output_token_ids.to(dtype=torch.float, device="cpu"),
|
||||
@ -532,22 +545,19 @@ def _test_masked_logits(
|
||||
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
|
||||
|
||||
# Create spec decode metadata
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids,
|
||||
device=DEVICE,
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
|
||||
|
||||
# Run rejection sampling
|
||||
output_token_ids = rejection_sampler(
|
||||
mock_sampler_output(rejection_sampler, bonus_token_ids)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=draft_probs,
|
||||
target_logits=target_logits,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
logits=target_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
# Remove bonus tokens and reshape
|
||||
output_token_ids = output_token_ids[:, :-1].flatten().tolist()
|
||||
output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist()
|
||||
|
||||
# Check that all sampled tokens are within the unmasked indices.
|
||||
for i in range(num_tokens):
|
||||
@ -665,11 +675,11 @@ def test_frequency_penalties(rejection_sampler):
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
expected = torch.tensor(
|
||||
@ -677,7 +687,7 @@ def test_frequency_penalties(rejection_sampler):
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_bad_words(rejection_sampler):
|
||||
@ -707,14 +717,12 @@ def test_bad_words(rejection_sampler):
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
|
||||
@ -723,7 +731,7 @@ def test_bad_words(rejection_sampler):
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_allowed_token_ids(rejection_sampler):
|
||||
@ -756,14 +764,12 @@ def test_allowed_token_ids(rejection_sampler):
|
||||
bonus_token_tensor = torch.tensor(
|
||||
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||
)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
spec_tokens, device=logits.device
|
||||
)
|
||||
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||
output = rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
draft_probs=None,
|
||||
target_logits=logits,
|
||||
bonus_token_ids=bonus_token_tensor,
|
||||
logits=logits,
|
||||
sampling_metadata=metadata,
|
||||
)
|
||||
|
||||
@ -772,4 +778,4 @@ def test_allowed_token_ids(rejection_sampler):
|
||||
dtype=torch.int,
|
||||
device=logits.device,
|
||||
)
|
||||
assert torch.equal(output, expected)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
@ -66,7 +66,7 @@ class LogprobsProcessor:
|
||||
assert self.logprobs is not None
|
||||
assert self.cumulative_logprob is not None
|
||||
|
||||
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
|
||||
token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists
|
||||
|
||||
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst):
|
||||
# Detokenize (non-incrementally).
|
||||
|
||||
@ -14,34 +14,49 @@ else:
|
||||
|
||||
|
||||
class LogprobsLists(NamedTuple):
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
|
||||
logprob_token_ids: list[list[int]]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
|
||||
logprobs: list[list[float]]
|
||||
# [num_reqs]
|
||||
# [num_reqs x num_generated_tokens]
|
||||
sampled_token_ranks: list[int]
|
||||
# [num_reqs]
|
||||
# Used for slicing the logprobs in cases like speculative
|
||||
# decoding where the number of generated tokens may be
|
||||
# different for each request.
|
||||
cu_num_generated_tokens: list[int] | None = None
|
||||
|
||||
def slice(self, start: int, end: int):
|
||||
def slice(self, start_req_idx: int, end_req_idx: int):
|
||||
if self.cu_num_generated_tokens:
|
||||
start = self.cu_num_generated_tokens[start_req_idx]
|
||||
end = self.cu_num_generated_tokens[end_req_idx]
|
||||
else:
|
||||
start = start_req_idx
|
||||
end = end_req_idx
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids[start:end],
|
||||
self.logprobs[start:end],
|
||||
self.sampled_token_ranks[start:end],
|
||||
self.cu_num_generated_tokens[start_req_idx:end_req_idx]
|
||||
if self.cu_num_generated_tokens
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
class LogprobsTensors(NamedTuple):
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
|
||||
logprob_token_ids: torch.Tensor
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
|
||||
logprobs: torch.Tensor
|
||||
# [num_reqs]
|
||||
# [num_reqs x num_generated_tokens]
|
||||
selected_token_ranks: torch.Tensor
|
||||
|
||||
def tolists(self):
|
||||
def tolists(self, cu_num_generated_tokens: list[int] | None = None):
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids.tolist(),
|
||||
self.logprobs.tolist(),
|
||||
self.selected_token_ranks.tolist(),
|
||||
cu_num_generated_tokens,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1,15 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import replace
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
|
||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.sampler import Sampler
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -44,17 +48,22 @@ class RejectionSampler(nn.Module):
|
||||
output tokens = accepted tokens + recovered tokens + bonus tokens
|
||||
"""
|
||||
|
||||
def __init__(self, sampler: Sampler):
|
||||
super().__init__()
|
||||
self.sampler = sampler
|
||||
logprobs_mode = self.sampler.logprobs_mode
|
||||
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
|
||||
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
metadata: SpecDecodeMetadata,
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: torch.Tensor | None,
|
||||
# [num_tokens, vocab_size]
|
||||
target_logits: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
# [num_tokens + batch_size, vocab_size]
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
) -> SamplerOutput:
|
||||
"""
|
||||
Args:
|
||||
metadata:
|
||||
@ -63,43 +72,65 @@ class RejectionSampler(nn.Module):
|
||||
Probability distribution for the draft tokens. Shape is
|
||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
||||
not provided, which is the case for ngram spec decode.
|
||||
target_logits (torch.Tensor):
|
||||
logits (torch.Tensor):
|
||||
Target model's logits probability distribution.
|
||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
||||
different requests are flattened into a single tensor because
|
||||
this is the shape of the output logits.
|
||||
NOTE: `target_logits` can be updated in place to save memory.
|
||||
bonus_token_ids (torch.Tensor):
|
||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
||||
Bonus tokens are added to the end of the sequence if all
|
||||
proposed tokens are accepted. We generate the bonus tokens
|
||||
outside of the rejection sampler with the default sampling
|
||||
strategy. It allows for more flexibility in the sampling
|
||||
process such as top_p, top_k sampling.
|
||||
Shape is [num_tokens + batch_size, vocab_size]. Here,
|
||||
probabilities from different requests are flattened into a
|
||||
single tensor because this is the shape of the output logits.
|
||||
NOTE: `logits` can be updated in place to save memory.
|
||||
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
|
||||
Additional metadata needed for sampling, such as temperature,
|
||||
top-k/top-p parameters, or other relevant information.
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
SamplerOutput:
|
||||
Contains the final output token IDs and their logprobs if
|
||||
requested.
|
||||
"""
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
|
||||
# Use float32 for the target_logits.
|
||||
target_logits = target_logits.to(torch.float32)
|
||||
bonus_logits_indices = metadata.bonus_logits_indices
|
||||
target_logits_indices = metadata.target_logits_indices
|
||||
|
||||
target_logits = self.apply_logits_processors(
|
||||
target_logits, sampling_metadata, metadata
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[bonus_logits_indices]
|
||||
bonus_sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=replace(
|
||||
sampling_metadata,
|
||||
max_num_logprobs=-1,
|
||||
),
|
||||
predict_bonus_token=True,
|
||||
# Override the logprobs mode to return logits because they are
|
||||
# needed later to compute the accepted token logprobs.
|
||||
logprobs_mode_override="processed_logits"
|
||||
if self.is_processed_logprobs_mode
|
||||
else "raw_logits",
|
||||
)
|
||||
bonus_token_ids = bonus_sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
raw_target_logits = logits[target_logits_indices]
|
||||
# Use float32 for the target_logits.
|
||||
raw_target_logits = raw_target_logits.to(torch.float32)
|
||||
target_logits = self.apply_logits_processors(
|
||||
raw_target_logits, sampling_metadata, metadata
|
||||
)
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
# `compute_probs` function.
|
||||
target_probs = compute_probs(
|
||||
# `apply_sampling_constraints` function.
|
||||
target_logits = apply_sampling_constraints(
|
||||
target_logits,
|
||||
metadata.cu_num_draft_tokens,
|
||||
sampling_metadata,
|
||||
)
|
||||
# Compute probability distribution from target logits.
|
||||
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
output_token_ids = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
@ -111,7 +142,63 @@ class RejectionSampler(nn.Module):
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
return output_token_ids
|
||||
|
||||
logprobs_tensors = None
|
||||
if sampling_metadata.max_num_logprobs:
|
||||
logprobs_tensors = self._get_logprobs_tensors(
|
||||
sampling_metadata.max_num_logprobs,
|
||||
metadata,
|
||||
logits,
|
||||
target_logits if self.is_processed_logprobs_mode else raw_target_logits,
|
||||
bonus_sampler_output.logprobs_tensors.logprobs,
|
||||
output_token_ids,
|
||||
)
|
||||
|
||||
return SamplerOutput(
|
||||
sampled_token_ids=output_token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
|
||||
def _get_logprobs_tensors(
|
||||
self,
|
||||
max_num_logprobs: int,
|
||||
metadata: SpecDecodeMetadata,
|
||||
logits: torch.Tensor,
|
||||
target_logits: torch.Tensor,
|
||||
bonus_logits: torch.Tensor,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens)
|
||||
cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1]
|
||||
|
||||
# Collect target and bonus logits.
|
||||
bonus_logits_indices = metadata.bonus_logits_indices
|
||||
target_logits_indices = metadata.target_logits_indices
|
||||
final_logits = torch.zeros_like(logits, dtype=torch.float32)
|
||||
final_logits[target_logits_indices] = target_logits.to(torch.float32)
|
||||
final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
|
||||
|
||||
# Compute accepted token indices.
|
||||
accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID
|
||||
num_accepted_tokens = accepted_mask.sum(dim=-1)
|
||||
accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1]
|
||||
accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave(
|
||||
num_accepted_tokens
|
||||
)
|
||||
|
||||
# Compute logprobs for accepted tokens.
|
||||
accepted_logits = final_logits[accepted_logit_indices]
|
||||
accepted_logprobs = (
|
||||
accepted_logits
|
||||
if self.is_logits_logprobs_mode
|
||||
else self.sampler.compute_logprobs(accepted_logits)
|
||||
)
|
||||
accepted_tokens = sampled_token_ids[accepted_mask]
|
||||
return self.sampler.gather_logprobs(
|
||||
accepted_logprobs,
|
||||
max_num_logprobs,
|
||||
accepted_tokens.to(torch.int64),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_output(
|
||||
@ -119,14 +206,12 @@ class RejectionSampler(nn.Module):
|
||||
vocab_size: int,
|
||||
) -> list[list[int]]:
|
||||
"""Parse the output of the rejection sampler.
|
||||
|
||||
Args:
|
||||
output_token_ids: The sampled token IDs in shape
|
||||
[batch_size, max_spec_len + 1]. The rejected tokens are
|
||||
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
|
||||
and will be filtered out in this function.
|
||||
vocab_size: The size of the vocabulary.
|
||||
|
||||
Returns:
|
||||
A list of lists of token IDs.
|
||||
"""
|
||||
@ -328,27 +413,26 @@ def rejection_sample(
|
||||
return output_token_ids
|
||||
|
||||
|
||||
def compute_probs(
|
||||
def apply_sampling_constraints(
|
||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Compute probability distribution from logits based on sampling metadata.
|
||||
"""Process logits based on sampling metadata.
|
||||
|
||||
This function applies temperature scaling to the logits and converts
|
||||
them to probabilities using softmax. For greedy decoding, it returns
|
||||
This function applies temperature scaling to the logits,
|
||||
as well as top-k and top-p. For greedy decoding, it returns
|
||||
the original logits.
|
||||
|
||||
Args:
|
||||
logits: Input logits tensor to be converted to probabilities.
|
||||
logits: Input logits tensor to be processed.
|
||||
cu_num_draft_tokens: Cumulative number of draft tokens.
|
||||
sampling_metadata: Metadata containing sampling parameters such as
|
||||
temperature and whether greedy sampling is used.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Probability distribution (softmax of scaled logits)
|
||||
if non-greedy sampling is used, otherwise returns the
|
||||
original logits.
|
||||
torch.Tensor: Processed logits if non-greedy sampling is used,
|
||||
otherwise returns the original logits.
|
||||
"""
|
||||
assert logits.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
@ -384,9 +468,7 @@ def compute_probs(
|
||||
|
||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||
# which is slow for large vocab sizes. This may cause performance issues.
|
||||
logits = apply_top_k_top_p(logits, top_k, top_p)
|
||||
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return output_prob
|
||||
return apply_top_k_top_p(logits, top_k, top_p)
|
||||
|
||||
|
||||
def expand_batch_to_tokens(
|
||||
|
||||
@ -69,16 +69,18 @@ class Sampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
predict_bonus_token: bool = False,
|
||||
logprobs_mode_override: LogprobsMode | None = None,
|
||||
) -> SamplerOutput:
|
||||
logprobs_mode = logprobs_mode_override or self.logprobs_mode
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
if self.logprobs_mode == "raw_logprobs":
|
||||
if logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
elif logprobs_mode == "raw_logits":
|
||||
raw_logprobs = logits.clone()
|
||||
|
||||
# Use float32 for the logits.
|
||||
@ -97,13 +99,18 @@ class Sampler(nn.Module):
|
||||
# return int32 (while PyTorch argmax and topk return int64).
|
||||
sampled = sampled.long()
|
||||
|
||||
# Gather the logprobs of the topk and sampled token (if requested).
|
||||
# Get logprobs and rank tensors (if requested)
|
||||
logprobs_tensors = (
|
||||
None
|
||||
if num_logprobs is None
|
||||
else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
|
||||
)
|
||||
if num_logprobs is None:
|
||||
logprobs_tensors = None
|
||||
elif num_logprobs == -1:
|
||||
# Return the full unsorted and unranked logprobs.
|
||||
logprobs_tensors = LogprobsTensors(
|
||||
torch.empty(0), raw_logprobs, torch.empty(0)
|
||||
)
|
||||
else:
|
||||
# Gather the logprobs and ranks of the topk and sampled token.
|
||||
logprobs_tensors = self.gather_logprobs(
|
||||
raw_logprobs, num_logprobs, token_ids=sampled
|
||||
)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
@ -138,6 +145,7 @@ class Sampler(nn.Module):
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
logprobs_mode_override: LogprobsMode | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Sample logits based on sampling metadata.
|
||||
|
||||
@ -145,6 +153,7 @@ class Sampler(nn.Module):
|
||||
may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
logprobs_mode = logprobs_mode_override or self.logprobs_mode
|
||||
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
|
||||
if sampling_metadata.all_random:
|
||||
greedy_sampled = None
|
||||
@ -153,9 +162,9 @@ class Sampler(nn.Module):
|
||||
if sampling_metadata.all_greedy:
|
||||
processed_logprobs = None
|
||||
if sampling_metadata.max_num_logprobs is not None:
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
if logprobs_mode == "processed_logits":
|
||||
processed_logprobs = logits
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
elif logprobs_mode == "processed_logprobs":
|
||||
processed_logprobs = self.compute_logprobs(logits)
|
||||
return greedy_sampled, processed_logprobs
|
||||
|
||||
|
||||
@ -14,6 +14,8 @@ class SpecDecodeMetadata:
|
||||
num_draft_tokens: list[int]
|
||||
# [batch_size]
|
||||
cu_num_draft_tokens: torch.Tensor
|
||||
# [batch_size]
|
||||
cu_num_sampled_tokens: torch.Tensor
|
||||
# [num_tokens]
|
||||
target_logits_indices: torch.Tensor
|
||||
# [batch_size]
|
||||
@ -32,6 +34,7 @@ class SpecDecodeMetadata:
|
||||
) -> "SpecDecodeMetadata":
|
||||
batch_size = len(draft_token_ids)
|
||||
num_draft_tokens = [len(ids) for ids in draft_token_ids]
|
||||
num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
|
||||
flattened_draft_token_ids = sum(draft_token_ids, [])
|
||||
num_tokens = len(flattened_draft_token_ids)
|
||||
|
||||
@ -40,6 +43,10 @@ class SpecDecodeMetadata:
|
||||
)
|
||||
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
||||
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
|
||||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
||||
cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
|
||||
device
|
||||
)
|
||||
|
||||
target_logits_indices = torch.zeros(
|
||||
num_tokens, dtype=torch.int32, device=device
|
||||
@ -52,6 +59,7 @@ class SpecDecodeMetadata:
|
||||
draft_token_ids=draft_token_ids_tensor,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
|
||||
cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
|
||||
@ -327,7 +327,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"Unknown speculative decoding method: "
|
||||
f"{self.speculative_config.method}"
|
||||
)
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
self.rejection_sampler = RejectionSampler(self.sampler)
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
@ -1624,6 +1624,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
logits_indices = torch.from_numpy(logits_indices).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
@ -1639,15 +1642,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
draft_token_ids = self.input_ids.gpu[logits_indices]
|
||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||
|
||||
metadata = SpecDecodeMetadata(
|
||||
return SpecDecodeMetadata(
|
||||
draft_token_ids=draft_token_ids,
|
||||
num_draft_tokens=num_draft_tokens.tolist(),
|
||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||
cu_num_sampled_tokens=cu_num_sampled_tokens,
|
||||
target_logits_indices=target_logits_indices,
|
||||
bonus_logits_indices=bonus_logits_indices,
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
return metadata
|
||||
|
||||
def _prepare_kv_sharing_fast_prefill(
|
||||
self,
|
||||
@ -2221,32 +2224,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||
# creates a new tensor with separate storage from the original
|
||||
# logits tensor. This means any in-place operations on bonus_logits
|
||||
# won't affect the original logits tensor.
|
||||
assert logits is not None
|
||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
||||
sampler_output = self.sampler(
|
||||
logits=bonus_logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
predict_bonus_token=True,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
|
||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||
# separate storage from the original `logits` tensor. Therefore,
|
||||
# it is safe to update `target_logits` in place.
|
||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||
output_token_ids = self.rejection_sampler(
|
||||
sampler_output = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
logits,
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
self._update_states_after_model_execute(sampler_output.sampled_token_ids)
|
||||
return sampler_output
|
||||
|
||||
def _bookkeeping_sync(
|
||||
@ -2256,6 +2240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logits: torch.Tensor | None,
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
spec_decode_metadata: SpecDecodeMetadata | None,
|
||||
) -> tuple[
|
||||
dict[str, int],
|
||||
LogprobsLists | None,
|
||||
@ -2282,19 +2267,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_ids_output_copy = self.input_batch.req_ids.copy()
|
||||
req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
logprobs_lists = (
|
||||
logprobs_tensors.tolists() if logprobs_tensors is not None else None
|
||||
)
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
)
|
||||
|
||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
invalid_req_indices = []
|
||||
@ -2335,6 +2307,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# the sampled tokens back, because there's no direct communication
|
||||
# between the first-stage worker and the last-stage worker.
|
||||
req_ids = self.input_batch.req_ids
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
cu_num_accepted_tokens = (
|
||||
[0] if spec_decode_metadata and logprobs_tensors else None
|
||||
)
|
||||
for req_idx in range(num_sampled_tokens):
|
||||
if self.use_async_scheduling:
|
||||
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
|
||||
@ -2360,6 +2336,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
|
||||
if cu_num_accepted_tokens is not None:
|
||||
cu_num_accepted_tokens.append(
|
||||
cu_num_accepted_tokens[-1] + len(sampled_ids)
|
||||
)
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
logprobs_lists = (
|
||||
logprobs_tensors.tolists(cu_num_accepted_tokens)
|
||||
if logprobs_tensors is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states[:num_scheduled_tokens],
|
||||
scheduler_output.num_scheduled_tokens,
|
||||
)
|
||||
|
||||
return (
|
||||
num_nans_in_logits,
|
||||
logprobs_lists,
|
||||
@ -2644,6 +2639,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logits,
|
||||
hidden_states,
|
||||
num_scheduled_tokens,
|
||||
spec_decode_metadata,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -3560,20 +3556,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# num_tokens, logits.shape[-1], device=self.device,
|
||||
# dtype=logits.dtype)
|
||||
draft_probs = None
|
||||
target_logits = torch.randn(
|
||||
num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype
|
||||
)
|
||||
# NOTE(woosuk): Here, we should use int32 because the sampler uses
|
||||
# int32 for bonus_token_ids. If the dtype mismatches, re-compilation
|
||||
# will occur at runtime.
|
||||
bonus_token_ids = torch.zeros(
|
||||
num_reqs, device=self.device, dtype=torch.int32
|
||||
logits = torch.randn(
|
||||
num_tokens + num_reqs,
|
||||
logits.shape[-1],
|
||||
device=self.device,
|
||||
dtype=logits.dtype,
|
||||
)
|
||||
self.rejection_sampler(
|
||||
dummy_spec_decode_metadata,
|
||||
draft_probs,
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
logits,
|
||||
dummy_metadata,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user