mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-11 21:47:11 +08:00
[V1][Spec Decode] Change Spec Decode Rejection Sampling API (#13729)
This commit is contained in:
parent
9ba28043b5
commit
5629f26df7
@ -29,7 +29,6 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
||||
temperature=torch.tensor([]),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
spec_token_ids=spec_tokens,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
min_p=torch.empty(batch_size, ),
|
||||
@ -55,7 +54,7 @@ def test_perfect_match(sampler):
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
output = sampler(spec_tokens, logits, metadata)
|
||||
expected = torch.tensor([[1, 2, 3, 4]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
@ -70,7 +69,7 @@ def test_early_mismatch(sampler):
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
output = sampler(spec_tokens, logits, metadata)
|
||||
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
@ -85,7 +84,7 @@ def test_multiple_sequences(sampler):
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
output = sampler(spec_tokens, logits, metadata)
|
||||
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
@ -100,7 +99,7 @@ def test_single_token_sequence(sampler):
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
output = sampler(spec_tokens, logits, metadata)
|
||||
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
@ -113,7 +112,7 @@ def test_empty_sequence(sampler):
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
output = sampler(spec_tokens, logits, metadata)
|
||||
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
@ -126,7 +125,7 @@ def test_multiple_mismatches(sampler):
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
output = sampler(spec_tokens, logits, metadata)
|
||||
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID],
|
||||
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
@ -147,7 +146,7 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
output = sampler(spec_tokens, logits, metadata)
|
||||
expected_tensor = torch.tensor(expected,
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
@ -163,7 +162,7 @@ def test_logits_shape_handling(sampler):
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens, vocab_size)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
output = sampler(spec_tokens, logits, metadata)
|
||||
expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
assert logits.shape[-1] == vocab_size
|
||||
|
||||
@ -105,7 +105,6 @@ def _create_default_sampling_metadata(
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||
vocab_size, device),
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=None,
|
||||
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
|
||||
|
||||
@ -123,7 +123,6 @@ def _construct_expected_sampling_metadata(
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=None,
|
||||
min_tokens=min_tokens,
|
||||
no_penalties=(all(x == 0 for x in presence_penalties)
|
||||
and all(x == 0 for x in frequency_penalties)
|
||||
|
||||
@ -13,9 +13,6 @@ class SamplingMetadata:
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
|
||||
# None when there are no speculated tokens.
|
||||
spec_token_ids: Optional[List[List[int]]]
|
||||
|
||||
top_p: Optional[torch.Tensor]
|
||||
top_k: Optional[torch.Tensor]
|
||||
min_p: Optional[torch.Tensor]
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
@ -52,62 +54,62 @@ class RejectionSampler(nn.Module):
|
||||
else:
|
||||
self.forward_method = self.forward_native
|
||||
|
||||
def forward(self, logits: torch.Tensor,
|
||||
def forward(self, draft_token_ids: List[List[int]],
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
if not sampling_metadata.all_greedy:
|
||||
raise NotImplementedError(
|
||||
"Currently, only greedy sampling is supported by "
|
||||
"rejection sampler.")
|
||||
return self.forward_method(logits, sampling_metadata)
|
||||
return self.forward_method(draft_token_ids, target_probs,
|
||||
sampling_metadata)
|
||||
|
||||
def flashinfer_sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
draft_token_ids: List[List[int]],
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# NOTE: The following input preparationg can be moved
|
||||
# to the model runner with a persistent manner for better
|
||||
# performance.
|
||||
assert sampling_metadata.spec_token_ids is not None
|
||||
spec_token_ids = sampling_metadata.spec_token_ids
|
||||
max_spec_len = max(len(s) for s in spec_token_ids)
|
||||
batch_size = len(spec_token_ids)
|
||||
draft_token_ids = torch.full((batch_size, max_spec_len),
|
||||
INVALID_TOKEN_ID,
|
||||
device="cpu",
|
||||
dtype=torch.long)
|
||||
sample_lens = [len(x) + 1 for x in draft_token_ids]
|
||||
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
|
||||
draft_token_ids = [
|
||||
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
|
||||
]
|
||||
draft_token_ids_tensor = pad_sequence(draft_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
target_token_ids = torch.full((batch_size, max_spec_len + 1),
|
||||
fill_value=INVALID_TOKEN_ID,
|
||||
device=logits.device,
|
||||
dtype=torch.long)
|
||||
if sampling_metadata.all_greedy:
|
||||
target_token_ids = target_probs.argmax(dim=-1).view(-1)
|
||||
target_token_ids = target_token_ids.split(sample_lens)
|
||||
target_token_ids = pad_sequence(target_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
# TODO: Vectorize the following loop for better performance.
|
||||
start_loc = 0
|
||||
for i in range(batch_size):
|
||||
num_spec_tokens = len(spec_token_ids[i])
|
||||
draft_token_ids[i, :num_spec_tokens] = torch.tensor(
|
||||
spec_token_ids[i], device="cpu", dtype=torch.long)
|
||||
end_loc = start_loc + num_spec_tokens + 1
|
||||
# Assume greedy sampling.
|
||||
target_token_ids[i, :num_spec_tokens + 1] = torch.argmax(
|
||||
logits[start_loc:end_loc], dim=-1)
|
||||
start_loc = end_loc
|
||||
|
||||
vocab_size = logits.size(-1)
|
||||
# NOTE: CPU <-> GPU synchronization happens here.
|
||||
draft_token_ids = draft_token_ids.to(logits.device)
|
||||
draft_probs = _create_greedy_token_probs(draft_token_ids, vocab_size,
|
||||
logits.device)
|
||||
target_probs = _create_greedy_token_probs(target_token_ids, vocab_size,
|
||||
logits.device)
|
||||
uniform_samples = torch.zeros(batch_size,
|
||||
max_spec_len + 1,
|
||||
device=logits.device)
|
||||
vocab_size = target_probs.size(-1)
|
||||
# NOTE: CPU <-> GPU synchronization happens here.
|
||||
draft_token_ids_tensor = draft_token_ids_tensor.to(
|
||||
target_probs.device)
|
||||
draft_probs = _create_greedy_token_probs(draft_token_ids_tensor,
|
||||
vocab_size,
|
||||
target_probs.device)
|
||||
target_probs = _create_greedy_token_probs(target_token_ids,
|
||||
vocab_size,
|
||||
target_probs.device)
|
||||
uniform_samples = torch.zeros(draft_token_ids_tensor.size(0),
|
||||
draft_token_ids_tensor.size(1) + 1,
|
||||
device=target_probs.device)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently, only greedy sampling is supported by "
|
||||
"rejection sampler.")
|
||||
|
||||
sampled_token_ids, _, _ = fs.chain_speculative_sampling(
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
draft_token_ids_tensor,
|
||||
uniform_samples,
|
||||
target_probs,
|
||||
)
|
||||
@ -117,35 +119,35 @@ class RejectionSampler(nn.Module):
|
||||
# TODO: The following method can be optimized for better performance.
|
||||
def forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
draft_token_ids: List[List[int]],
|
||||
target_probs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
assert sampling_metadata.spec_token_ids is not None
|
||||
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
|
||||
# Add 1 to include the 'bonus' token.
|
||||
sample_lens = [x + 1 for x in spec_lens]
|
||||
|
||||
output_token_ids = logits.argmax(dim=-1).view(-1)
|
||||
output_token_ids = output_token_ids.split(sample_lens)
|
||||
output_token_ids = pad_sequence(output_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
|
||||
spec_token_ids = [
|
||||
torch.tensor(x,
|
||||
dtype=output_token_ids.dtype,
|
||||
device=output_token_ids.device)
|
||||
for x in sampling_metadata.spec_token_ids
|
||||
sample_lens = [len(x) + 1 for x in draft_token_ids]
|
||||
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
|
||||
draft_token_ids = [
|
||||
torch.tensor(x, dtype=int, device='cpu') for x in draft_token_ids
|
||||
]
|
||||
spec_token_ids = pad_sequence(spec_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
# Produce a mask that remains 1 (True) until the first
|
||||
# mismatch (cumprod turns 0 after a mismatch).
|
||||
accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod(
|
||||
dim=1)
|
||||
draft_token_ids_tensor = pad_sequence(draft_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
draft_token_ids_tensor = draft_token_ids_tensor.to(target_probs.device)
|
||||
# Add 1 to include the 'bonus' token.
|
||||
if sampling_metadata.all_greedy:
|
||||
output_token_ids = target_probs.argmax(dim=-1).view(-1)
|
||||
output_token_ids = output_token_ids.split(sample_lens)
|
||||
output_token_ids = pad_sequence(output_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
# Produce a mask that remains 1 (True) until the first
|
||||
# mismatch (cumprod turns 0 after a mismatch).
|
||||
accept_mask = (
|
||||
output_token_ids[:, :-1] == draft_token_ids_tensor).cumprod(
|
||||
dim=1)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently, only greedy sampling is supported by "
|
||||
"rejection sampler.")
|
||||
# Identify valid positions (non-padding).
|
||||
valid_mask = output_token_ids != INVALID_TOKEN_ID
|
||||
# Generate mask with bonus token.
|
||||
|
||||
@ -9,7 +9,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
|
||||
apply_min_token_penalties)
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -19,22 +18,12 @@ class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
if sampling_metadata.spec_token_ids:
|
||||
if sampling_metadata.max_num_logprobs:
|
||||
raise NotImplementedError(
|
||||
"Rejection sampling does not support logprobs.")
|
||||
return self.rejection_sampler(
|
||||
logits,
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
@ -127,6 +116,14 @@ class Sampler(nn.Module):
|
||||
)
|
||||
return sampled
|
||||
|
||||
def compute_probs(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
if sampling_metadata.all_greedy:
|
||||
return logits
|
||||
# Apply temperature. This is an in-place op changing logits.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
return logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
|
||||
@ -490,23 +490,12 @@ class InputBatch:
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(List[List[int]], self.req_output_token_ids),
|
||||
spec_token_ids=None,
|
||||
min_tokens=self.min_tokens,
|
||||
no_penalties=self.no_penalties,
|
||||
logit_bias=self.logit_bias[:num_reqs],
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
)
|
||||
|
||||
def get_sampling_metadata(
|
||||
self,
|
||||
req_id_to_spec_token_ids: Dict[str, List[int]],
|
||||
) -> SamplingMetadata:
|
||||
# Set the new spec token ids in the cached sampling metadata.
|
||||
self.sampling_metadata.spec_token_ids = [
|
||||
req_id_to_spec_token_ids.get(req_id, []) for req_id in self.req_ids
|
||||
] if req_id_to_spec_token_ids else None
|
||||
return self.sampling_metadata
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
prompt_token_ids_cpu_tensor = torch.empty(
|
||||
|
||||
@ -32,7 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
|
||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@ -122,7 +122,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.use_spec_decode = False
|
||||
if self.speculative_config:
|
||||
self.use_spec_decode = True
|
||||
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
# TODO: find a better way to check if we are using ngram.
|
||||
assert self.speculative_config.ngram_prompt_lookup_min, \
|
||||
"Currently, only ngram spec decode is supported in V1."
|
||||
@ -951,12 +951,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.get_sampling_metadata(
|
||||
scheduler_output.scheduled_spec_decode_tokens)
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if not self.use_spec_decode:
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
target_probs = self.model.sampler.compute_probs(
|
||||
logits, sampling_metadata)
|
||||
scheduled_request_ids = scheduler_output.num_scheduled_tokens.keys(
|
||||
)
|
||||
draft_token_ids = [
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
for req_id in scheduled_request_ids
|
||||
]
|
||||
sampler_output = self.rejection_sampler(draft_token_ids,
|
||||
target_probs,
|
||||
sampling_metadata)
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
@ -1293,7 +1305,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
temperature=dummy_tensors(0.5),
|
||||
all_greedy=False,
|
||||
all_random=False,
|
||||
spec_token_ids=None,
|
||||
top_p=dummy_tensors(0.9),
|
||||
top_k=dummy_tensors(logits.size(1) - 1),
|
||||
min_p=None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user