mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 14:05:01 +08:00
[V1][spec decode] return logprobs for spec decoding (#26060)
Signed-off-by: Giancarlo Delfin <gdelfin@meta.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
ff93cc8c84
commit
6644796bf4
@ -2,12 +2,14 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
import math
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import get_args
|
from typing import get_args
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.utils import large_gpu_mark
|
||||||
from tests.v1.sample.utils import (
|
from tests.v1.sample.utils import (
|
||||||
BatchLogprobsComposition,
|
BatchLogprobsComposition,
|
||||||
BatchLogprobsSpecType,
|
BatchLogprobsSpecType,
|
||||||
@ -17,6 +19,7 @@ from tests.v1.sample.utils import (
|
|||||||
)
|
)
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.config.model import LogprobsMode
|
from vllm.config.model import LogprobsMode
|
||||||
|
from vllm.distributed import cleanup_dist_env_and_memory
|
||||||
|
|
||||||
from ...conftest import HfRunner, VllmRunner
|
from ...conftest import HfRunner, VllmRunner
|
||||||
|
|
||||||
@ -508,3 +511,94 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
|
|||||||
if logprobs_mode in ("raw_logits", "processed_logits"):
|
if logprobs_mode in ("raw_logits", "processed_logits"):
|
||||||
assert positive_values > 0
|
assert positive_values > 0
|
||||||
del llm
|
del llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_setup",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
(
|
||||||
|
"eagle",
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B",
|
||||||
|
),
|
||||||
|
marks=large_gpu_mark(min_gb=32),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_spec_decode_logprobs(
|
||||||
|
logprobs_mode: LogprobsMode,
|
||||||
|
model_setup: tuple[str, str, str],
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
):
|
||||||
|
"""Spec decode logprobs should match those of the base model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logprobs_mode: logprobs mode.
|
||||||
|
model_setup: Spec decode method, base model name, and
|
||||||
|
draft model name.
|
||||||
|
"""
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
prompt = "Hello world"
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0, logprobs=3, max_tokens=10, ignore_eos=False
|
||||||
|
)
|
||||||
|
method, model_name, spec_model_name = model_setup
|
||||||
|
max_model_len = 256
|
||||||
|
|
||||||
|
# Run base LLM.
|
||||||
|
ref_llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
max_logprobs=5,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
seed=42,
|
||||||
|
logprobs_mode=logprobs_mode,
|
||||||
|
gpu_memory_utilization=0.4,
|
||||||
|
)
|
||||||
|
ref_results = ref_llm.generate([prompt], sampling_params)
|
||||||
|
# Collect logprobs outputs from reference LLM.
|
||||||
|
ref_logprobs = []
|
||||||
|
for output in ref_results[0].outputs:
|
||||||
|
for logprobs in output.logprobs:
|
||||||
|
for token_id in logprobs:
|
||||||
|
ref_logprobs.append(logprobs[token_id])
|
||||||
|
del ref_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
# Run spec decode LLM.
|
||||||
|
spec_llm = LLM(
|
||||||
|
model_name,
|
||||||
|
speculative_config={
|
||||||
|
"method": method,
|
||||||
|
"model": spec_model_name,
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
"max_model_len": max_model_len,
|
||||||
|
},
|
||||||
|
max_logprobs=5,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
seed=42,
|
||||||
|
logprobs_mode=logprobs_mode,
|
||||||
|
gpu_memory_utilization=0.4,
|
||||||
|
)
|
||||||
|
spec_results = spec_llm.generate([prompt], sampling_params)
|
||||||
|
# Collect logprobs outputs from spec decode LLM.
|
||||||
|
spec_logprobs = []
|
||||||
|
for output in spec_results[0].outputs:
|
||||||
|
for logprobs in output.logprobs:
|
||||||
|
for token_id in logprobs:
|
||||||
|
spec_logprobs.append(logprobs[token_id])
|
||||||
|
del spec_llm
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
cleanup_dist_env_and_memory()
|
||||||
|
|
||||||
|
# Per-token logprobs are expected to be the same.
|
||||||
|
assert len(ref_logprobs) == len(spec_logprobs)
|
||||||
|
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
|
||||||
|
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, abs_tol=1e-3)
|
||||||
|
assert ref_logprob.rank == spec_logprob.rank
|
||||||
|
assert ref_logprob.decoded_token == spec_logprob.decoded_token
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -11,6 +12,7 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
|
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
|
||||||
|
from vllm.v1.sample.sampler import Sampler, SamplerOutput
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
|
|
||||||
DEVICE = current_platform.device_type
|
DEVICE = current_platform.device_type
|
||||||
@ -18,7 +20,28 @@ DEVICE = current_platform.device_type
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def rejection_sampler():
|
def rejection_sampler():
|
||||||
return RejectionSampler()
|
mock_sampler = Mock(spec=Sampler)
|
||||||
|
mock_sampler.logprobs_mode = "raw_logprobs"
|
||||||
|
return RejectionSampler(mock_sampler)
|
||||||
|
|
||||||
|
|
||||||
|
def mock_sampler_output(
|
||||||
|
rejection_sampler: RejectionSampler, bonus_token_ids: torch.Tensor
|
||||||
|
):
|
||||||
|
rejection_sampler.sampler.return_value = SamplerOutput(
|
||||||
|
sampled_token_ids=bonus_token_ids, logprobs_tensors=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_spec_decode_metadata(
|
||||||
|
spec_tokens: list[list[int]], logits: torch.Tensor
|
||||||
|
) -> SpecDecodeMetadata:
|
||||||
|
metadata = SpecDecodeMetadata.make_dummy(spec_tokens, device=logits.device)
|
||||||
|
metadata.target_logits_indices = torch.arange(logits.shape[0])
|
||||||
|
# Output bonus token ids are mocked, so the bonus logit indices should
|
||||||
|
# be empty.
|
||||||
|
metadata.bonus_logits_indices = torch.empty(0, dtype=torch.int32)
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
def create_logits_tensor(
|
def create_logits_tensor(
|
||||||
@ -111,19 +134,17 @@ def test_perfect_match(rejection_sampler):
|
|||||||
metadata = create_sampling_metadata(all_greedy=True)
|
metadata = create_sampling_metadata(all_greedy=True)
|
||||||
logits = create_logits_tensor(output_tokens)
|
logits = create_logits_tensor(output_tokens)
|
||||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||||
spec_tokens, device=logits.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||||
output = rejection_sampler(
|
output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=None,
|
draft_probs=None,
|
||||||
target_logits=logits,
|
logits=logits,
|
||||||
bonus_token_ids=bonus_token_tensor,
|
|
||||||
sampling_metadata=metadata,
|
sampling_metadata=metadata,
|
||||||
)
|
)
|
||||||
expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
|
expected = torch.tensor([[1, 2, 3, 4]], dtype=torch.int, device=logits.device)
|
||||||
assert torch.equal(output, expected)
|
assert torch.equal(output.sampled_token_ids, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_early_mismatch(rejection_sampler):
|
def test_early_mismatch(rejection_sampler):
|
||||||
@ -134,15 +155,13 @@ def test_early_mismatch(rejection_sampler):
|
|||||||
metadata = create_sampling_metadata(all_greedy=True)
|
metadata = create_sampling_metadata(all_greedy=True)
|
||||||
logits = create_logits_tensor(output_tokens)
|
logits = create_logits_tensor(output_tokens)
|
||||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||||
spec_tokens, device=logits.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||||
output = rejection_sampler(
|
output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=None,
|
draft_probs=None,
|
||||||
target_logits=logits,
|
logits=logits,
|
||||||
bonus_token_ids=bonus_token_tensor,
|
|
||||||
sampling_metadata=metadata,
|
sampling_metadata=metadata,
|
||||||
)
|
)
|
||||||
expected = torch.tensor(
|
expected = torch.tensor(
|
||||||
@ -150,7 +169,7 @@ def test_early_mismatch(rejection_sampler):
|
|||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=logits.device,
|
device=logits.device,
|
||||||
)
|
)
|
||||||
assert torch.equal(output, expected)
|
assert torch.equal(output.sampled_token_ids, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_sequences(rejection_sampler):
|
def test_multiple_sequences(rejection_sampler):
|
||||||
@ -163,21 +182,19 @@ def test_multiple_sequences(rejection_sampler):
|
|||||||
bonus_token_tensor = torch.tensor(
|
bonus_token_tensor = torch.tensor(
|
||||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
|
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
|
||||||
)
|
)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||||
spec_tokens, device=logits.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||||
output = rejection_sampler(
|
output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=None,
|
draft_probs=None,
|
||||||
target_logits=logits,
|
logits=logits,
|
||||||
bonus_token_ids=bonus_token_tensor,
|
|
||||||
sampling_metadata=metadata,
|
sampling_metadata=metadata,
|
||||||
)
|
)
|
||||||
expected = torch.tensor(
|
expected = torch.tensor(
|
||||||
[[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
|
[[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]], dtype=torch.int, device=logits.device
|
||||||
)
|
)
|
||||||
assert torch.equal(output, expected)
|
assert torch.equal(output.sampled_token_ids, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_single_token_sequence(rejection_sampler):
|
def test_single_token_sequence(rejection_sampler):
|
||||||
@ -188,19 +205,17 @@ def test_single_token_sequence(rejection_sampler):
|
|||||||
metadata = create_sampling_metadata(all_greedy=True)
|
metadata = create_sampling_metadata(all_greedy=True)
|
||||||
logits = create_logits_tensor(output_tokens)
|
logits = create_logits_tensor(output_tokens)
|
||||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||||
spec_tokens, device=logits.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||||
output = rejection_sampler(
|
output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=None,
|
draft_probs=None,
|
||||||
target_logits=logits,
|
logits=logits,
|
||||||
bonus_token_ids=bonus_token_tensor,
|
|
||||||
sampling_metadata=metadata,
|
sampling_metadata=metadata,
|
||||||
)
|
)
|
||||||
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
|
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
|
||||||
assert torch.equal(output, expected)
|
assert torch.equal(output.sampled_token_ids, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_empty_sequence(rejection_sampler):
|
def test_empty_sequence(rejection_sampler):
|
||||||
@ -211,19 +226,17 @@ def test_empty_sequence(rejection_sampler):
|
|||||||
metadata = create_sampling_metadata(all_greedy=True)
|
metadata = create_sampling_metadata(all_greedy=True)
|
||||||
logits = create_logits_tensor(output_tokens)
|
logits = create_logits_tensor(output_tokens)
|
||||||
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], device=logits.device)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||||
spec_tokens, device=logits.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||||
output = rejection_sampler(
|
output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=None,
|
draft_probs=None,
|
||||||
target_logits=logits,
|
logits=logits,
|
||||||
bonus_token_ids=bonus_token_tensor,
|
|
||||||
sampling_metadata=metadata,
|
sampling_metadata=metadata,
|
||||||
)
|
)
|
||||||
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
|
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
|
||||||
assert torch.equal(output, expected)
|
assert torch.equal(output.sampled_token_ids, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_mismatches(rejection_sampler):
|
def test_multiple_mismatches(rejection_sampler):
|
||||||
@ -236,15 +249,13 @@ def test_multiple_mismatches(rejection_sampler):
|
|||||||
bonus_token_tensor = torch.tensor(
|
bonus_token_tensor = torch.tensor(
|
||||||
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
|
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device
|
||||||
)
|
)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||||
spec_tokens, device=logits.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||||
output = rejection_sampler(
|
output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=None,
|
draft_probs=None,
|
||||||
target_logits=logits,
|
logits=logits,
|
||||||
bonus_token_ids=bonus_token_tensor,
|
|
||||||
sampling_metadata=metadata,
|
sampling_metadata=metadata,
|
||||||
)
|
)
|
||||||
expected = torch.tensor(
|
expected = torch.tensor(
|
||||||
@ -255,7 +266,7 @@ def test_multiple_mismatches(rejection_sampler):
|
|||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=logits.device,
|
device=logits.device,
|
||||||
)
|
)
|
||||||
assert torch.equal(output, expected)
|
assert torch.equal(output.sampled_token_ids, expected)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -277,19 +288,17 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens, expec
|
|||||||
bonus_token_tensor = torch.tensor(
|
bonus_token_tensor = torch.tensor(
|
||||||
[tokens[-1] for tokens in output_tokens], device=logits.device
|
[tokens[-1] for tokens in output_tokens], device=logits.device
|
||||||
)
|
)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||||
spec_tokens, device=logits.device
|
|
||||||
)
|
|
||||||
|
|
||||||
|
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||||
output = rejection_sampler(
|
output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=None,
|
draft_probs=None,
|
||||||
target_logits=logits,
|
logits=logits,
|
||||||
bonus_token_ids=bonus_token_tensor,
|
|
||||||
sampling_metadata=metadata,
|
sampling_metadata=metadata,
|
||||||
)
|
)
|
||||||
expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
|
expected_tensor = torch.tensor(expected, dtype=torch.int, device=logits.device)
|
||||||
assert torch.equal(output, expected_tensor)
|
assert torch.equal(output.sampled_token_ids, expected_tensor)
|
||||||
|
|
||||||
|
|
||||||
########################### Tests for Random Sampling ###################
|
########################### Tests for Random Sampling ###################
|
||||||
@ -331,18 +340,19 @@ def test_deterministic_when_seeded(
|
|||||||
sampling_metadata = create_sampling_metadata(
|
sampling_metadata = create_sampling_metadata(
|
||||||
all_greedy=False, temperature=temperature, generators=seeded_seqs
|
all_greedy=False, temperature=temperature, generators=seeded_seqs
|
||||||
)
|
)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(
|
||||||
draft_token_ids.tolist(), device=DEVICE
|
draft_token_ids.tolist(), target_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mock_sampler_output(rejection_sampler, bonus_token_ids)
|
||||||
rep_result = rejection_sampler(
|
rep_result = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=draft_probs,
|
draft_probs=None,
|
||||||
target_logits=target_logits,
|
logits=target_logits,
|
||||||
bonus_token_ids=bonus_token_ids,
|
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
results.append(rep_result)
|
results.append(rep_result.sampled_token_ids)
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
if seeded_mask[i]:
|
if seeded_mask[i]:
|
||||||
@ -460,7 +470,9 @@ def estimate_rejection_sampling_pdf(
|
|||||||
Returns:
|
Returns:
|
||||||
Estimated probability distribution of the output tokens.
|
Estimated probability distribution of the output tokens.
|
||||||
"""
|
"""
|
||||||
rejection_sampler = RejectionSampler()
|
mock_sampler = Mock(spec=Sampler)
|
||||||
|
mock_sampler.logprobs_mode = "raw_logprobs"
|
||||||
|
rejection_sampler = RejectionSampler(mock_sampler)
|
||||||
num_tokens = num_samples * k
|
num_tokens = num_samples * k
|
||||||
# Repeat draft probs num_samples * k times.
|
# Repeat draft probs num_samples * k times.
|
||||||
draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
|
draft_probs = draft_probs.reshape(1, 1, vocab_size).repeat(num_samples, k, 1)
|
||||||
@ -483,17 +495,18 @@ def estimate_rejection_sampling_pdf(
|
|||||||
sampling_metadata = create_sampling_metadata(
|
sampling_metadata = create_sampling_metadata(
|
||||||
all_greedy=False, temperature=temperature
|
all_greedy=False, temperature=temperature
|
||||||
)
|
)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(
|
||||||
draft_token_ids.tolist(), device=bonus_token_ids.device
|
draft_token_ids.tolist(), target_logits
|
||||||
)
|
)
|
||||||
output_token_ids = rejection_sampler(
|
|
||||||
|
mock_sampler_output(rejection_sampler, bonus_token_ids)
|
||||||
|
sampler_output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=draft_probs,
|
draft_probs=draft_probs,
|
||||||
target_logits=target_logits,
|
logits=target_logits,
|
||||||
bonus_token_ids=bonus_token_ids,
|
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
output_token_ids = sampler_output.sampled_token_ids[:, :-1].flatten()
|
||||||
|
|
||||||
hist = torch.histogram(
|
hist = torch.histogram(
|
||||||
output_token_ids.to(dtype=torch.float, device="cpu"),
|
output_token_ids.to(dtype=torch.float, device="cpu"),
|
||||||
@ -532,22 +545,19 @@ def _test_masked_logits(
|
|||||||
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
|
bonus_token_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=DEVICE)
|
||||||
|
|
||||||
# Create spec decode metadata
|
# Create spec decode metadata
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(draft_token_ids, target_logits)
|
||||||
draft_token_ids,
|
|
||||||
device=DEVICE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Run rejection sampling
|
# Run rejection sampling
|
||||||
output_token_ids = rejection_sampler(
|
mock_sampler_output(rejection_sampler, bonus_token_ids)
|
||||||
|
output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=draft_probs,
|
draft_probs=draft_probs,
|
||||||
target_logits=target_logits,
|
logits=target_logits,
|
||||||
bonus_token_ids=bonus_token_ids,
|
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove bonus tokens and reshape
|
# Remove bonus tokens and reshape
|
||||||
output_token_ids = output_token_ids[:, :-1].flatten().tolist()
|
output_token_ids = output.sampled_token_ids[:, :-1].flatten().tolist()
|
||||||
|
|
||||||
# Check that all sampled tokens are within the unmasked indices.
|
# Check that all sampled tokens are within the unmasked indices.
|
||||||
for i in range(num_tokens):
|
for i in range(num_tokens):
|
||||||
@ -665,11 +675,11 @@ def test_frequency_penalties(rejection_sampler):
|
|||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||||
spec_tokens, device=logits.device
|
spec_tokens, device=logits.device
|
||||||
)
|
)
|
||||||
|
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||||
output = rejection_sampler(
|
output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=None,
|
draft_probs=None,
|
||||||
target_logits=logits,
|
logits=logits,
|
||||||
bonus_token_ids=bonus_token_tensor,
|
|
||||||
sampling_metadata=metadata,
|
sampling_metadata=metadata,
|
||||||
)
|
)
|
||||||
expected = torch.tensor(
|
expected = torch.tensor(
|
||||||
@ -677,7 +687,7 @@ def test_frequency_penalties(rejection_sampler):
|
|||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=logits.device,
|
device=logits.device,
|
||||||
)
|
)
|
||||||
assert torch.equal(output, expected)
|
assert torch.equal(output.sampled_token_ids, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_bad_words(rejection_sampler):
|
def test_bad_words(rejection_sampler):
|
||||||
@ -707,14 +717,12 @@ def test_bad_words(rejection_sampler):
|
|||||||
bonus_token_tensor = torch.tensor(
|
bonus_token_tensor = torch.tensor(
|
||||||
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||||
)
|
)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||||
spec_tokens, device=logits.device
|
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||||
)
|
|
||||||
output = rejection_sampler(
|
output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=None,
|
draft_probs=None,
|
||||||
target_logits=logits,
|
logits=logits,
|
||||||
bonus_token_ids=bonus_token_tensor,
|
|
||||||
sampling_metadata=metadata,
|
sampling_metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -723,7 +731,7 @@ def test_bad_words(rejection_sampler):
|
|||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=logits.device,
|
device=logits.device,
|
||||||
)
|
)
|
||||||
assert torch.equal(output, expected)
|
assert torch.equal(output.sampled_token_ids, expected)
|
||||||
|
|
||||||
|
|
||||||
def test_allowed_token_ids(rejection_sampler):
|
def test_allowed_token_ids(rejection_sampler):
|
||||||
@ -756,14 +764,12 @@ def test_allowed_token_ids(rejection_sampler):
|
|||||||
bonus_token_tensor = torch.tensor(
|
bonus_token_tensor = torch.tensor(
|
||||||
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
[output_tokens[i][-1] for i in range(len(output_tokens))], device=logits.device
|
||||||
)
|
)
|
||||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
spec_decode_metadata = create_spec_decode_metadata(spec_tokens, logits)
|
||||||
spec_tokens, device=logits.device
|
mock_sampler_output(rejection_sampler, bonus_token_tensor)
|
||||||
)
|
|
||||||
output = rejection_sampler(
|
output = rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
draft_probs=None,
|
draft_probs=None,
|
||||||
target_logits=logits,
|
logits=logits,
|
||||||
bonus_token_ids=bonus_token_tensor,
|
|
||||||
sampling_metadata=metadata,
|
sampling_metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -772,4 +778,4 @@ def test_allowed_token_ids(rejection_sampler):
|
|||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=logits.device,
|
device=logits.device,
|
||||||
)
|
)
|
||||||
assert torch.equal(output, expected)
|
assert torch.equal(output.sampled_token_ids, expected)
|
||||||
|
|||||||
@ -66,7 +66,7 @@ class LogprobsProcessor:
|
|||||||
assert self.logprobs is not None
|
assert self.logprobs is not None
|
||||||
assert self.cumulative_logprob is not None
|
assert self.cumulative_logprob is not None
|
||||||
|
|
||||||
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
|
token_ids_lst, logprobs_lst, ranks_lst, _ = logprobs_lists
|
||||||
|
|
||||||
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst):
|
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst, token_ids_lst):
|
||||||
# Detokenize (non-incrementally).
|
# Detokenize (non-incrementally).
|
||||||
|
|||||||
@ -14,34 +14,49 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
class LogprobsLists(NamedTuple):
|
class LogprobsLists(NamedTuple):
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
|
||||||
logprob_token_ids: list[list[int]]
|
logprob_token_ids: list[list[int]]
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
|
||||||
logprobs: list[list[float]]
|
logprobs: list[list[float]]
|
||||||
# [num_reqs]
|
# [num_reqs x num_generated_tokens]
|
||||||
sampled_token_ranks: list[int]
|
sampled_token_ranks: list[int]
|
||||||
|
# [num_reqs]
|
||||||
|
# Used for slicing the logprobs in cases like speculative
|
||||||
|
# decoding where the number of generated tokens may be
|
||||||
|
# different for each request.
|
||||||
|
cu_num_generated_tokens: list[int] | None = None
|
||||||
|
|
||||||
def slice(self, start: int, end: int):
|
def slice(self, start_req_idx: int, end_req_idx: int):
|
||||||
|
if self.cu_num_generated_tokens:
|
||||||
|
start = self.cu_num_generated_tokens[start_req_idx]
|
||||||
|
end = self.cu_num_generated_tokens[end_req_idx]
|
||||||
|
else:
|
||||||
|
start = start_req_idx
|
||||||
|
end = end_req_idx
|
||||||
return LogprobsLists(
|
return LogprobsLists(
|
||||||
self.logprob_token_ids[start:end],
|
self.logprob_token_ids[start:end],
|
||||||
self.logprobs[start:end],
|
self.logprobs[start:end],
|
||||||
self.sampled_token_ranks[start:end],
|
self.sampled_token_ranks[start:end],
|
||||||
|
self.cu_num_generated_tokens[start_req_idx:end_req_idx]
|
||||||
|
if self.cu_num_generated_tokens
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogprobsTensors(NamedTuple):
|
class LogprobsTensors(NamedTuple):
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
|
||||||
logprob_token_ids: torch.Tensor
|
logprob_token_ids: torch.Tensor
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs x num_generated_tokens, max_num_logprobs + 1]
|
||||||
logprobs: torch.Tensor
|
logprobs: torch.Tensor
|
||||||
# [num_reqs]
|
# [num_reqs x num_generated_tokens]
|
||||||
selected_token_ranks: torch.Tensor
|
selected_token_ranks: torch.Tensor
|
||||||
|
|
||||||
def tolists(self):
|
def tolists(self, cu_num_generated_tokens: list[int] | None = None):
|
||||||
return LogprobsLists(
|
return LogprobsLists(
|
||||||
self.logprob_token_ids.tolist(),
|
self.logprob_token_ids.tolist(),
|
||||||
self.logprobs.tolist(),
|
self.logprobs.tolist(),
|
||||||
self.selected_token_ranks.tolist(),
|
self.selected_token_ranks.tolist(),
|
||||||
|
cu_num_generated_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -1,15 +1,19 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
|
from vllm.v1.sample.ops.bad_words import apply_bad_words_with_drafts
|
||||||
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
||||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||||
|
from vllm.v1.sample.sampler import Sampler
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -44,17 +48,22 @@ class RejectionSampler(nn.Module):
|
|||||||
output tokens = accepted tokens + recovered tokens + bonus tokens
|
output tokens = accepted tokens + recovered tokens + bonus tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampler: Sampler):
|
||||||
|
super().__init__()
|
||||||
|
self.sampler = sampler
|
||||||
|
logprobs_mode = self.sampler.logprobs_mode
|
||||||
|
self.is_processed_logprobs_mode = logprobs_mode.startswith("processed")
|
||||||
|
self.is_logits_logprobs_mode = logprobs_mode.endswith("logits")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
metadata: SpecDecodeMetadata,
|
metadata: SpecDecodeMetadata,
|
||||||
# [num_tokens, vocab_size]
|
# [num_tokens, vocab_size]
|
||||||
draft_probs: torch.Tensor | None,
|
draft_probs: torch.Tensor | None,
|
||||||
# [num_tokens, vocab_size]
|
# [num_tokens + batch_size, vocab_size]
|
||||||
target_logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
# [batch_size, 1]
|
|
||||||
bonus_token_ids: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> SamplerOutput:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
metadata:
|
metadata:
|
||||||
@ -63,43 +72,65 @@ class RejectionSampler(nn.Module):
|
|||||||
Probability distribution for the draft tokens. Shape is
|
Probability distribution for the draft tokens. Shape is
|
||||||
[num_tokens, vocab_size]. Can be None if probabilities are
|
[num_tokens, vocab_size]. Can be None if probabilities are
|
||||||
not provided, which is the case for ngram spec decode.
|
not provided, which is the case for ngram spec decode.
|
||||||
target_logits (torch.Tensor):
|
logits (torch.Tensor):
|
||||||
Target model's logits probability distribution.
|
Target model's logits probability distribution.
|
||||||
Shape is [num_tokens, vocab_size]. Here, probabilities from
|
Shape is [num_tokens + batch_size, vocab_size]. Here,
|
||||||
different requests are flattened into a single tensor because
|
probabilities from different requests are flattened into a
|
||||||
this is the shape of the output logits.
|
single tensor because this is the shape of the output logits.
|
||||||
NOTE: `target_logits` can be updated in place to save memory.
|
NOTE: `logits` can be updated in place to save memory.
|
||||||
bonus_token_ids (torch.Tensor):
|
|
||||||
A tensor containing bonus tokens. Shape is [batch_size, 1].
|
|
||||||
Bonus tokens are added to the end of the sequence if all
|
|
||||||
proposed tokens are accepted. We generate the bonus tokens
|
|
||||||
outside of the rejection sampler with the default sampling
|
|
||||||
strategy. It allows for more flexibility in the sampling
|
|
||||||
process such as top_p, top_k sampling.
|
|
||||||
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
|
sampling_metadata (vllm.v1.sample.metadata.SamplingMetadata):
|
||||||
Additional metadata needed for sampling, such as temperature,
|
Additional metadata needed for sampling, such as temperature,
|
||||||
top-k/top-p parameters, or other relevant information.
|
top-k/top-p parameters, or other relevant information.
|
||||||
Returns:
|
Returns:
|
||||||
output_token_ids (torch.Tensor):
|
SamplerOutput:
|
||||||
A tensor containing the final output token IDs.
|
Contains the final output token IDs and their logprobs if
|
||||||
|
requested.
|
||||||
"""
|
"""
|
||||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||||
|
|
||||||
# Use float32 for the target_logits.
|
bonus_logits_indices = metadata.bonus_logits_indices
|
||||||
target_logits = target_logits.to(torch.float32)
|
target_logits_indices = metadata.target_logits_indices
|
||||||
|
|
||||||
target_logits = self.apply_logits_processors(
|
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
||||||
target_logits, sampling_metadata, metadata
|
# creates a new tensor with separate storage from the original
|
||||||
|
# logits tensor. This means any in-place operations on bonus_logits
|
||||||
|
# won't affect the original logits tensor.
|
||||||
|
assert logits is not None
|
||||||
|
bonus_logits = logits[bonus_logits_indices]
|
||||||
|
bonus_sampler_output = self.sampler(
|
||||||
|
logits=bonus_logits,
|
||||||
|
sampling_metadata=replace(
|
||||||
|
sampling_metadata,
|
||||||
|
max_num_logprobs=-1,
|
||||||
|
),
|
||||||
|
predict_bonus_token=True,
|
||||||
|
# Override the logprobs mode to return logits because they are
|
||||||
|
# needed later to compute the accepted token logprobs.
|
||||||
|
logprobs_mode_override="processed_logits"
|
||||||
|
if self.is_processed_logprobs_mode
|
||||||
|
else "raw_logits",
|
||||||
)
|
)
|
||||||
|
bonus_token_ids = bonus_sampler_output.sampled_token_ids
|
||||||
|
|
||||||
|
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
||||||
|
# separate storage from the original `logits` tensor. Therefore,
|
||||||
|
# it is safe to update `target_logits` in place.
|
||||||
|
raw_target_logits = logits[target_logits_indices]
|
||||||
|
# Use float32 for the target_logits.
|
||||||
|
raw_target_logits = raw_target_logits.to(torch.float32)
|
||||||
|
target_logits = self.apply_logits_processors(
|
||||||
|
raw_target_logits, sampling_metadata, metadata
|
||||||
|
)
|
||||||
# [num_tokens, vocab_size]
|
# [num_tokens, vocab_size]
|
||||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||||
# `compute_probs` function.
|
# `apply_sampling_constraints` function.
|
||||||
target_probs = compute_probs(
|
target_logits = apply_sampling_constraints(
|
||||||
target_logits,
|
target_logits,
|
||||||
metadata.cu_num_draft_tokens,
|
metadata.cu_num_draft_tokens,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
)
|
)
|
||||||
|
# Compute probability distribution from target logits.
|
||||||
|
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
|
|
||||||
output_token_ids = rejection_sample(
|
output_token_ids = rejection_sample(
|
||||||
metadata.draft_token_ids,
|
metadata.draft_token_ids,
|
||||||
@ -111,7 +142,63 @@ class RejectionSampler(nn.Module):
|
|||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
)
|
)
|
||||||
return output_token_ids
|
|
||||||
|
logprobs_tensors = None
|
||||||
|
if sampling_metadata.max_num_logprobs:
|
||||||
|
logprobs_tensors = self._get_logprobs_tensors(
|
||||||
|
sampling_metadata.max_num_logprobs,
|
||||||
|
metadata,
|
||||||
|
logits,
|
||||||
|
target_logits if self.is_processed_logprobs_mode else raw_target_logits,
|
||||||
|
bonus_sampler_output.logprobs_tensors.logprobs,
|
||||||
|
output_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
return SamplerOutput(
|
||||||
|
sampled_token_ids=output_token_ids,
|
||||||
|
logprobs_tensors=logprobs_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_logprobs_tensors(
|
||||||
|
self,
|
||||||
|
max_num_logprobs: int,
|
||||||
|
metadata: SpecDecodeMetadata,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
target_logits: torch.Tensor,
|
||||||
|
bonus_logits: torch.Tensor,
|
||||||
|
sampled_token_ids: torch.Tensor,
|
||||||
|
) -> LogprobsTensors:
|
||||||
|
cu_num_sampled_tokens = torch.zeros_like(metadata.cu_num_sampled_tokens)
|
||||||
|
cu_num_sampled_tokens[1:] = metadata.cu_num_sampled_tokens[:-1]
|
||||||
|
|
||||||
|
# Collect target and bonus logits.
|
||||||
|
bonus_logits_indices = metadata.bonus_logits_indices
|
||||||
|
target_logits_indices = metadata.target_logits_indices
|
||||||
|
final_logits = torch.zeros_like(logits, dtype=torch.float32)
|
||||||
|
final_logits[target_logits_indices] = target_logits.to(torch.float32)
|
||||||
|
final_logits[bonus_logits_indices] = bonus_logits.to(torch.float32)
|
||||||
|
|
||||||
|
# Compute accepted token indices.
|
||||||
|
accepted_mask = sampled_token_ids != PLACEHOLDER_TOKEN_ID
|
||||||
|
num_accepted_tokens = accepted_mask.sum(dim=-1)
|
||||||
|
accepted_logit_indices = accepted_mask.nonzero(as_tuple=True)[1]
|
||||||
|
accepted_logit_indices += cu_num_sampled_tokens.repeat_interleave(
|
||||||
|
num_accepted_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute logprobs for accepted tokens.
|
||||||
|
accepted_logits = final_logits[accepted_logit_indices]
|
||||||
|
accepted_logprobs = (
|
||||||
|
accepted_logits
|
||||||
|
if self.is_logits_logprobs_mode
|
||||||
|
else self.sampler.compute_logprobs(accepted_logits)
|
||||||
|
)
|
||||||
|
accepted_tokens = sampled_token_ids[accepted_mask]
|
||||||
|
return self.sampler.gather_logprobs(
|
||||||
|
accepted_logprobs,
|
||||||
|
max_num_logprobs,
|
||||||
|
accepted_tokens.to(torch.int64),
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_output(
|
def parse_output(
|
||||||
@ -119,14 +206,12 @@ class RejectionSampler(nn.Module):
|
|||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
"""Parse the output of the rejection sampler.
|
"""Parse the output of the rejection sampler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_token_ids: The sampled token IDs in shape
|
output_token_ids: The sampled token IDs in shape
|
||||||
[batch_size, max_spec_len + 1]. The rejected tokens are
|
[batch_size, max_spec_len + 1]. The rejected tokens are
|
||||||
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
|
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
|
||||||
and will be filtered out in this function.
|
and will be filtered out in this function.
|
||||||
vocab_size: The size of the vocabulary.
|
vocab_size: The size of the vocabulary.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of lists of token IDs.
|
A list of lists of token IDs.
|
||||||
"""
|
"""
|
||||||
@ -328,27 +413,26 @@ def rejection_sample(
|
|||||||
return output_token_ids
|
return output_token_ids
|
||||||
|
|
||||||
|
|
||||||
def compute_probs(
|
def apply_sampling_constraints(
|
||||||
logits: torch.Tensor, # [num_tokens, vocab_size]
|
logits: torch.Tensor, # [num_tokens, vocab_size]
|
||||||
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
cu_num_draft_tokens: torch.Tensor, # [batch_size]
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Compute probability distribution from logits based on sampling metadata.
|
"""Process logits based on sampling metadata.
|
||||||
|
|
||||||
This function applies temperature scaling to the logits and converts
|
This function applies temperature scaling to the logits,
|
||||||
them to probabilities using softmax. For greedy decoding, it returns
|
as well as top-k and top-p. For greedy decoding, it returns
|
||||||
the original logits.
|
the original logits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits: Input logits tensor to be converted to probabilities.
|
logits: Input logits tensor to be processed.
|
||||||
cu_num_draft_tokens: Cumulative number of draft tokens.
|
cu_num_draft_tokens: Cumulative number of draft tokens.
|
||||||
sampling_metadata: Metadata containing sampling parameters such as
|
sampling_metadata: Metadata containing sampling parameters such as
|
||||||
temperature and whether greedy sampling is used.
|
temperature and whether greedy sampling is used.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Probability distribution (softmax of scaled logits)
|
torch.Tensor: Processed logits if non-greedy sampling is used,
|
||||||
if non-greedy sampling is used, otherwise returns the
|
otherwise returns the original logits.
|
||||||
original logits.
|
|
||||||
"""
|
"""
|
||||||
assert logits.ndim == 2
|
assert logits.ndim == 2
|
||||||
assert cu_num_draft_tokens.ndim == 1
|
assert cu_num_draft_tokens.ndim == 1
|
||||||
@ -384,9 +468,7 @@ def compute_probs(
|
|||||||
|
|
||||||
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
|
||||||
# which is slow for large vocab sizes. This may cause performance issues.
|
# which is slow for large vocab sizes. This may cause performance issues.
|
||||||
logits = apply_top_k_top_p(logits, top_k, top_p)
|
return apply_top_k_top_p(logits, top_k, top_p)
|
||||||
output_prob = logits.softmax(dim=-1, dtype=torch.float32)
|
|
||||||
return output_prob
|
|
||||||
|
|
||||||
|
|
||||||
def expand_batch_to_tokens(
|
def expand_batch_to_tokens(
|
||||||
|
|||||||
@ -69,16 +69,18 @@ class Sampler(nn.Module):
|
|||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
predict_bonus_token: bool = False,
|
predict_bonus_token: bool = False,
|
||||||
|
logprobs_mode_override: LogprobsMode | None = None,
|
||||||
) -> SamplerOutput:
|
) -> SamplerOutput:
|
||||||
|
logprobs_mode = logprobs_mode_override or self.logprobs_mode
|
||||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||||
# temperature scaling) for the top-k logprobs.
|
# temperature scaling) for the top-k logprobs.
|
||||||
# This is different from the V0 sampler, which uses the logits that
|
# This is different from the V0 sampler, which uses the logits that
|
||||||
# is used for sampling (after penalties and temperature scaling).
|
# is used for sampling (after penalties and temperature scaling).
|
||||||
num_logprobs = sampling_metadata.max_num_logprobs
|
num_logprobs = sampling_metadata.max_num_logprobs
|
||||||
if num_logprobs is not None:
|
if num_logprobs is not None:
|
||||||
if self.logprobs_mode == "raw_logprobs":
|
if logprobs_mode == "raw_logprobs":
|
||||||
raw_logprobs = self.compute_logprobs(logits)
|
raw_logprobs = self.compute_logprobs(logits)
|
||||||
elif self.logprobs_mode == "raw_logits":
|
elif logprobs_mode == "raw_logits":
|
||||||
raw_logprobs = logits.clone()
|
raw_logprobs = logits.clone()
|
||||||
|
|
||||||
# Use float32 for the logits.
|
# Use float32 for the logits.
|
||||||
@ -97,12 +99,17 @@ class Sampler(nn.Module):
|
|||||||
# return int32 (while PyTorch argmax and topk return int64).
|
# return int32 (while PyTorch argmax and topk return int64).
|
||||||
sampled = sampled.long()
|
sampled = sampled.long()
|
||||||
|
|
||||||
# Gather the logprobs of the topk and sampled token (if requested).
|
if num_logprobs is None:
|
||||||
# Get logprobs and rank tensors (if requested)
|
logprobs_tensors = None
|
||||||
logprobs_tensors = (
|
elif num_logprobs == -1:
|
||||||
None
|
# Return the full unsorted and unranked logprobs.
|
||||||
if num_logprobs is None
|
logprobs_tensors = LogprobsTensors(
|
||||||
else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
|
torch.empty(0), raw_logprobs, torch.empty(0)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Gather the logprobs and ranks of the topk and sampled token.
|
||||||
|
logprobs_tensors = self.gather_logprobs(
|
||||||
|
raw_logprobs, num_logprobs, token_ids=sampled
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use int32 to reduce the tensor size.
|
# Use int32 to reduce the tensor size.
|
||||||
@ -138,6 +145,7 @@ class Sampler(nn.Module):
|
|||||||
self,
|
self,
|
||||||
logits: torch.Tensor,
|
logits: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
|
logprobs_mode_override: LogprobsMode | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
"""Sample logits based on sampling metadata.
|
"""Sample logits based on sampling metadata.
|
||||||
|
|
||||||
@ -145,6 +153,7 @@ class Sampler(nn.Module):
|
|||||||
may update the logits tensor in-place.
|
may update the logits tensor in-place.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
logprobs_mode = logprobs_mode_override or self.logprobs_mode
|
||||||
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
|
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
|
||||||
if sampling_metadata.all_random:
|
if sampling_metadata.all_random:
|
||||||
greedy_sampled = None
|
greedy_sampled = None
|
||||||
@ -153,9 +162,9 @@ class Sampler(nn.Module):
|
|||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_greedy:
|
||||||
processed_logprobs = None
|
processed_logprobs = None
|
||||||
if sampling_metadata.max_num_logprobs is not None:
|
if sampling_metadata.max_num_logprobs is not None:
|
||||||
if self.logprobs_mode == "processed_logits":
|
if logprobs_mode == "processed_logits":
|
||||||
processed_logprobs = logits
|
processed_logprobs = logits
|
||||||
elif self.logprobs_mode == "processed_logprobs":
|
elif logprobs_mode == "processed_logprobs":
|
||||||
processed_logprobs = self.compute_logprobs(logits)
|
processed_logprobs = self.compute_logprobs(logits)
|
||||||
return greedy_sampled, processed_logprobs
|
return greedy_sampled, processed_logprobs
|
||||||
|
|
||||||
|
|||||||
@ -14,6 +14,8 @@ class SpecDecodeMetadata:
|
|||||||
num_draft_tokens: list[int]
|
num_draft_tokens: list[int]
|
||||||
# [batch_size]
|
# [batch_size]
|
||||||
cu_num_draft_tokens: torch.Tensor
|
cu_num_draft_tokens: torch.Tensor
|
||||||
|
# [batch_size]
|
||||||
|
cu_num_sampled_tokens: torch.Tensor
|
||||||
# [num_tokens]
|
# [num_tokens]
|
||||||
target_logits_indices: torch.Tensor
|
target_logits_indices: torch.Tensor
|
||||||
# [batch_size]
|
# [batch_size]
|
||||||
@ -32,6 +34,7 @@ class SpecDecodeMetadata:
|
|||||||
) -> "SpecDecodeMetadata":
|
) -> "SpecDecodeMetadata":
|
||||||
batch_size = len(draft_token_ids)
|
batch_size = len(draft_token_ids)
|
||||||
num_draft_tokens = [len(ids) for ids in draft_token_ids]
|
num_draft_tokens = [len(ids) for ids in draft_token_ids]
|
||||||
|
num_sampled_tokens = [len(ids) + 1 for ids in draft_token_ids]
|
||||||
flattened_draft_token_ids = sum(draft_token_ids, [])
|
flattened_draft_token_ids = sum(draft_token_ids, [])
|
||||||
num_tokens = len(flattened_draft_token_ids)
|
num_tokens = len(flattened_draft_token_ids)
|
||||||
|
|
||||||
@ -40,6 +43,10 @@ class SpecDecodeMetadata:
|
|||||||
)
|
)
|
||||||
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
cu_num_draft_tokens = np.cumsum(num_draft_tokens, dtype=np.int32)
|
||||||
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
|
cu_num_draft_tokens_tensor = torch.from_numpy(cu_num_draft_tokens).to(device)
|
||||||
|
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens, dtype=np.int32)
|
||||||
|
cu_num_sampled_tokens_tensor = torch.from_numpy(cu_num_sampled_tokens).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
target_logits_indices = torch.zeros(
|
target_logits_indices = torch.zeros(
|
||||||
num_tokens, dtype=torch.int32, device=device
|
num_tokens, dtype=torch.int32, device=device
|
||||||
@ -52,6 +59,7 @@ class SpecDecodeMetadata:
|
|||||||
draft_token_ids=draft_token_ids_tensor,
|
draft_token_ids=draft_token_ids_tensor,
|
||||||
num_draft_tokens=num_draft_tokens,
|
num_draft_tokens=num_draft_tokens,
|
||||||
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
|
cu_num_draft_tokens=cu_num_draft_tokens_tensor,
|
||||||
|
cu_num_sampled_tokens=cu_num_sampled_tokens_tensor,
|
||||||
target_logits_indices=target_logits_indices,
|
target_logits_indices=target_logits_indices,
|
||||||
bonus_logits_indices=bonus_logits_indices,
|
bonus_logits_indices=bonus_logits_indices,
|
||||||
logits_indices=logits_indices,
|
logits_indices=logits_indices,
|
||||||
|
|||||||
@ -327,7 +327,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
"Unknown speculative decoding method: "
|
"Unknown speculative decoding method: "
|
||||||
f"{self.speculative_config.method}"
|
f"{self.speculative_config.method}"
|
||||||
)
|
)
|
||||||
self.rejection_sampler = RejectionSampler()
|
self.rejection_sampler = RejectionSampler(self.sampler)
|
||||||
|
|
||||||
# Request states.
|
# Request states.
|
||||||
self.requests: dict[str, CachedRequestState] = {}
|
self.requests: dict[str, CachedRequestState] = {}
|
||||||
@ -1624,6 +1624,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
cu_num_draft_tokens = torch.from_numpy(cu_num_draft_tokens).to(
|
||||||
self.device, non_blocking=True
|
self.device, non_blocking=True
|
||||||
)
|
)
|
||||||
|
cu_num_sampled_tokens = torch.from_numpy(cu_num_sampled_tokens).to(
|
||||||
|
self.device, non_blocking=True
|
||||||
|
)
|
||||||
logits_indices = torch.from_numpy(logits_indices).to(
|
logits_indices = torch.from_numpy(logits_indices).to(
|
||||||
self.device, non_blocking=True
|
self.device, non_blocking=True
|
||||||
)
|
)
|
||||||
@ -1639,15 +1642,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
draft_token_ids = self.input_ids.gpu[logits_indices]
|
draft_token_ids = self.input_ids.gpu[logits_indices]
|
||||||
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
draft_token_ids = draft_token_ids[target_logits_indices + 1]
|
||||||
|
|
||||||
metadata = SpecDecodeMetadata(
|
return SpecDecodeMetadata(
|
||||||
draft_token_ids=draft_token_ids,
|
draft_token_ids=draft_token_ids,
|
||||||
num_draft_tokens=num_draft_tokens.tolist(),
|
num_draft_tokens=num_draft_tokens.tolist(),
|
||||||
cu_num_draft_tokens=cu_num_draft_tokens,
|
cu_num_draft_tokens=cu_num_draft_tokens,
|
||||||
|
cu_num_sampled_tokens=cu_num_sampled_tokens,
|
||||||
target_logits_indices=target_logits_indices,
|
target_logits_indices=target_logits_indices,
|
||||||
bonus_logits_indices=bonus_logits_indices,
|
bonus_logits_indices=bonus_logits_indices,
|
||||||
logits_indices=logits_indices,
|
logits_indices=logits_indices,
|
||||||
)
|
)
|
||||||
return metadata
|
|
||||||
|
|
||||||
def _prepare_kv_sharing_fast_prefill(
|
def _prepare_kv_sharing_fast_prefill(
|
||||||
self,
|
self,
|
||||||
@ -2221,32 +2224,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# When indexing with a tensor (bonus_logits_indices), PyTorch
|
sampler_output = self.rejection_sampler(
|
||||||
# creates a new tensor with separate storage from the original
|
|
||||||
# logits tensor. This means any in-place operations on bonus_logits
|
|
||||||
# won't affect the original logits tensor.
|
|
||||||
assert logits is not None
|
|
||||||
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
|
|
||||||
sampler_output = self.sampler(
|
|
||||||
logits=bonus_logits,
|
|
||||||
sampling_metadata=sampling_metadata,
|
|
||||||
predict_bonus_token=True,
|
|
||||||
)
|
|
||||||
bonus_token_ids = sampler_output.sampled_token_ids
|
|
||||||
|
|
||||||
# Just like `bonus_logits`, `target_logits` is a new tensor with
|
|
||||||
# separate storage from the original `logits` tensor. Therefore,
|
|
||||||
# it is safe to update `target_logits` in place.
|
|
||||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
|
||||||
output_token_ids = self.rejection_sampler(
|
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
None, # draft_probs
|
None, # draft_probs
|
||||||
target_logits,
|
logits,
|
||||||
bonus_token_ids,
|
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
)
|
)
|
||||||
sampler_output.sampled_token_ids = output_token_ids
|
self._update_states_after_model_execute(sampler_output.sampled_token_ids)
|
||||||
self._update_states_after_model_execute(output_token_ids)
|
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|
||||||
def _bookkeeping_sync(
|
def _bookkeeping_sync(
|
||||||
@ -2256,6 +2240,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logits: torch.Tensor | None,
|
logits: torch.Tensor | None,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
num_scheduled_tokens: int,
|
num_scheduled_tokens: int,
|
||||||
|
spec_decode_metadata: SpecDecodeMetadata | None,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
dict[str, int],
|
dict[str, int],
|
||||||
LogprobsLists | None,
|
LogprobsLists | None,
|
||||||
@ -2282,19 +2267,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
req_ids_output_copy = self.input_batch.req_ids.copy()
|
req_ids_output_copy = self.input_batch.req_ids.copy()
|
||||||
req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
|
req_id_to_index_output_copy = self.input_batch.req_id_to_index.copy()
|
||||||
|
|
||||||
# NOTE: GPU -> CPU Sync happens here.
|
|
||||||
# Move as many CPU operations as possible before this sync point.
|
|
||||||
logprobs_tensors = sampler_output.logprobs_tensors
|
|
||||||
logprobs_lists = (
|
|
||||||
logprobs_tensors.tolists() if logprobs_tensors is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute prompt logprobs if needed.
|
|
||||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
|
||||||
hidden_states[:num_scheduled_tokens],
|
|
||||||
scheduler_output.num_scheduled_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
num_sampled_tokens = sampler_output.sampled_token_ids.shape[0]
|
||||||
sampled_token_ids = sampler_output.sampled_token_ids
|
sampled_token_ids = sampler_output.sampled_token_ids
|
||||||
invalid_req_indices = []
|
invalid_req_indices = []
|
||||||
@ -2335,6 +2307,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# the sampled tokens back, because there's no direct communication
|
# the sampled tokens back, because there's no direct communication
|
||||||
# between the first-stage worker and the last-stage worker.
|
# between the first-stage worker and the last-stage worker.
|
||||||
req_ids = self.input_batch.req_ids
|
req_ids = self.input_batch.req_ids
|
||||||
|
logprobs_tensors = sampler_output.logprobs_tensors
|
||||||
|
cu_num_accepted_tokens = (
|
||||||
|
[0] if spec_decode_metadata and logprobs_tensors else None
|
||||||
|
)
|
||||||
for req_idx in range(num_sampled_tokens):
|
for req_idx in range(num_sampled_tokens):
|
||||||
if self.use_async_scheduling:
|
if self.use_async_scheduling:
|
||||||
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
|
sampled_ids = [-1] if req_idx not in invalid_req_indices_set else None
|
||||||
@ -2360,6 +2336,25 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
req_state.output_token_ids.extend(sampled_ids)
|
req_state.output_token_ids.extend(sampled_ids)
|
||||||
|
|
||||||
|
if cu_num_accepted_tokens is not None:
|
||||||
|
cu_num_accepted_tokens.append(
|
||||||
|
cu_num_accepted_tokens[-1] + len(sampled_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: GPU -> CPU Sync happens here.
|
||||||
|
# Move as many CPU operations as possible before this sync point.
|
||||||
|
logprobs_lists = (
|
||||||
|
logprobs_tensors.tolists(cu_num_accepted_tokens)
|
||||||
|
if logprobs_tensors is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute prompt logprobs if needed.
|
||||||
|
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||||
|
hidden_states[:num_scheduled_tokens],
|
||||||
|
scheduler_output.num_scheduled_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
num_nans_in_logits,
|
num_nans_in_logits,
|
||||||
logprobs_lists,
|
logprobs_lists,
|
||||||
@ -2644,6 +2639,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
logits,
|
logits,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
num_scheduled_tokens,
|
num_scheduled_tokens,
|
||||||
|
spec_decode_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -3560,20 +3556,16 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# num_tokens, logits.shape[-1], device=self.device,
|
# num_tokens, logits.shape[-1], device=self.device,
|
||||||
# dtype=logits.dtype)
|
# dtype=logits.dtype)
|
||||||
draft_probs = None
|
draft_probs = None
|
||||||
target_logits = torch.randn(
|
logits = torch.randn(
|
||||||
num_tokens, logits.shape[-1], device=self.device, dtype=logits.dtype
|
num_tokens + num_reqs,
|
||||||
)
|
logits.shape[-1],
|
||||||
# NOTE(woosuk): Here, we should use int32 because the sampler uses
|
device=self.device,
|
||||||
# int32 for bonus_token_ids. If the dtype mismatches, re-compilation
|
dtype=logits.dtype,
|
||||||
# will occur at runtime.
|
|
||||||
bonus_token_ids = torch.zeros(
|
|
||||||
num_reqs, device=self.device, dtype=torch.int32
|
|
||||||
)
|
)
|
||||||
self.rejection_sampler(
|
self.rejection_sampler(
|
||||||
dummy_spec_decode_metadata,
|
dummy_spec_decode_metadata,
|
||||||
draft_probs,
|
draft_probs,
|
||||||
target_logits,
|
logits,
|
||||||
bonus_token_ids,
|
|
||||||
dummy_metadata,
|
dummy_metadata,
|
||||||
)
|
)
|
||||||
return sampler_output
|
return sampler_output
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user