mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 10:24:28 +08:00
[Perf] Optimize EAGLE prepare_inputs_padded with triton kernels (#28597)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
parent
3461e7efd8
commit
1986de1375
@ -103,16 +103,23 @@ def test_prepare_next_token_ids():
|
||||
mock_request.num_computed_tokens = 0
|
||||
mock_requests[req_id] = mock_request
|
||||
|
||||
# explicitly discard the last request
|
||||
discarded_req_mask = torch.tensor(
|
||||
[False, False, False, True], dtype=torch.bool, device=device
|
||||
)
|
||||
sampled_token_ids = [
|
||||
[0, 1, -1, -1, -1], # 1 accepted, 3 rejected, "1" sampled
|
||||
[0, 1, 2, 3, 4], # all accepted, "4" sampled
|
||||
[-1, -1, -1, -1, -1], # sampling skipped, use backup token "30"
|
||||
[-1, -1, -1, -1, -1], # this request will be discarded
|
||||
[0, 1, 2, -1, -1], # explicitly discarded, sampling should be ignored
|
||||
]
|
||||
sampled_token_ids_tensor = torch.tensor(
|
||||
sampled_token_ids, dtype=torch.int32, device=device
|
||||
)
|
||||
sampled_token_ids_cpu = [[i for i in seq if i != -1] for seq in sampled_token_ids]
|
||||
for i in range(len(sampled_token_ids_cpu)):
|
||||
if discarded_req_mask[i]:
|
||||
sampled_token_ids_cpu[i] = []
|
||||
|
||||
expected_next_token_ids_cpu = [1, 4, 30, 40]
|
||||
expected_next_token_ids_tensor = torch.tensor(
|
||||
@ -136,9 +143,6 @@ def test_prepare_next_token_ids():
|
||||
device=device,
|
||||
)
|
||||
|
||||
discarded_req_indices = torch.tensor([3], dtype=torch.int64, device=device)
|
||||
num_discarded_reqs = 1
|
||||
|
||||
expected_valid_sampled_tokens_count = torch.tensor(
|
||||
[2, 5, 0, 0], dtype=torch.int32, device=device
|
||||
)
|
||||
@ -149,8 +153,7 @@ def test_prepare_next_token_ids():
|
||||
sampled_token_ids_tensor,
|
||||
mock_requests,
|
||||
mock_input_batch,
|
||||
discarded_req_indices,
|
||||
num_discarded_reqs,
|
||||
discarded_req_mask,
|
||||
)
|
||||
)
|
||||
|
||||
@ -256,11 +259,6 @@ def test_prepare_inputs_padded():
|
||||
- Request 3: query_len = 3, rejected = 2
|
||||
|
||||
Expected outputs:
|
||||
token_indices: [0, 1, 2,
|
||||
3, 4, 5,
|
||||
6, 7, 8]
|
||||
Reason: Deferred computation should not disturb the original indices.
|
||||
|
||||
token_indices_to_sample: [1, 5, 6]
|
||||
Reason: After accounting for rejections, these are the valid token positions
|
||||
from the original indices to sample from.
|
||||
@ -268,9 +266,6 @@ def test_prepare_inputs_padded():
|
||||
|
||||
device = torch.device(current_platform.device_type)
|
||||
|
||||
expected_token_indices = torch.tensor(
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.int32, device=device
|
||||
)
|
||||
expected_token_indices_to_sample = torch.tensor(
|
||||
[1, 5, 6], dtype=torch.int32, device=device
|
||||
)
|
||||
@ -305,15 +300,12 @@ def test_prepare_inputs_padded():
|
||||
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
|
||||
output_metadata, token_indices, token_indices_to_sample = (
|
||||
proposer.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
)
|
||||
output_metadata, token_indices_to_sample = proposer.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
)
|
||||
|
||||
assert output_metadata.max_query_len == 3
|
||||
assert torch.equal(output_metadata.query_start_loc, expected_query_start_loc)
|
||||
assert torch.equal(token_indices, expected_token_indices)
|
||||
assert torch.equal(token_indices_to_sample, expected_token_indices_to_sample)
|
||||
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.tree_attn import (
|
||||
@ -40,6 +41,10 @@ from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.utils import (
|
||||
eagle_prepare_inputs_padded_kernel,
|
||||
eagle_prepare_next_token_padded_kernel,
|
||||
)
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@ -555,20 +560,15 @@ class EagleProposer:
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
discard_request_indices: torch.Tensor,
|
||||
num_discarded_requests: int,
|
||||
discard_request_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids and the number of valid sampled tokens
|
||||
for each request, considering the "discarded" requests whose next token
|
||||
is not sampled and comes from `request.get_token_id()` instead.
|
||||
It also accounts for the rejected tokens in `sampled_token_ids`.
|
||||
This function must use device functions to operate on the inputs, and
|
||||
should not introduce any blocking CPU-GPU synchronization.
|
||||
is not sampled and comes from `request.get_token_id()` instead. This is denoted
|
||||
the "backup" token id. It also counts rejected tokens via `sampled_token_ids`.
|
||||
"""
|
||||
# TODO(Ben): Combine this into a custom fused kernel
|
||||
|
||||
# Precompute get_token_id for when there is no valid next token
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
self.backup_next_token_ids.np[:num_reqs] = np.array(
|
||||
@ -577,44 +577,39 @@ class EagleProposer:
|
||||
common_attn_metadata.seq_lens_cpu[i].item()
|
||||
)
|
||||
for i in range(num_reqs)
|
||||
]
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||
backup_tokens_gpu = self.backup_next_token_ids.gpu
|
||||
|
||||
# Mask out the sampled tokens indices that should not be sampled.
|
||||
discard_sampled_tokens_req_indices = discard_request_indices[
|
||||
:num_discarded_requests
|
||||
]
|
||||
batch_size, num_tokens = sampled_token_ids.shape
|
||||
device = sampled_token_ids.device
|
||||
|
||||
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
||||
valid_sampled_token_ids_gpu.index_fill_(
|
||||
0, discard_sampled_tokens_req_indices, -1
|
||||
assert discard_request_mask.dtype == torch.bool
|
||||
assert backup_tokens_gpu.dtype == torch.int32
|
||||
|
||||
next_token_ids = torch.empty((batch_size,), dtype=torch.int32, device=device)
|
||||
valid_sampled_tokens_count = torch.empty(
|
||||
(batch_size,), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Generate a mask for all valid tokens within those requests
|
||||
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
|
||||
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
|
||||
)
|
||||
# Kernel grid: one program per request (row)
|
||||
grid = (batch_size,)
|
||||
|
||||
# Count the number of valid tokens in each request
|
||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||
|
||||
# Get the rightmost valid index per row
|
||||
last_valid_indices = valid_sampled_tokens_count - 1
|
||||
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
||||
|
||||
# Get last valid token from each row
|
||||
# (assume undefined state where there is no valid token)
|
||||
selected_tokens = torch.gather(
|
||||
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
# Use last token if valid, pre-computed backup if not
|
||||
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
||||
next_token_ids = torch.where(
|
||||
last_valid_indices != -1,
|
||||
selected_tokens,
|
||||
self.backup_next_token_ids.gpu[:batch_size],
|
||||
# Find the next power of 2 for block sizes
|
||||
BLOCK_SIZE_TOKENS = triton.next_power_of_2(num_tokens)
|
||||
eagle_prepare_next_token_padded_kernel[grid](
|
||||
sampled_token_ids,
|
||||
discard_request_mask,
|
||||
backup_tokens_gpu,
|
||||
next_token_ids,
|
||||
valid_sampled_tokens_count,
|
||||
gpu_input_batch.vocab_size,
|
||||
num_tokens,
|
||||
batch_size,
|
||||
sampled_token_ids.stride(0),
|
||||
BLOCK_SIZE_TOKENS=BLOCK_SIZE_TOKENS,
|
||||
)
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
@ -624,35 +619,35 @@ class EagleProposer:
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
valid_sampled_tokens_count: torch.Tensor,
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding
|
||||
It updates the common_attn_metadata for speculative decoding,
|
||||
but does not consider the rejected tokens. Instead, all tokens
|
||||
are included as inputs to the speculator, with the rejected tokens
|
||||
used as padding and filtered out later by `token_indices_to_sample`.
|
||||
No blocking CPU operations should be introduced in this function.
|
||||
"""
|
||||
num_draft_tokens_gpu = torch.cat(
|
||||
[
|
||||
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
||||
spec_decode_metadata.cu_num_draft_tokens[1:]
|
||||
- spec_decode_metadata.cu_num_draft_tokens[:-1],
|
||||
]
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
device = valid_sampled_tokens_count.device
|
||||
|
||||
token_indices_to_sample = torch.empty(
|
||||
(num_reqs,), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
num_rejected_tokens_gpu = torch.where(
|
||||
num_draft_tokens_gpu > 0,
|
||||
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
||||
torch.zeros_like(num_draft_tokens_gpu),
|
||||
# Kernel grid: one program per request (row)
|
||||
grid = (num_reqs,)
|
||||
eagle_prepare_inputs_padded_kernel[grid](
|
||||
spec_decode_metadata.cu_num_draft_tokens,
|
||||
valid_sampled_tokens_count,
|
||||
common_attn_metadata.query_start_loc,
|
||||
token_indices_to_sample,
|
||||
num_reqs,
|
||||
)
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
new_query_len_per_req = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
token_indices = self.arange[:total_num_tokens]
|
||||
|
||||
spec_common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
@ -665,16 +660,12 @@ class EagleProposer:
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||||
slot_mapping=common_attn_metadata.slot_mapping[:total_num_tokens],
|
||||
causal=True,
|
||||
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
|
||||
)
|
||||
|
||||
token_indices_to_sample = (
|
||||
common_attn_metadata.query_start_loc[1:] - 1 - num_rejected_tokens_gpu
|
||||
)
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
return spec_common_attn_metadata, token_indices_to_sample
|
||||
|
||||
def propose_tree(
|
||||
self,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -14,3 +15,107 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool:
|
||||
or sampling_params.min_p > _SAMPLING_EPS
|
||||
or sampling_params.logprobs is not None
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def eagle_prepare_inputs_padded_kernel(
|
||||
cu_num_draft_tokens_ptr, # [num_reqs]
|
||||
valid_sampled_tokens_count_ptr, # [num_reqs]
|
||||
query_start_loc_gpu_ptr, # [num_reqs + 1]
|
||||
token_indices_to_sample_ptr, # [num_reqs] (output)
|
||||
num_reqs, # tl.int32
|
||||
):
|
||||
"""
|
||||
Fused kernel for Eagle prepare_input_padded. This kernel computes the
|
||||
token index to sample for each request, taking into account the number
|
||||
of draft tokens and the number of valid sampled tokens (which is one more than
|
||||
the number of accepted tokens).
|
||||
"""
|
||||
req_idx = tl.program_id(axis=0)
|
||||
if req_idx >= num_reqs:
|
||||
return
|
||||
|
||||
# Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive
|
||||
# cumulative sum (first entry is the first value, not zero).
|
||||
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + req_idx)
|
||||
|
||||
num_draft_tokens = 0
|
||||
if req_idx == 0:
|
||||
num_draft_tokens = cu_draft_curr
|
||||
else:
|
||||
cu_draft_prev = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
|
||||
num_draft_tokens = cu_draft_curr - cu_draft_prev
|
||||
|
||||
valid_count = tl.load(valid_sampled_tokens_count_ptr + req_idx)
|
||||
num_rejected_tokens = num_draft_tokens + 1 - valid_count
|
||||
num_rejected_tokens = tl.where(num_draft_tokens > 0, num_rejected_tokens, 0)
|
||||
|
||||
# query_start_loc[req_idx + 1] is the start position of the next request,
|
||||
# which is one past the last token of this request.
|
||||
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + req_idx + 1) - 1
|
||||
|
||||
index_to_sample = q_last_tok_idx - num_rejected_tokens
|
||||
tl.store(token_indices_to_sample_ptr + req_idx, index_to_sample)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def eagle_prepare_next_token_padded_kernel(
|
||||
sampled_token_ids_ptr, # [num_reqs, num_sampled_tokens_per_req]
|
||||
discard_request_mask_ptr, # [num_reqs]
|
||||
backup_next_token_ids_ptr, # [num_reqs]
|
||||
next_token_ids_ptr, # [num_reqs] (output)
|
||||
valid_sampled_tokens_count_ptr, # [num_reqs] (output)
|
||||
vocab_size, # tl.int32
|
||||
num_sampled_tokens_per_req, # tl.int32 (num_spec_tokens + 1)
|
||||
num_reqs, # tl.int32
|
||||
stride_sampled_token_ids, # tl.int32 (stride for dim 0)
|
||||
BLOCK_SIZE_TOKENS: tl.constexpr, # Power-of-2 >= num_sampled_tokens_per_req
|
||||
):
|
||||
"""
|
||||
Fused kernel for Eagle prepare_next_token_ids_padded. This kernel computes the
|
||||
number of valid (1 + accepted) tokens for each request, and the corresponding
|
||||
"next" token id to sample from during speculative decoding. This is the
|
||||
"last accepted token" from the sampled tokens, or the backup token if no
|
||||
tokens were accepted or if the request is marked as discarded.
|
||||
"""
|
||||
req_idx = tl.program_id(axis=0)
|
||||
if req_idx >= num_reqs:
|
||||
return
|
||||
|
||||
# Check if this request is discarded.
|
||||
is_discarded = tl.load(discard_request_mask_ptr + req_idx)
|
||||
|
||||
if is_discarded:
|
||||
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
|
||||
valid_count = tl.full((), 0, dtype=tl.uint32)
|
||||
tl.store(next_token_ids_ptr + req_idx, backup_token)
|
||||
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
|
||||
else:
|
||||
# Count the number of valid tokens among the sampled tokens.
|
||||
token_offs = tl.arange(0, BLOCK_SIZE_TOKENS)
|
||||
token_mask = token_offs < num_sampled_tokens_per_req
|
||||
|
||||
row_ptr = sampled_token_ids_ptr + req_idx * stride_sampled_token_ids
|
||||
token_ids = tl.load(row_ptr + token_offs, mask=token_mask, other=-1)
|
||||
|
||||
# Rejected tokens are -1, valid tokens are in [0, vocab_size)
|
||||
is_valid_mask = (token_ids != -1) & (token_ids < vocab_size) & token_mask
|
||||
valid_count = tl.sum(is_valid_mask)
|
||||
|
||||
if valid_count > 0:
|
||||
# Guaranteed to be well-defined since
|
||||
# valid_count > 0 implies is_valid_mask is not empty
|
||||
last_valid_index = tl.max(tl.where(is_valid_mask, token_offs, -1))
|
||||
|
||||
# Select the token at that index, using a sum trick since
|
||||
# we don't want to load again to access token_ids[last_valid_index].
|
||||
last_valid_token = tl.sum(
|
||||
tl.where(token_offs == last_valid_index, token_ids, 0)
|
||||
)
|
||||
tl.store(next_token_ids_ptr + req_idx, last_valid_token)
|
||||
else:
|
||||
# No valid tokens found, use backup token
|
||||
backup_token = tl.load(backup_next_token_ids_ptr + req_idx)
|
||||
tl.store(next_token_ids_ptr + req_idx, backup_token)
|
||||
|
||||
tl.store(valid_sampled_tokens_count_ptr + req_idx, valid_count)
|
||||
|
||||
@ -488,11 +488,9 @@ class GPUModelRunner(
|
||||
self.max_num_tokens, self.hidden_size, dtype=self.dtype, numpy=False
|
||||
)
|
||||
self.is_token_ids = self._make_buffer(self.max_num_tokens, dtype=torch.bool)
|
||||
self.discard_request_indices = self._make_buffer(
|
||||
self.max_num_reqs, dtype=torch.int64
|
||||
self.discard_request_mask = self._make_buffer(
|
||||
self.max_num_reqs, dtype=torch.bool
|
||||
)
|
||||
self.num_discarded_requests = 0
|
||||
|
||||
self.num_decode_draft_tokens = self._make_buffer(
|
||||
self.max_num_reqs, dtype=torch.int32
|
||||
)
|
||||
@ -1369,16 +1367,12 @@ class GPUModelRunner(
|
||||
num_tokens = [self.requests[r].num_tokens for r in self.input_batch.req_ids]
|
||||
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||
|
||||
# Record the index of requests that should not be sampled,
|
||||
# Record which requests should not be sampled,
|
||||
# so that we could clear the sampled tokens before returning
|
||||
discard_requests_mask = self.seq_lens.np[:num_reqs] < num_tokens_np
|
||||
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
||||
self.num_discarded_requests = len(discard_request_indices)
|
||||
self.discard_request_indices.np[: self.num_discarded_requests] = (
|
||||
discard_request_indices
|
||||
self.discard_request_mask.np[:num_reqs] = (
|
||||
self.seq_lens.np[:num_reqs] < num_tokens_np
|
||||
)
|
||||
|
||||
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
|
||||
self.discard_request_mask.copy_to_gpu(num_reqs)
|
||||
|
||||
# Copy the tensors to the GPU.
|
||||
self._prepare_input_ids(
|
||||
@ -2548,9 +2542,10 @@ class GPUModelRunner(
|
||||
if envs.VLLM_COMPUTE_NANS_IN_LOGITS:
|
||||
num_nans_in_logits = self._get_nans_in_logits(logits)
|
||||
|
||||
discard_sampled_tokens_req_indices = self.discard_request_indices.np[
|
||||
: self.num_discarded_requests
|
||||
]
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
discard_sampled_tokens_req_indices = np.nonzero(
|
||||
self.discard_request_mask.np[:num_reqs]
|
||||
)[0]
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
gen = self.input_batch.generators.get(int(i))
|
||||
if gen is not None:
|
||||
@ -3131,8 +3126,7 @@ class GPUModelRunner(
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
self.discard_request_indices.gpu,
|
||||
self.num_discarded_requests,
|
||||
self.discard_request_mask.gpu,
|
||||
)
|
||||
)
|
||||
self._copy_valid_sampled_token_count(
|
||||
@ -3335,8 +3329,7 @@ class GPUModelRunner(
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
self.discard_request_indices.gpu,
|
||||
self.num_discarded_requests,
|
||||
self.discard_request_mask.gpu,
|
||||
)
|
||||
)
|
||||
self._copy_valid_sampled_token_count(
|
||||
@ -3363,24 +3356,34 @@ class GPUModelRunner(
|
||||
sampled_token_ids,
|
||||
spec_decode_metadata.num_draft_tokens,
|
||||
)
|
||||
target_token_ids = self.input_ids.gpu[token_indices]
|
||||
target_positions = self._get_positions(token_indices)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
assert aux_hidden_states is not None
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1
|
||||
)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
else:
|
||||
common_attn_metadata, token_indices, token_indices_to_sample = (
|
||||
common_attn_metadata, token_indices_to_sample = (
|
||||
self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count,
|
||||
)
|
||||
)
|
||||
|
||||
target_token_ids = self.input_ids.gpu[token_indices]
|
||||
target_positions = self._get_positions(token_indices)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
assert aux_hidden_states is not None
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1
|
||||
)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
total_num_tokens = common_attn_metadata.num_actual_tokens
|
||||
# When padding the batch, token_indices is just a range
|
||||
target_token_ids = self.input_ids.gpu[:total_num_tokens]
|
||||
target_positions = self._get_positions(total_num_tokens)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
assert aux_hidden_states is not None
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:total_num_tokens] for h in aux_hidden_states], dim=-1
|
||||
)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:total_num_tokens]
|
||||
|
||||
if self.supports_mm_inputs:
|
||||
mm_embed_inputs = self._gather_mm_embeddings(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user