mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 14:34:32 +08:00
[Model Runner V2] Support spec decoding [1/N] (#29274)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
7f12c82fa6
commit
b004c00418
@ -35,6 +35,7 @@ class InputBuffers:
|
||||
self.positions = torch.zeros(max_num_tokens, dtype=torch.int64, device=device)
|
||||
self.query_start_loc = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
|
||||
self.seq_lens = torch.zeros(max_num_reqs, dtype=torch.int32, device=device)
|
||||
self.cu_num_logits = self._make_buffer(max_num_reqs + 1, dtype=torch.int32)
|
||||
|
||||
# Structured outputs.
|
||||
self.bitmask_indices = self._make_buffer(max_num_reqs, dtype=torch.int32)
|
||||
@ -64,6 +65,7 @@ class InputBatch:
|
||||
# sum(num_scheduled_tokens)
|
||||
num_tokens: int
|
||||
num_tokens_after_padding: int
|
||||
num_draft_tokens: int
|
||||
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor
|
||||
@ -80,8 +82,10 @@ class InputBatch:
|
||||
# layer_name -> Metadata
|
||||
attn_metadata: dict[str, Any]
|
||||
|
||||
# [num_reqs]
|
||||
# [total_num_logits]
|
||||
logits_indices: torch.Tensor
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
@ -118,6 +122,7 @@ class InputBatch:
|
||||
positions = input_buffers.positions[:num_tokens]
|
||||
# attn_metadata = defaultdict(lambda: None)
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
|
||||
return cls(
|
||||
req_ids=req_ids,
|
||||
num_reqs=num_reqs,
|
||||
@ -126,6 +131,7 @@ class InputBatch:
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens,
|
||||
num_draft_tokens=0,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
@ -134,6 +140,7 @@ class InputBatch:
|
||||
positions=positions,
|
||||
attn_metadata=None, # type: ignore
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
)
|
||||
|
||||
|
||||
@ -279,19 +286,53 @@ def _combine_sampled_and_draft_tokens_kernel(
|
||||
query_start_loc_ptr,
|
||||
seq_lens_ptr,
|
||||
prefill_len_ptr,
|
||||
draft_tokens_ptr,
|
||||
draft_tokens_stride,
|
||||
cu_num_logits_ptr,
|
||||
logits_indices_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
batch_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
|
||||
|
||||
# Get the number of logits and draft tokens.
|
||||
cu_num_logits_start = tl.load(cu_num_logits_ptr + batch_idx)
|
||||
cu_num_logits_end = tl.load(cu_num_logits_ptr + batch_idx + 1)
|
||||
num_logits = cu_num_logits_end - cu_num_logits_start
|
||||
num_draft_tokens = num_logits - 1
|
||||
|
||||
# Compute the logits indices.
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
logits_start = query_end - num_logits
|
||||
tl.store(
|
||||
logits_indices_ptr + cu_num_logits_start + block,
|
||||
logits_start + block,
|
||||
mask=block < num_logits,
|
||||
)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + batch_idx)
|
||||
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
|
||||
if seq_len <= prefill_len:
|
||||
# Handling prefill tokens.
|
||||
# Handling prefill tokens. No sampled or draft tokens.
|
||||
return
|
||||
|
||||
# Write the last sampled token ID to input_ids.
|
||||
last_token_id = tl.load(last_sampled_tokens_ptr + req_state_idx)
|
||||
end = tl.load(query_start_loc_ptr + batch_idx + 1)
|
||||
tl.store(input_ids_ptr + end - 1, last_token_id)
|
||||
tl.store(input_ids_ptr + query_end - num_logits, last_token_id)
|
||||
|
||||
# Write the draft tokens (if any) to input_ids.
|
||||
if num_draft_tokens > 0:
|
||||
mask = block < num_draft_tokens
|
||||
draft_tokens = tl.load(
|
||||
draft_tokens_ptr + req_state_idx * draft_tokens_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
input_ids_ptr + query_end - num_draft_tokens + block,
|
||||
draft_tokens,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def combine_sampled_and_draft_tokens(
|
||||
@ -301,8 +342,18 @@ def combine_sampled_and_draft_tokens(
|
||||
query_start_loc: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
prefill_len: torch.Tensor,
|
||||
draft_tokens: torch.Tensor,
|
||||
cu_num_logits: torch.Tensor,
|
||||
num_logits: int,
|
||||
) -> torch.Tensor:
|
||||
num_reqs = seq_lens.shape[0]
|
||||
num_speculative_steps = draft_tokens.shape[-1]
|
||||
|
||||
logits_indices = torch.empty(
|
||||
num_logits,
|
||||
dtype=torch.int64,
|
||||
device=input_ids.device,
|
||||
)
|
||||
_combine_sampled_and_draft_tokens_kernel[(num_reqs,)](
|
||||
input_ids,
|
||||
idx_mapping,
|
||||
@ -310,35 +361,80 @@ def combine_sampled_and_draft_tokens(
|
||||
query_start_loc,
|
||||
seq_lens,
|
||||
prefill_len,
|
||||
draft_tokens,
|
||||
draft_tokens.stride(0),
|
||||
cu_num_logits,
|
||||
logits_indices,
|
||||
# NOTE(woosuk): Add 1 to ensure the block can cover the last sampled token
|
||||
# in addition to all draft tokens.
|
||||
BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
|
||||
)
|
||||
return input_ids
|
||||
return logits_indices
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _update_num_computed_tokens_kernel(
|
||||
def _post_update_kernel(
|
||||
idx_mapping_ptr,
|
||||
num_computed_tokens_ptr,
|
||||
last_sampled_tokens_ptr,
|
||||
sampled_tokens_ptr,
|
||||
sampled_tokens_stride,
|
||||
num_sampled_ptr,
|
||||
query_start_loc_ptr,
|
||||
cu_num_logits_ptr,
|
||||
):
|
||||
req_id = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_id)
|
||||
|
||||
start = tl.load(query_start_loc_ptr + req_id)
|
||||
end = tl.load(query_start_loc_ptr + req_id + 1)
|
||||
query_len = end - start
|
||||
num_sampled = tl.load(num_sampled_ptr + req_id)
|
||||
if num_sampled > 0:
|
||||
token_id = tl.load(
|
||||
sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
|
||||
)
|
||||
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
|
||||
|
||||
n = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
tl.store(num_computed_tokens_ptr + req_state_idx, n + query_len)
|
||||
query_start = tl.load(query_start_loc_ptr + req_id)
|
||||
query_end = tl.load(query_start_loc_ptr + req_id + 1)
|
||||
query_len = query_end - query_start
|
||||
|
||||
num_computed = tl.load(num_computed_tokens_ptr + req_state_idx)
|
||||
num_computed += query_len
|
||||
# Consider the rejected tokens in spec decoding.
|
||||
if num_sampled > 0:
|
||||
# NOTE(woosuk): We must skip num_sampled == 0 to account for chunked prefills.
|
||||
logits_start = tl.load(cu_num_logits_ptr + req_id)
|
||||
logits_end = tl.load(cu_num_logits_ptr + req_id + 1)
|
||||
num_logits = logits_end - logits_start
|
||||
num_rejected = num_logits - num_sampled
|
||||
num_computed -= num_rejected
|
||||
tl.store(num_computed_tokens_ptr + req_state_idx, num_computed)
|
||||
|
||||
|
||||
def update_num_computed_tokens(
|
||||
def post_update(
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
num_computed_tokens: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
last_sampled_tokens: torch.Tensor,
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled_tokens: torch.Tensor,
|
||||
# [num_reqs]
|
||||
num_sampled: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
query_start_loc: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor,
|
||||
) -> None:
|
||||
num_reqs = idx_mapping.shape[0]
|
||||
_update_num_computed_tokens_kernel[(num_reqs,)](
|
||||
_post_update_kernel[(num_reqs,)](
|
||||
idx_mapping,
|
||||
num_computed_tokens,
|
||||
last_sampled_tokens,
|
||||
sampled_tokens,
|
||||
sampled_tokens.stride(0),
|
||||
num_sampled,
|
||||
query_start_loc,
|
||||
cu_num_logits,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
@ -40,11 +40,12 @@ from vllm.v1.worker.gpu.input_batch import (
|
||||
InputBatch,
|
||||
InputBuffers,
|
||||
combine_sampled_and_draft_tokens,
|
||||
post_update,
|
||||
prepare_pos_seq_lens,
|
||||
prepare_prefill_inputs,
|
||||
update_num_computed_tokens,
|
||||
)
|
||||
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
||||
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
|
||||
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
@ -100,10 +101,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.input_prep_event = None
|
||||
self.structured_outputs_event = None
|
||||
|
||||
if self.speculative_config is not None:
|
||||
self.do_spec_decode = True
|
||||
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
|
||||
else:
|
||||
self.do_spec_decode = False
|
||||
self.num_speculative_steps = 0
|
||||
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
num_speculative_steps=self.num_speculative_steps,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
@ -427,6 +436,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
idx_mapping_np = idx_mapping.np[:num_reqs]
|
||||
idx_mapping = idx_mapping.copy_to_gpu(num_reqs)
|
||||
|
||||
# Get the number of draft tokens for each request.
|
||||
if not scheduler_output.scheduled_spec_decode_tokens:
|
||||
# No draft token scheduled (common case).
|
||||
total_num_draft_tokens = 0
|
||||
total_num_logits = num_reqs
|
||||
cu_num_logits = torch.arange(
|
||||
num_reqs + 1, device=self.device, dtype=torch.int32
|
||||
)
|
||||
else:
|
||||
draft_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
num_draft_tokens = np.array(
|
||||
[
|
||||
len(draft_tokens[req_id]) if req_id in draft_tokens else 0
|
||||
for req_id in req_ids
|
||||
],
|
||||
dtype=np.int32,
|
||||
)
|
||||
total_num_draft_tokens = int(num_draft_tokens.sum())
|
||||
total_num_logits = num_reqs + total_num_draft_tokens
|
||||
|
||||
np.cumsum(
|
||||
num_draft_tokens + 1,
|
||||
out=self.input_buffers.cu_num_logits.np[1 : num_reqs + 1],
|
||||
)
|
||||
cu_num_logits = self.input_buffers.cu_num_logits.copy_to_gpu(num_reqs + 1)
|
||||
|
||||
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
|
||||
block_tables = self.block_tables.gather_block_tables(idx_mapping)
|
||||
|
||||
@ -456,14 +491,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
seq_lens = self.input_buffers.seq_lens[:num_reqs]
|
||||
|
||||
# Some input token ids are directly read from the last sampled tokens
|
||||
# and draft tokens.
|
||||
combine_sampled_and_draft_tokens(
|
||||
# and draft tokens. Also, get the logits indices to sample tokens from.
|
||||
logits_indices = combine_sampled_and_draft_tokens(
|
||||
self.input_buffers.input_ids.gpu,
|
||||
idx_mapping,
|
||||
self.req_states.last_sampled_tokens,
|
||||
query_start_loc_gpu,
|
||||
seq_lens,
|
||||
self.req_states.prefill_len.gpu,
|
||||
self.req_states.draft_tokens,
|
||||
cu_num_logits,
|
||||
total_num_logits,
|
||||
)
|
||||
|
||||
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
|
||||
@ -471,9 +509,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
query_start_loc_gpu, self.input_buffers.positions[:num_tokens]
|
||||
)
|
||||
|
||||
# Logits indices to sample next token from.
|
||||
logits_indices = query_start_loc_gpu[1:] - 1
|
||||
|
||||
# Get num_computed_tokens.
|
||||
# HACK(woosuk): Here, we use num_computed_tokens on GPU instead of
|
||||
# num_computed_tokens_cpu. This works for most cases.
|
||||
@ -508,6 +543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_after_padding=num_tokens_after_padding,
|
||||
num_draft_tokens=total_num_draft_tokens,
|
||||
query_start_loc=query_start_loc_gpu,
|
||||
query_start_loc_np=query_start_loc_np,
|
||||
seq_lens=seq_lens,
|
||||
@ -516,6 +552,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
positions=positions,
|
||||
attn_metadata=attn_metadata,
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
)
|
||||
|
||||
def sample(
|
||||
@ -530,6 +567,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if grammar_output is not None:
|
||||
# Apply grammar bitmask to the logits in-place.
|
||||
# TODO(woosuk): Make compatible with spec decoding.
|
||||
assert input_batch.num_draft_tokens == 0
|
||||
with async_barrier(self.structured_outputs_event):
|
||||
apply_grammar_bitmask(
|
||||
logits,
|
||||
@ -539,12 +577,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.input_buffers,
|
||||
)
|
||||
|
||||
# Sample tokens and compute logprobs (if needed).
|
||||
sampler_output = self.sampler(logits, sampling_metadata)
|
||||
|
||||
# Get the number of sampled tokens.
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
prefill_len = self.req_states.prefill_len.gpu[input_batch.idx_mapping]
|
||||
is_chunked_prefilling = input_batch.seq_lens < prefill_len
|
||||
num_sampled = (~is_chunked_prefilling).int()
|
||||
if input_batch.num_draft_tokens == 0:
|
||||
# No draft tokens (common case).
|
||||
# 0 if chunked-prefilling, 1 if not.
|
||||
num_sampled = (~is_chunked_prefilling).int()
|
||||
else:
|
||||
# Draft tokens for spec decoding.
|
||||
input_ids = input_batch.input_ids[input_batch.logits_indices]
|
||||
sampled_tokens, num_sampled = rejection_sample(
|
||||
sampler_output.sampled_token_ids,
|
||||
input_ids,
|
||||
input_batch.cu_num_logits,
|
||||
self.num_speculative_steps,
|
||||
)
|
||||
num_sampled *= ~is_chunked_prefilling
|
||||
sampler_output.sampled_token_ids = sampled_tokens
|
||||
# TODO(woosuk): Support logprobs with spec decoding.
|
||||
return sampler_output, num_sampled
|
||||
|
||||
def compute_prompt_logprobs(
|
||||
@ -653,11 +707,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_sampled: torch.Tensor,
|
||||
) -> None:
|
||||
# Update the number of computed tokens.
|
||||
update_num_computed_tokens(
|
||||
post_update(
|
||||
input_batch.idx_mapping,
|
||||
self.req_states.num_computed_tokens,
|
||||
self.req_states.last_sampled_tokens,
|
||||
sampled_tokens,
|
||||
num_sampled,
|
||||
input_batch.query_start_loc,
|
||||
input_batch.cu_num_logits,
|
||||
)
|
||||
|
||||
# Update the number of computed prefill tokens.
|
||||
idx_mapping_np = input_batch.idx_mapping_np
|
||||
computed_prefill = self.req_states.num_computed_prefill_tokens
|
||||
# TODO(woosuk): Simplify this.
|
||||
@ -666,10 +726,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.req_states.prefill_len.np[idx_mapping_np],
|
||||
)
|
||||
|
||||
# Store the last sampled token ids.
|
||||
last_sampled = sampled_tokens
|
||||
self.req_states.last_sampled_tokens[input_batch.idx_mapping] = last_sampled
|
||||
|
||||
def get_cudagraph_and_dp_padding(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
@ -761,6 +817,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
sampling_metadata = self.req_states.make_sampling_metadata(
|
||||
input_batch.idx_mapping_np, pos
|
||||
)
|
||||
if input_batch.num_draft_tokens > 0:
|
||||
sampling_metadata = self.req_states.expand_sampling_metadata(
|
||||
sampling_metadata, input_batch.cu_num_logits
|
||||
)
|
||||
|
||||
if self.lora_config:
|
||||
# Activate LoRA adapters.
|
||||
|
||||
0
vllm/v1/worker/gpu/spec_decode/__init__.py
Normal file
0
vllm/v1/worker/gpu/spec_decode/__init__.py
Normal file
71
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
Normal file
71
vllm/v1/worker/gpu/spec_decode/rejection_sample.py
Normal file
@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _rejection_sample_kernel(
|
||||
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
|
||||
sampled_stride,
|
||||
num_sampled_ptr, # [num_reqs]
|
||||
target_sampled_ptr, # [num_draft_tokens + num_reqs]
|
||||
input_ids_ptr, # [num_draft_tokens + num_reqs]
|
||||
cu_num_logits_ptr, # [num_reqs + 1]
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
num_sampled = 0
|
||||
rejected = False
|
||||
for i in range(num_tokens - 1):
|
||||
if not rejected:
|
||||
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
|
||||
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
|
||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
|
||||
num_sampled += 1
|
||||
if target_sampled != draft_sampled:
|
||||
rejected = True
|
||||
if not rejected:
|
||||
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
|
||||
tl.store(
|
||||
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
|
||||
)
|
||||
num_sampled += 1
|
||||
tl.store(num_sampled_ptr + req_idx, num_sampled)
|
||||
|
||||
|
||||
def rejection_sample(
|
||||
# [num_draft_tokens + num_reqs]
|
||||
target_sampled: torch.Tensor,
|
||||
# [num_draft_tokens + num_reqs]
|
||||
input_ids: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor,
|
||||
num_speculative_steps: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
sampled = torch.empty(
|
||||
num_reqs,
|
||||
num_speculative_steps + 1,
|
||||
dtype=target_sampled.dtype,
|
||||
device=target_sampled.device,
|
||||
)
|
||||
num_sampled = torch.empty(
|
||||
num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=target_sampled.device,
|
||||
)
|
||||
_rejection_sample_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
num_sampled,
|
||||
target_sampled,
|
||||
input_ids,
|
||||
cu_num_logits,
|
||||
num_warps=1,
|
||||
)
|
||||
return sampled, num_sampled
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
@ -63,6 +64,7 @@ class RequestState:
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_batched_tokens: int,
|
||||
num_speculative_steps: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
@ -70,6 +72,7 @@ class RequestState:
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.num_speculative_steps = num_speculative_steps
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
@ -100,6 +103,14 @@ class RequestState:
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Draft tokens.
|
||||
self.draft_tokens = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.num_speculative_steps,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# LoRA.
|
||||
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
|
||||
self.lora_ids.fill(NO_LORA_ID)
|
||||
@ -226,6 +237,17 @@ class RequestState:
|
||||
max_num_logprobs=max_num_logprobs,
|
||||
)
|
||||
|
||||
def expand_sampling_metadata(
|
||||
self,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
cu_num_logits: torch.Tensor,
|
||||
) -> SamplingMetadata:
|
||||
# For draft tokens, we need to expand the sampling param tensors as
|
||||
# each request samples multiple tokens in each step.
|
||||
return expand_sampling_metadata(
|
||||
sampling_metadata, cu_num_logits, self.num_speculative_steps
|
||||
)
|
||||
|
||||
def make_lora_inputs(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
@ -270,3 +292,75 @@ class Param:
|
||||
class ExtraData:
|
||||
lora_request: LoRARequest | None
|
||||
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
|
||||
|
||||
|
||||
# NOTE(woosuk): Re-compilation can happen at runtime since top_p and top_k can be None.
|
||||
@triton.jit
|
||||
def _expand_sampling_metadata_kernel(
|
||||
temp_ptr,
|
||||
expanded_temp_ptr,
|
||||
top_p_ptr,
|
||||
expanded_top_p_ptr,
|
||||
top_k_ptr,
|
||||
expanded_top_k_ptr,
|
||||
seeds_ptr,
|
||||
expanded_seeds_ptr,
|
||||
cu_num_logits_ptr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
|
||||
block = tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < num_tokens
|
||||
|
||||
temp = tl.load(temp_ptr + req_idx)
|
||||
tl.store(expanded_temp_ptr + start_idx + block, temp, mask=mask)
|
||||
|
||||
if top_p_ptr is not None:
|
||||
top_p = tl.load(top_p_ptr + req_idx)
|
||||
tl.store(expanded_top_p_ptr + start_idx + block, top_p, mask=mask)
|
||||
|
||||
if top_k_ptr is not None:
|
||||
top_k = tl.load(top_k_ptr + req_idx)
|
||||
tl.store(expanded_top_k_ptr + start_idx + block, top_k, mask=mask)
|
||||
|
||||
seed = tl.load(seeds_ptr + req_idx)
|
||||
tl.store(expanded_seeds_ptr + start_idx + block, seed, mask=mask)
|
||||
|
||||
|
||||
def expand_sampling_metadata(
|
||||
sampling_metadata: SamplingMetadata,
|
||||
cu_num_logits: torch.Tensor,
|
||||
num_speculative_steps: int,
|
||||
) -> SamplingMetadata:
|
||||
total_num_logits = sampling_metadata.pos.shape[0]
|
||||
create_empty = lambda x: x.new_empty(total_num_logits) if x is not None else None
|
||||
expanded_temp = create_empty(sampling_metadata.temperature)
|
||||
expanded_top_p = create_empty(sampling_metadata.top_p)
|
||||
expanded_top_k = create_empty(sampling_metadata.top_k)
|
||||
expanded_seeds = create_empty(sampling_metadata.seeds)
|
||||
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
_expand_sampling_metadata_kernel[(num_reqs,)](
|
||||
sampling_metadata.temperature,
|
||||
expanded_temp,
|
||||
sampling_metadata.top_p,
|
||||
expanded_top_p,
|
||||
sampling_metadata.top_k,
|
||||
expanded_top_k,
|
||||
sampling_metadata.seeds,
|
||||
expanded_seeds,
|
||||
cu_num_logits,
|
||||
BLOCK_SIZE=triton.next_power_of_2(num_speculative_steps + 1),
|
||||
)
|
||||
return SamplingMetadata(
|
||||
temperature=expanded_temp,
|
||||
top_p=expanded_top_p,
|
||||
top_k=expanded_top_k,
|
||||
seeds=expanded_seeds,
|
||||
pos=sampling_metadata.pos,
|
||||
max_num_logprobs=sampling_metadata.max_num_logprobs,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user