mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 23:54:56 +08:00
[Spec Decode] Efficient padded speculation (#24539)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
5c65a72bb1
commit
b7433ca1a4
@ -19,6 +19,8 @@ from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
|
||||
@ -64,6 +66,86 @@ def _create_proposer(
|
||||
device=current_platform.device_type)
|
||||
|
||||
|
||||
def test_prepare_next_token_ids():
|
||||
"""
|
||||
Test for prepare_next_token_ids_cpu and prepare_next_token_ids_padded.
|
||||
Each will produce a device tensor of next_token_ids, taking as input
|
||||
either the GPU tensor of sampled_token_ids with -1 for rejected tokens,
|
||||
or the CPU python list[list[int]] with the rejected tokens removed.
|
||||
"""
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
num_requests = 4
|
||||
num_speculative_tokens = 4
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
query_lens=[num_speculative_tokens + 1] * num_requests,
|
||||
)
|
||||
|
||||
req_ids = [f"req_{i+1}" for i in range(num_requests)]
|
||||
mock_input_batch = mock.MagicMock(spec=InputBatch)
|
||||
mock_input_batch.req_ids = req_ids
|
||||
mock_input_batch.num_reqs = num_requests
|
||||
mock_input_batch.vocab_size = 100
|
||||
|
||||
mock_num_scheduled_tokens = {req_id: 0 for req_id in req_ids}
|
||||
mock_requests = {}
|
||||
for req_id in req_ids:
|
||||
mock_request = mock.MagicMock(spec=CachedRequestState)
|
||||
# Each request will have a backup next token id of 10, 20, 30, 40
|
||||
mock_request.get_token_id.return_value = int(req_id.split("_")[1]) * 10
|
||||
mock_request.num_computed_tokens = 0
|
||||
mock_requests[req_id] = mock_request
|
||||
|
||||
sampled_token_ids = [
|
||||
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
|
||||
[0, 1, 2, 3, 4], # all accepted, "4" sampled
|
||||
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
|
||||
[-1, -1, -1, -1, -1] # this request will be discarded
|
||||
]
|
||||
sampled_token_ids_tensor = torch.tensor(sampled_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
sampled_token_ids_cpu = [[i for i in seq if i != -1]
|
||||
for seq in sampled_token_ids]
|
||||
|
||||
expected_next_token_ids_cpu = [1, 4, 30, 40]
|
||||
expected_next_token_ids_tensor = torch.tensor(expected_next_token_ids_cpu,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
next_token_ids_from_cpu = proposer.prepare_next_token_ids_cpu(
|
||||
sampled_token_ids_cpu, mock_requests, mock_input_batch,
|
||||
mock_num_scheduled_tokens)
|
||||
|
||||
assert torch.equal(next_token_ids_from_cpu, expected_next_token_ids_tensor)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
|
||||
num_discarded_reqs = 1
|
||||
|
||||
expected_valid_sampled_tokens_count = torch.tensor([2, 5, 0, 0],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
next_token_ids_from_padded, valid_sampled_tokens_count = \
|
||||
proposer.prepare_next_token_ids_padded(
|
||||
common_attn_metadata, sampled_token_ids_tensor, mock_requests,
|
||||
mock_input_batch, discarded_req_indices, num_discarded_reqs)
|
||||
|
||||
assert torch.equal(next_token_ids_from_padded,
|
||||
expected_next_token_ids_tensor)
|
||||
assert torch.equal(valid_sampled_tokens_count,
|
||||
expected_valid_sampled_tokens_count)
|
||||
|
||||
|
||||
def test_prepare_inputs():
|
||||
"""
|
||||
cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||
@ -90,10 +172,24 @@ def test_prepare_inputs():
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Rejected tokens per request: [1, 3, 2]
|
||||
num_rejected_tokens = torch.tensor([1, 3, 2],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
# If there are `k` sampled tokens, then `k-1` tokens are draft tokens
|
||||
# from the previous iteration, and the last token is the bonus token sampled
|
||||
# from the base model.
|
||||
num_draft_tokens = [3, 6, 4] # one less than query_lens
|
||||
# num rejected tokens is [1, 3, 2]
|
||||
ACCEPT_TOKEN = 0
|
||||
BONUS_TOKEN = 1
|
||||
REJECT_TOKEN = -1
|
||||
sampled_token_ids = [
|
||||
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, BONUS_TOKEN],
|
||||
[
|
||||
ACCEPT_TOKEN, ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN,
|
||||
REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN
|
||||
],
|
||||
[ACCEPT_TOKEN, ACCEPT_TOKEN, REJECT_TOKEN, REJECT_TOKEN, BONUS_TOKEN]
|
||||
]
|
||||
sampled_token_ids = [[i for i in seq if i != REJECT_TOKEN]
|
||||
for seq in sampled_token_ids]
|
||||
|
||||
# Expected calculations:
|
||||
# query_len_per_req = [4, 7, 5]
|
||||
@ -125,7 +221,7 @@ def test_prepare_inputs():
|
||||
proposer = _create_proposer("eagle", 1)
|
||||
|
||||
updated_metadata, token_indices = proposer.prepare_inputs(
|
||||
common_attn_metadata, num_rejected_tokens.cpu())
|
||||
common_attn_metadata, sampled_token_ids, num_draft_tokens)
|
||||
|
||||
assert torch.equal(updated_metadata.query_start_loc,
|
||||
expected_cu_num_tokens)
|
||||
@ -133,6 +229,77 @@ def test_prepare_inputs():
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
|
||||
|
||||
def test_prepare_inputs_padded():
|
||||
"""
|
||||
Input scenario is 3 requests with num_speculative_tokens == 2 and:
|
||||
- Request 1: query_len = 3, rejected = 1
|
||||
- Request 2: query_len = 3, rejected = 0
|
||||
- Request 3: query_len = 3, rejected = 2
|
||||
|
||||
Expected outputs:
|
||||
token_indices: [0, 1, 2,
|
||||
3, 4, 5,
|
||||
6, 7, 8]
|
||||
Reason: Deferred computation should not disturb the original indices.
|
||||
|
||||
token_indices_to_sample: [1, 5, 6]
|
||||
Reason: After accounting for rejections, these are the valid token positions
|
||||
from the original indices to sample from.
|
||||
"""
|
||||
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
expected_token_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
expected_token_indices_to_sample = torch.tensor([1, 5, 6],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
num_speculative_tokens = 2
|
||||
batch_spec = BatchSpec(
|
||||
seq_lens=[3, 3, 3],
|
||||
query_lens=[3, 3, 3],
|
||||
)
|
||||
|
||||
common_attn_metadata = create_common_attn_metadata(
|
||||
batch_spec,
|
||||
block_size=16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Needed for cu_num_draft_tokens, which is expected to be [3, 6, 9]
|
||||
expected_query_start_loc = torch.tensor([0, 3, 6, 9],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
spec_decode_metadata = SpecDecodeMetadata.make_dummy(
|
||||
draft_token_ids=[[0] * num_speculative_tokens] * 3,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# num_rejected_tokens = [1, 0, 2]
|
||||
# num_draft_tokens = [2, 2, 2]
|
||||
# valid_sampled_tokens_count = num_draft_tokens + 1 - num_rejected_tokens
|
||||
valid_sampled_tokens_count = torch.tensor([2, 3, 1],
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
output_metadata, token_indices, token_indices_to_sample = \
|
||||
proposer.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count)
|
||||
|
||||
assert output_metadata.max_query_len == 3
|
||||
assert torch.equal(output_metadata.query_start_loc,
|
||||
expected_query_start_loc)
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
assert torch.equal(token_indices_to_sample,
|
||||
expected_token_indices_to_sample)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
get_attn_backend_list_based_on_platform())
|
||||
@ -373,6 +540,7 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
@ -526,6 +694,7 @@ def test_propose_tree(spec_token_tree):
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=None,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata)
|
||||
assert result.shape == (batch_size, num_speculative_tokens)
|
||||
|
||||
@ -83,6 +83,11 @@ class SpeculativeConfig:
|
||||
disable_by_batch_size: Optional[int] = None
|
||||
"""Disable speculative decoding for new incoming requests when the number
|
||||
of enqueued requests is larger than this value, if provided."""
|
||||
disable_padded_drafter_batch: bool = False
|
||||
"""Disable input padding for speculative decoding. If set to True,
|
||||
speculative input batches can contain sequences of different lengths,
|
||||
which may only be supported by certain attention backends. This currently
|
||||
only affects the EAGLE method of speculation."""
|
||||
|
||||
# Ngram proposer configuration
|
||||
prompt_lookup_max: Optional[int] = None
|
||||
|
||||
@ -27,6 +27,9 @@ from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -94,20 +97,26 @@ class EagleProposer:
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.arange = torch.arange(
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
max_batch_size + 1,
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
|
||||
self.arange = torch.arange(max_num_slots_for_arange,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
dtype=torch.int32)
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
|
||||
self.backup_next_token_ids = CpuGpuBuffer(
|
||||
max_batch_size,
|
||||
dtype=torch.int32,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
device=device,
|
||||
with_numpy=True)
|
||||
|
||||
# Determine allowed attention backends once during initialization.
|
||||
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
|
||||
if current_platform.is_rocm():
|
||||
@ -156,12 +165,15 @@ class EagleProposer:
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
last_token_indices: Optional[torch.Tensor],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embeds: Optional[list[torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
|
||||
if last_token_indices is None:
|
||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
if self.method == "eagle3":
|
||||
@ -228,6 +240,12 @@ class EagleProposer:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
|
||||
@ -245,15 +263,12 @@ class EagleProposer:
|
||||
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
# Early exit if there is only one draft token to be generated.
|
||||
if self.num_speculative_tokens == 1:
|
||||
# [batch_size, 1]
|
||||
return draft_token_ids.view(-1, 1)
|
||||
|
||||
# TODO: Currently, MTP module released by deepseek only has
|
||||
# one layer. Adapt this code to support multiple layers once
|
||||
# there's a multi-layer MTP module.
|
||||
assert isinstance(attn_metadata, self.allowed_attn_types)
|
||||
if not isinstance(attn_metadata, self.allowed_attn_types):
|
||||
raise ValueError(
|
||||
f"Unsupported attention metadata type for speculative "
|
||||
"decoding with num_speculative_tokens > 1: "
|
||||
f"{type(attn_metadata)}. Supported types are: "
|
||||
f"{self.allowed_attn_types}")
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
@ -263,10 +278,13 @@ class EagleProposer:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
attn_metadata.num_actual_tokens = batch_size
|
||||
attn_metadata.max_query_len = 1
|
||||
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
|
||||
common_attn_metadata.num_actual_tokens = batch_size
|
||||
common_attn_metadata.max_query_len = 1
|
||||
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||
self.token_arange_np[:batch_size + 1]).clone()
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
@ -286,27 +304,38 @@ class EagleProposer:
|
||||
positions)
|
||||
|
||||
# Increment the sequence lengths.
|
||||
attn_metadata.max_seq_len += 1
|
||||
attn_metadata.seq_lens += 1
|
||||
# Consider max model length.
|
||||
attn_metadata.max_seq_len = min(attn_metadata.max_seq_len,
|
||||
self.max_model_len)
|
||||
common_attn_metadata.seq_lens += 1
|
||||
common_attn_metadata.seq_lens_cpu += 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
|
||||
1)
|
||||
|
||||
common_attn_metadata.num_computed_tokens_cpu = \
|
||||
common_attn_metadata.seq_lens_cpu - 1
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_numbers = clamped_positions // self.block_size
|
||||
block_ids = attn_metadata.block_table.gather(
|
||||
block_ids = common_attn_metadata.block_table_tensor.gather(
|
||||
dim=1, index=block_numbers.view(-1, 1))
|
||||
block_ids = block_ids.view(-1)
|
||||
attn_metadata.slot_mapping = (block_ids * self.block_size +
|
||||
common_attn_metadata.slot_mapping = (
|
||||
block_ids * self.block_size +
|
||||
clamped_positions % self.block_size)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
|
||||
PADDING_SLOT_ID)
|
||||
common_attn_metadata.slot_mapping.masked_fill_(
|
||||
exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
|
||||
# Rebuild attention metadata
|
||||
attn_metadata_builder = \
|
||||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
attn_metadata = attn_metadata_builder\
|
||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||
draft_index=token_index + 1)
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
@ -347,6 +376,158 @@ class EagleProposer:
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
def prepare_next_token_ids_cpu(
|
||||
self, sampled_token_ids: list[list[int]],
|
||||
requests: dict[str,
|
||||
CachedRequestState], gpu_input_batch: InputBatch,
|
||||
num_scheduled_tokens: dict[str, int]) -> torch.Tensor:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids for each request based on the sampled
|
||||
token ids from the CPU. If a request has no sampled token ids (e.g.,
|
||||
during the initial decoding steps), it falls back to using the request
|
||||
state to get the next token id.
|
||||
"""
|
||||
req_ids = gpu_input_batch.req_ids
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = req_ids[i]
|
||||
req_state = requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
num_scheduled_tokens[req_id])
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device)
|
||||
return next_token_ids
|
||||
|
||||
def prepare_next_token_ids_padded(self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
discard_request_indices: torch.Tensor,
|
||||
num_discarded_requests: int) -> \
|
||||
tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids and the number of valid sampled tokens
|
||||
for each request, considering the "discarded" requests whose next token
|
||||
is not sampled and comes from `request.get_token_id()` instead.
|
||||
It also accounts for the rejected tokens in `sampled_token_ids`.
|
||||
This function must use device functions to operate on the inputs, and
|
||||
should not introduce any blocking CPU-GPU synchronization.
|
||||
"""
|
||||
# TODO(Ben): Combine this into a custom fused kernel
|
||||
|
||||
# Precompute get_token_id for when there is no valid next token
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
self.backup_next_token_ids.np[:num_reqs] = np.array([
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||
common_attn_metadata.seq_lens_cpu[i].item())
|
||||
for i in range(num_reqs)
|
||||
])
|
||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||
|
||||
# Mask out the sampled tokens indices that should not be sampled.
|
||||
discard_sampled_tokens_req_indices = \
|
||||
discard_request_indices[:num_discarded_requests]
|
||||
|
||||
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
||||
valid_sampled_token_ids_gpu.index_fill_(
|
||||
0, discard_sampled_tokens_req_indices, -1)
|
||||
|
||||
# Generate a mask for all valid tokens within those requests
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
|
||||
dtype=torch.bool)
|
||||
else:
|
||||
valid_mask = (
|
||||
(valid_sampled_token_ids_gpu != -1) &
|
||||
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))
|
||||
|
||||
# Count the number of valid tokens in each request
|
||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||
|
||||
# Get the rightmost valid index per row
|
||||
last_valid_indices = valid_sampled_tokens_count - 1
|
||||
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
||||
|
||||
# Get last valid token from each row
|
||||
# (assume undefined state where there is no valid token)
|
||||
selected_tokens = torch.gather(
|
||||
valid_sampled_token_ids_gpu, 1,
|
||||
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Use last token if valid, pre-computed backup if not
|
||||
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
||||
next_token_ids = torch.where(
|
||||
last_valid_indices != -1, selected_tokens,
|
||||
self.backup_next_token_ids.gpu[:batch_size])
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
def prepare_inputs_padded(self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
valid_sampled_tokens_count: torch.Tensor) -> \
|
||||
tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding
|
||||
It updates the common_attn_metadata for speculative decoding,
|
||||
but does not consider the rejected tokens. Instead, all tokens
|
||||
are included as inputs to the speculator, with the rejected tokens
|
||||
used as padding and filtered out later by `token_indices_to_sample`.
|
||||
No blocking CPU operations should be introduced in this function.
|
||||
"""
|
||||
num_draft_tokens_gpu = torch.cat([
|
||||
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
||||
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
||||
spec_decode_metadata.cu_num_draft_tokens[:-1]
|
||||
])
|
||||
|
||||
num_rejected_tokens_gpu = torch.where(
|
||||
num_draft_tokens_gpu > 0,
|
||||
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
||||
torch.zeros_like(num_draft_tokens_gpu))
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
new_query_len_per_req = (query_start_loc_cpu[1:] -
|
||||
query_start_loc_cpu[:-1])
|
||||
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
token_indices = self.arange[:total_num_tokens]
|
||||
|
||||
spec_common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
||||
num_computed_tokens_cpu=common_attn_metadata.
|
||||
num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||||
causal=True,
|
||||
)
|
||||
|
||||
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
|
||||
- num_rejected_tokens_gpu
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
|
||||
def propose_tree(
|
||||
self,
|
||||
batch_size: int,
|
||||
@ -520,11 +701,11 @@ class EagleProposer:
|
||||
def prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
# [batch_size]
|
||||
num_rejected_tokens: torch.Tensor
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_draft_tokens: list[int],
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for the spec decode.
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It updates to the common_attn_metadata to account for the rejected
|
||||
tokens (and newly sampled tokens). It also returns the token indices
|
||||
of the tokens that should be fed to the speculator.
|
||||
@ -545,6 +726,13 @@ class EagleProposer:
|
||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(num_rejected_tokens,
|
||||
dtype=torch.int32)
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
|
||||
|
||||
@ -64,7 +64,10 @@ class CachedRequestState:
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
return self.prompt_token_ids[idx]
|
||||
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
|
||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||
else:
|
||||
return -1
|
||||
|
||||
|
||||
class InputBatch:
|
||||
|
||||
@ -344,6 +344,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.hidden_size,
|
||||
dtype=self.dtype,
|
||||
numpy=False)
|
||||
self.discard_request_indices = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int64)
|
||||
self.num_discarded_requests = 0
|
||||
|
||||
self.num_draft_tokens = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int32)
|
||||
self.num_accepted_tokens = self._make_buffer(self.max_num_reqs,
|
||||
@ -974,6 +978,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
seq_lens = self.seq_lens.gpu[:num_reqs]
|
||||
max_seq_len = self.seq_lens.np[:num_reqs].max().item()
|
||||
|
||||
num_tokens = [
|
||||
self.requests[r].num_tokens for r in self.input_batch.req_ids
|
||||
]
|
||||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||
|
||||
# Record the index of requests that should not be sampled,
|
||||
# so that we could clear the sampled tokens before returning
|
||||
discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
|
||||
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
||||
self.num_discarded_requests = len(discard_request_indices)
|
||||
self.discard_request_indices.np[:self.num_discarded_requests] = (
|
||||
discard_request_indices)
|
||||
|
||||
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
|
||||
|
||||
# Copy the tensors to the GPU.
|
||||
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
|
||||
|
||||
@ -1973,23 +1992,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
||||
num_nans_in_logits = self._get_nans_in_logits(logits)
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
discard_sampled_tokens_req_indices = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
if seq_len < req_state.num_tokens:
|
||||
# Ignore the sampled token for partial prefills.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
# Record the index of the request that should not be sampled,
|
||||
# so that we could clear the sampled tokens before returning.
|
||||
discard_sampled_tokens_req_indices.append(i)
|
||||
discard_sampled_tokens_req_indices = \
|
||||
self.discard_request_indices.np[:self.num_discarded_requests]
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
gen = self.input_batch.generators.get(int(i))
|
||||
if gen is not None:
|
||||
gen.set_offset(gen.get_offset() - 4)
|
||||
|
||||
# Copy some objects so they don't get modified after returning.
|
||||
# This is important when using async scheduling.
|
||||
@ -2026,10 +2034,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
# Mask out the sampled tokens that should not be sampled.
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
valid_sampled_token_ids[int(i)].clear()
|
||||
else:
|
||||
valid_sampled_token_ids = []
|
||||
invalid_req_indices = list(discard_sampled_tokens_req_indices)
|
||||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
|
||||
invalid_req_indices_set = set(invalid_req_indices)
|
||||
assert sampled_token_ids.shape[-1] == 1
|
||||
|
||||
@ -2229,6 +2237,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
with record_function_or_nullcontext("Sample"):
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
def propose_draft_token_ids(sampled_token_ids):
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
with record_function_or_nullcontext("Draft"):
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
sampled_token_ids,
|
||||
self.input_batch.sampling_metadata,
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
spec_decode_metadata,
|
||||
spec_decode_common_attn_metadata,
|
||||
)
|
||||
|
||||
use_padded_batch_for_eagle = self.speculative_config and \
|
||||
self.speculative_config.use_eagle() and \
|
||||
not self.speculative_config.disable_padded_drafter_batch
|
||||
if use_padded_batch_for_eagle:
|
||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
||||
|
||||
with record_function_or_nullcontext("Bookkeep"):
|
||||
(
|
||||
num_nans_in_logits,
|
||||
@ -2242,19 +2272,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logits, hidden_states,
|
||||
num_scheduled_tokens)
|
||||
|
||||
if self.speculative_config:
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
with record_function_or_nullcontext("Draft"):
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
valid_sampled_token_ids,
|
||||
self.input_batch.sampling_metadata,
|
||||
hidden_states,
|
||||
sample_hidden_states,
|
||||
aux_hidden_states,
|
||||
spec_decode_metadata,
|
||||
spec_decode_common_attn_metadata,
|
||||
)
|
||||
if self.speculative_config and not use_padded_batch_for_eagle:
|
||||
# ngram and other speculative decoding methods use the sampled
|
||||
# tokens on the CPU, so they are run after bookkeeping.
|
||||
propose_draft_token_ids(valid_sampled_token_ids)
|
||||
|
||||
with record_function_or_nullcontext("EPLB"):
|
||||
self.eplb_step()
|
||||
@ -2294,7 +2315,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
sampled_token_ids: list[list[int]],
|
||||
sampled_token_ids: Union[torch.Tensor, list[list[int]]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
hidden_states: torch.Tensor,
|
||||
sample_hidden_states: torch.Tensor,
|
||||
@ -2304,11 +2325,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
) -> Union[list[list[int]], torch.Tensor]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if self.speculative_config.method == "ngram":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
draft_token_ids = self.propose_ngram_draft_token_ids(
|
||||
sampled_token_ids)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(sampled_token_ids, list)
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
|
||||
if sample_hidden_states.shape[0] == len(sampled_token_ids):
|
||||
# The input to the target model does not include draft tokens.
|
||||
hidden_states = sample_hidden_states
|
||||
@ -2329,27 +2353,37 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
elif self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# TODO(woosuk): Refactor the loop.
|
||||
req_ids = self.input_batch.req_ids
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
# the cpu-side list[list[int]] of valid sampled tokens for each
|
||||
# request, with invalid requests having empty lists.
|
||||
assert isinstance(sampled_token_ids, list), \
|
||||
"sampled_token_ids should be a python list when" \
|
||||
"padded-batch is disabled."
|
||||
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
|
||||
sampled_token_ids, self.requests, self.input_batch,
|
||||
scheduler_output.num_scheduled_tokens)
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = req_ids[i]
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
# When using padded-batch, the sampled_token_ids should be
|
||||
# the gpu tensor of sampled tokens for each request, of shape
|
||||
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
||||
# value -1.
|
||||
assert isinstance(sampled_token_ids, torch.Tensor), \
|
||||
"sampled_token_ids should be a torch.Tensor when" \
|
||||
"padded-batch is enabled."
|
||||
next_token_ids, valid_sampled_tokens_count = \
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
self.discard_request_indices.gpu,
|
||||
self.num_discarded_requests
|
||||
)
|
||||
|
||||
if spec_decode_metadata is None:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids.gpu[:num_scheduled_tokens]
|
||||
# TODO(woosuk): Support M-RoPE.
|
||||
@ -2361,17 +2395,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
else:
|
||||
# TODO(woosuk): Refactor this.
|
||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens,
|
||||
dtype=torch.int32)
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
token_indices_to_sample = None
|
||||
common_attn_metadata, token_indices =\
|
||||
self.drafter.prepare_inputs(
|
||||
common_attn_metadata, num_rejected_tokens_cpu)
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
spec_decode_metadata.num_draft_tokens)
|
||||
else:
|
||||
common_attn_metadata, token_indices, \
|
||||
token_indices_to_sample =\
|
||||
self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count)
|
||||
|
||||
target_token_ids = self.input_ids.gpu[token_indices]
|
||||
# TODO(woosuk): Support M-RoPE.
|
||||
@ -2391,6 +2428,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=token_indices_to_sample,
|
||||
sampling_metadata=sampling_metadata,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
mm_embeds=mm_embeds,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user