[Spec Decode] Efficient padded speculation (#24539)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett 2025-09-18 01:07:24 -04:00 committed by GitHub
parent 5c65a72bb1
commit b7433ca1a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 507 additions and 104 deletions

View File

@ -19,6 +19,8 @@ from vllm.config.load import LoadConfig
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
@ -64,6 +66,86 @@ def _create_proposer(
device=current_platform.device_type)
def test_prepare_next_token_ids():
"""
Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
Each will produce a device tensor of next_token_ids, taking as input
either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
or the CPU python list[list[int]] with the rejected tokens removed.
"""
device = torch.device(current_platform.device_type)
num_requests = 4
num_speculative_tokens = 4
batch_spec = BatchSpec(
seq_lens=[num_speculative_tokens + 1] * num_requests,
query_lens=[num_speculative_tokens + 1] * num_requests,
)
req_ids = [f"req_{i+1}" for i in range(num_requests)]
mock_input_batch = mock.MagicMock(spec=InputBatch)
mock_input_batch.req_ids = req_ids
mock_input_batch.num_reqs = num_requests
mock_input_batch.vocab_size = 100
mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
mock_requests = {}
for req_id in req_ids:
mock_request = mock.MagicMock(spec=CachedRequestState)
# Each request will have a backup next token id of 10, 20, 30, 40
mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
mock_request.num_computed_tokens = 0
mock_requests[req_id] = mock_request
sampled_token_ids = [
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
[0, 1, 2, 3, 4], # all accepted, "4" sampled
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
[-1, -1, -1, -1, -1] # this request will be discarded
]
sampled_token_ids_tensor = torch.tensor(sampled_token_ids,
dtype=torch.int32,
device=device)
sampled_token_ids_cpu = [[i for i in seq if i != -1]
for seq in sampled_token_ids]
expected_next_token_ids_cpu = [1, 4, 30, 40]
expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu,
dtype=torch.int32,
device=device)
proposer = _create_proposer("eagle", num_speculative_tokens)
next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
sampled_token_ids_cpu, mock_requests, mock_input_batch,
mock_num_scheduled_tokens)
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
num_discarded_reqs = 1
expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0],
dtype=torch.int32,
device=device)
next_token_ids_from_padded, valid_sampled_tokens_count = \
proposer.prepare_next_token_ids_padded(
common_attn_metadata, sampled_token_ids_tensor, mock_requests,
mock_input_batch, discarded_req_indices, num_discarded_reqs)
assert torch.equal(next_token_ids_from_padded,
expected_next_token_ids_tensor)
assert torch.equal(valid_sampled_tokens_count,
expected_valid_sampled_tokens_count)
def test_prepare_inputs():
"""
cu_target_query_lens: [0, a, a + b, a + b + c]
@ -90,10 +172,24 @@ def test_prepare_inputs():
device=device,
)
# Rejected tokens per request: [1, 3, 2]
num_rejected_tokens = torch.tensor([1, 3, 2],
dtype=torch.int32,
device=device)
# If there are `k` sampled tokens, then `k-1` tokens are draft tokens
# from the previous iteration, and the last token is the bonus token sampled
# from the base model.
num_draft_tokens = [3, 6, 4] # one less than query_lens
# num rejected tokens is [1, 3, 2]
ACCEPT_TOKEN = 0
BONUS_TOKEN = 1
REJECT_TOKEN = -1
sampled_token_ids = [
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
[
ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN,
REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN
],
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN]
]
sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN]
for seq in sampled_token_ids]
# Expected calculations:
# query_len_per_req = [4, 7, 5]
@ -125,7 +221,7 @@ def test_prepare_inputs():
proposer = _create_proposer("eagle", 1)
updated_metadata, token_indices = proposer.prepare_inputs(
common_attn_metadata, num_rejected_tokens.cpu())
common_attn_metadata, sampled_token_ids, num_draft_tokens)
assert torch.equal(updated_metadata.query_start_loc,
expected_cu_num_tokens)
@ -133,6 +229,77 @@ def test_prepare_inputs():
assert torch.equal(token_indices, expected_token_indices)
def test_prepare_inputs_padded():
"""
Input scenario is 3 requests with num_speculative_tokens == 2 and:
- Request 1: query_len = 3, rejected = 1
- Request 2: query_len = 3, rejected = 0
- Request 3: query_len = 3, rejected = 2
Expected outputs:
token_indices: [0, 1, 2,
3, 4, 5,
6, 7, 8]
Reason: Deferred computation should not disturb the original indices.
token_indices_to_sample: [1, 5, 6]
Reason: After accounting for rejections, these are the valid token positions
from the original indices to sample from.
"""
device = torch.device(current_platform.device_type)
expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8],
dtype=torch.int32,
device=device)
expected_token_indices_to_sample = torch.tensor([1, 5, 6],
dtype=torch.int32,
device=device)
num_speculative_tokens = 2
batch_spec = BatchSpec(
seq_lens=[3, 3, 3],
query_lens=[3, 3, 3],
)
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size=16,
device=device,
)
# Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
expected_query_start_loc = torch.tensor([0, 3, 6, 9],
dtype=torch.int32,
device=device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids=[[0] * num_speculative_tokens] * 3,
device=device,
)
# num_rejected_tokens = [1, 0, 2]
# num_draft_tokens = [2, 2, 2]
# valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
valid_sampled_tokens_count = torch.tensor([2, 3, 1],
dtype=torch.int32,
device=device)
proposer = _create_proposer("eagle", num_speculative_tokens)
output_metadata, token_indices, token_indices_to_sample = \
proposer.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)
assert output_metadata.max_query_len == 3
assert torch.equal(output_metadata.query_start_loc,
expected_query_start_loc)
assert torch.equal(token_indices, expected_token_indices)
assert torch.equal(token_indices_to_sample,
expected_token_indices_to_sample)
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("attn_backend",
get_attn_backend_list_based_on_platform())
@ -373,6 +540,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
@ -526,6 +694,7 @@ def test_propose_tree(spec_token_tree):
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=None,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata)
assert result.shape == (batch_size, num_speculative_tokens)

View File

@ -83,6 +83,11 @@ class SpeculativeConfig:
disable_by_batch_size: Optional[int] = None
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
disable_padded_drafter_batch: bool = False
"""Disable input padding for speculative decoding. If set to True,
speculative input batches can contain sequences of different lengths,
which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation."""
# Ngram proposer configuration
prompt_lookup_max: Optional[int] = None

View File

@ -27,6 +27,9 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
@ -94,20 +97,26 @@ class EagleProposer:
dtype=self.dtype,
device=device)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.arange = torch.arange(
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_batch_size + 1,
device=device,
dtype=torch.int32,
)
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
self.arange = torch.arange(max_num_slots_for_arange,
device=device,
dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
self.backup_next_token_ids = CpuGpuBuffer(
max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True)
# Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
if current_platform.is_rocm():
@ -156,13 +165,16 @@ class EagleProposer:
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
last_token_indices: Optional[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
@ -228,6 +240,12 @@ class EagleProposer:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
draft_token_ids = logits.argmax(dim=-1)
return draft_token_ids.view(-1, 1)
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
@ -245,15 +263,12 @@ class EagleProposer:
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
assert isinstance(attn_metadata, self.allowed_attn_types)
if not isinstance(attn_metadata, self.allowed_attn_types):
raise ValueError(
f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: "
f"{type(attn_metadata)}. Supported types are: "
f"{self.allowed_attn_types}")
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
@ -263,10 +278,13 @@ class EagleProposer:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1):
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:batch_size + 1]).clone()
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
@ -286,27 +304,38 @@ class EagleProposer:
positions)
# Increment the sequence lengths.
attn_metadata.max_seq_len += 1
attn_metadata.seq_lens += 1
# Consider max model length.
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
self.max_model_len)
common_attn_metadata.seq_lens += 1
common_attn_metadata.seq_lens_cpu += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
1)
common_attn_metadata.num_computed_tokens_cpu = \
common_attn_metadata.seq_lens_cpu - 1
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = attn_metadata.block_table.gather(
block_ids = common_attn_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
attn_metadata.slot_mapping = (block_ids * self.block_size +
clamped_positions % self.block_size)
common_attn_metadata.slot_mapping = (
block_ids * self.block_size +
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)
common_attn_metadata.slot_mapping.masked_fill_(
exceeds_max_model_len, PADDING_SLOT_ID)
# Rebuild attention metadata
attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
attn_metadata = attn_metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
@ -347,6 +376,158 @@ class EagleProposer:
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def prepare_next_token_ids_cpu(
self, sampled_token_ids: list[list[int]],
requests: dict[str,
CachedRequestState], gpu_input_batch: InputBatch,
num_scheduled_tokens: dict[str, int]) -> torch.Tensor:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids for each request based on the sampled
token ids from the CPU. If a request has no sampled token ids (e.g.,
during the initial decoding steps), it falls back to using the request
state to get the next token id.
"""
req_ids = gpu_input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = req_ids[i]
req_state = requests[req_id]
seq_len = (req_state.num_computed_tokens +
num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.input_ids.device)
return next_token_ids
def prepare_next_token_ids_padded(self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int) -> \
tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
for each request, considering the "discarded" requests whose next token
is not sampled and comes from `request.get_token_id()` instead.
It also accounts for the rejected tokens in `sampled_token_ids`.
This function must use device functions to operate on the inputs, and
should not introduce any blocking CPU-GPU synchronization.
"""
# TODO(Ben): Combine this into a custom fused kernel
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
self.backup_next_token_ids.np[:num_reqs] = np.array([
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item())
for i in range(num_reqs)
])
self.backup_next_token_ids.copy_to_gpu(num_reqs)
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices = \
discard_request_indices[:num_discarded_requests]
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
valid_sampled_token_ids_gpu.index_fill_(
0, discard_sampled_tokens_req_indices, -1)
# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
dtype=torch.bool)
else:
valid_mask = (
(valid_sampled_token_ids_gpu != -1) &
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Get the rightmost valid index per row
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1,
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
# Use last token if valid, pre-computed backup if not
batch_size = valid_sampled_token_ids_gpu.shape[0]
next_token_ids = torch.where(
last_valid_indices != -1, selected_tokens,
self.backup_next_token_ids.gpu[:batch_size])
return next_token_ids, valid_sampled_tokens_count
def prepare_inputs_padded(self,
common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor) -> \
tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_draft_tokens_gpu = torch.cat([
spec_decode_metadata.cu_num_draft_tokens[0:1],
spec_decode_metadata.cu_num_draft_tokens[1:] -
spec_decode_metadata.cu_num_draft_tokens[:-1]
])
num_rejected_tokens_gpu = torch.where(
num_draft_tokens_gpu > 0,
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
torch.zeros_like(num_draft_tokens_gpu))
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
total_num_tokens = query_start_loc_cpu[-1].item()
token_indices = self.arange[:total_num_tokens]
spec_common_attn_metadata = CommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
seq_lens=common_attn_metadata.seq_lens,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
)
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
- num_rejected_tokens_gpu
return spec_common_attn_metadata, token_indices, token_indices_to_sample
def propose_tree(
self,
batch_size: int,
@ -520,11 +701,11 @@ class EagleProposer:
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
# [batch_size]
num_rejected_tokens: torch.Tensor
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for the spec decode.
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
@ -545,6 +726,13 @@ class EagleProposer:
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens = torch.tensor(num_rejected_tokens,
dtype=torch.int32)
device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \

View File

@ -64,7 +64,10 @@ class CachedRequestState:
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]
return self.output_token_ids[idx - self.num_prompt_tokens]
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
return self.output_token_ids[idx - self.num_prompt_tokens]
else:
return -1
class InputBatch:

View File

@ -344,6 +344,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.hidden_size,
dtype=self.dtype,
numpy=False)
self.discard_request_indices = self._make_buffer(self.max_num_reqs,
dtype=torch.int64)
self.num_discarded_requests = 0
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
dtype=torch.int32)
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
@ -974,6 +978,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
seq_lens = self.seq_lens.gpu[:num_reqs]
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
num_tokens = [
self.requests[r].num_tokens for r in self.input_batch.req_ids
]
num_tokens_np = np.array(num_tokens, dtype=np.int32)
# Record the index of requests that should not be sampled,
# so that we could clear the sampled tokens before returning
discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
discard_request_indices = np.nonzero(discard_requests_mask)[0]
self.num_discarded_requests = len(discard_request_indices)
self.discard_request_indices.np[:self.num_discarded_requests] = (
discard_request_indices)
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
# Copy the tensors to the GPU.
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
@ -1973,23 +1992,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
num_nans_in_logits = self._get_nans_in_logits(logits)
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
discard_sampled_tokens_req_indices = []
for i, req_id in enumerate(self.input_batch.req_ids):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
if seq_len < req_state.num_tokens:
# Ignore the sampled token for partial prefills.
# Rewind the generator state as if the token was not sampled.
# This relies on cuda-specific torch-internal impl details
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 4)
# Record the index of the request that should not be sampled,
# so that we could clear the sampled tokens before returning.
discard_sampled_tokens_req_indices.append(i)
discard_sampled_tokens_req_indices = \
self.discard_request_indices.np[:self.num_discarded_requests]
for i in discard_sampled_tokens_req_indices:
gen = self.input_batch.generators.get(int(i))
if gen is not None:
gen.set_offset(gen.get_offset() - 4)
# Copy some objects so they don't get modified after returning.
# This is important when using async scheduling.
@ -2026,10 +2034,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[i].clear()
valid_sampled_token_ids[int(i)].clear()
else:
valid_sampled_token_ids = []
invalid_req_indices = list(discard_sampled_tokens_req_indices)
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
invalid_req_indices_set = set(invalid_req_indices)
assert sampled_token_ids.shape[-1] == 1
@ -2229,6 +2237,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
with record_function_or_nullcontext("Sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
def propose_draft_token_ids(sampled_token_ids):
assert spec_decode_common_attn_metadata is not None
with record_function_or_nullcontext("Draft"):
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
sampled_token_ids,
self.input_batch.sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
spec_decode_common_attn_metadata,
)
use_padded_batch_for_eagle = self.speculative_config and \
self.speculative_config.use_eagle() and \
not self.speculative_config.disable_padded_drafter_batch
if use_padded_batch_for_eagle:
# EAGLE speculative decoding can use the GPU sampled tokens
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampler_output.sampled_token_ids)
with record_function_or_nullcontext("Bookkeep"):
(
num_nans_in_logits,
@ -2242,19 +2272,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits, hidden_states,
num_scheduled_tokens)
if self.speculative_config:
assert spec_decode_common_attn_metadata is not None
with record_function_or_nullcontext("Draft"):
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
self.input_batch.sampling_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
spec_decode_metadata,
spec_decode_common_attn_metadata,
)
if self.speculative_config and not use_padded_batch_for_eagle:
# ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)
with record_function_or_nullcontext("EPLB"):
self.eplb_step()
@ -2294,7 +2315,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
sampled_token_ids: list[list[int]],
sampled_token_ids: Union[torch.Tensor, list[list[int]]],
sampling_metadata: SamplingMetadata,
hidden_states: torch.Tensor,
sample_hidden_states: torch.Tensor,
@ -2304,11 +2325,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> Union[list[list[int]], torch.Tensor]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids)
elif self.speculative_config.method == "medusa":
assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, MedusaProposer)
if sample_hidden_states.shape[0] == len(sampled_token_ids):
# The input to the target model does not include draft tokens.
hidden_states = sample_hidden_states
@ -2329,27 +2353,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
elif self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# TODO(woosuk): Refactor the loop.
req_ids = self.input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = req_ids[i]
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
if self.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
# the cpu-side list[list[int]] of valid sampled tokens for each
# request, with invalid requests having empty lists.
assert isinstance(sampled_token_ids, list), \
"sampled_token_ids should be a python list when" \
"padded-batch is disabled."
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
sampled_token_ids, self.requests, self.input_batch,
scheduler_output.num_scheduled_tokens)
else:
# When using padded-batch, the sampled_token_ids should be
# the gpu tensor of sampled tokens for each request, of shape
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
# value -1.
assert isinstance(sampled_token_ids, torch.Tensor), \
"sampled_token_ids should be a torch.Tensor when" \
"padded-batch is enabled."
next_token_ids, valid_sampled_tokens_count = \
self.drafter.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.requests,
self.input_batch,
self.discard_request_indices.gpu,
self.num_discarded_requests
)
if spec_decode_metadata is None:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
# TODO(woosuk): Support M-RoPE.
@ -2361,17 +2395,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
# TODO(woosuk): Refactor this.
num_draft_tokens = spec_decode_metadata.num_draft_tokens
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens,
dtype=torch.int32)
common_attn_metadata, token_indices =\
self.drafter.prepare_inputs(
common_attn_metadata, num_rejected_tokens_cpu)
if self.speculative_config.disable_padded_drafter_batch:
token_indices_to_sample = None
common_attn_metadata, token_indices =\
self.drafter.prepare_inputs(
common_attn_metadata,
sampled_token_ids,
spec_decode_metadata.num_draft_tokens)
else:
common_attn_metadata, token_indices, \
token_indices_to_sample =\
self.drafter.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)
target_token_ids = self.input_ids.gpu[token_indices]
# TODO(woosuk): Support M-RoPE.
@ -2391,6 +2428,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
mm_embeds=mm_embeds,