mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-09 01:58:43 +08:00
[Model Runner V2] Implement multi-step Eagle with CUDA graph (#29559)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
43c5792592
commit
da3222f371
@ -233,10 +233,11 @@ def prepare_inputs_to_capture(
|
||||
query_start_loc.np[num_reqs:] = num_tokens
|
||||
query_start_loc.copy_to_gpu()
|
||||
seq_lens_np = np.full(num_reqs, max_model_len, dtype=np.int32)
|
||||
# HACK(woosuk): To optimize warmup time, we use 1 (instead of max_model_len)
|
||||
# for seq_lens. This leads to a mismatch between seq_lens (GPU) and
|
||||
# seq_lens_np (CPU), which might cause issues in some attention backends.
|
||||
input_buffers.seq_lens[:num_reqs] = 1
|
||||
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
|
||||
# rather than max_model_len. This introduces a discrepancy between
|
||||
# seq_lens (on GPU) and seq_lens_np (on CPU), which may cause issues for
|
||||
# certain attention backends.
|
||||
input_buffers.seq_lens[:num_reqs] = num_tokens
|
||||
input_buffers.seq_lens[num_reqs:] = 0
|
||||
|
||||
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
|
||||
|
||||
@ -140,10 +140,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
|
||||
|
||||
# CUDA graphs.
|
||||
self.cudagraph_manager = CudaGraphManager(
|
||||
vllm_config=self.vllm_config,
|
||||
device=self.device,
|
||||
)
|
||||
self.cudagraph_manager = CudaGraphManager(self.vllm_config, self.device)
|
||||
|
||||
def get_supported_tasks(self) -> tuple[str]:
|
||||
return ("generate",)
|
||||
@ -203,6 +200,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
# HACK(woosuk)
|
||||
self.speculator.set_attn(
|
||||
self.kv_cache_config,
|
||||
self.attn_metadata_builders,
|
||||
self.block_tables,
|
||||
)
|
||||
|
||||
# TODO(woosuk): Support other backends.
|
||||
if not all(b.get_name() == "FLASH_ATTN" for b in self.attn_backends.values()):
|
||||
raise NotImplementedError("Only FLASH_ATTN backend is supported currently.")
|
||||
@ -297,35 +302,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
self.sampler(logits, sampling_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_speculator_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
aux_hidden_states: list[torch.Tensor] | None,
|
||||
) -> None:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
input_batch = InputBatch.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens,
|
||||
input_buffers=self.input_buffers,
|
||||
device=self.device,
|
||||
)
|
||||
sampling_metadata = SamplingMetadata.make_dummy(
|
||||
num_reqs=num_reqs,
|
||||
device=self.device,
|
||||
)
|
||||
num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
|
||||
num_rejected = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
|
||||
self.propose_draft(
|
||||
input_batch=input_batch,
|
||||
sampling_metadata=sampling_metadata,
|
||||
last_hidden_states=hidden_states,
|
||||
aux_hidden_states=aux_hidden_states,
|
||||
num_sampled=num_sampled,
|
||||
num_rejected=num_rejected,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
hidden_states, sample_hidden_states = self._dummy_run(
|
||||
@ -334,7 +310,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
if self.do_spec_decode:
|
||||
self._dummy_speculator_run(hidden_states, None)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(
|
||||
self.dp_size, self.max_num_tokens
|
||||
)
|
||||
self.speculator.run_model(
|
||||
self.max_num_tokens,
|
||||
attn_metadata=None,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
del hidden_states, sample_hidden_states
|
||||
gc.collect()
|
||||
@ -368,6 +351,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
self.speculator.capture_model()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
@ -1,17 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.sampler import gumbel_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
|
||||
from vllm.v1.worker.gpu.states import SamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EagleSpeculator:
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
@ -27,13 +39,48 @@ class EagleSpeculator:
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
# We need to get the hidden size from the draft model config because
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
self.vocab_size = self.draft_model_config.get_vocab_size()
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
|
||||
self.input_ids = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int32, device=device
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
hidden_size=self.hidden_size,
|
||||
vocab_size=self.vocab_size,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
self.positions = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int64, device=device
|
||||
self.hidden_states = torch.zeros(
|
||||
self.max_num_tokens,
|
||||
self.hidden_size,
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.temperature = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
self.seeds = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
self.draft_tokens = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.num_speculative_steps,
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
@ -49,6 +96,91 @@ class EagleSpeculator:
|
||||
del self.model.lm_head
|
||||
self.model.lm_head = target_model.lm_head
|
||||
|
||||
def set_attn(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
block_tables: BlockTables,
|
||||
) -> None:
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.attn_metadata_builders = attn_metadata_builders
|
||||
self.block_tables = block_tables
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_model(
|
||||
self,
|
||||
num_tokens: int,
|
||||
attn_metadata: dict[str, Any],
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
with set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=self.input_buffers.input_ids.gpu[:num_tokens],
|
||||
positions=self.input_buffers.positions[:num_tokens],
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
return last_hidden_states, hidden_states
|
||||
|
||||
def generate_draft(
|
||||
self,
|
||||
num_reqs: int,
|
||||
attn_metadata: dict[str, Any],
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
) -> None:
|
||||
pos = self.input_buffers.positions[:num_reqs]
|
||||
query_start_loc = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
|
||||
for step in range(1, self.num_speculative_steps):
|
||||
# Run the eagle model.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
num_reqs, attn_metadata, num_tokens_across_dp
|
||||
)
|
||||
logits = self.model.compute_logits(last_hidden_states)
|
||||
|
||||
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
|
||||
# used for draft and target sampling.
|
||||
draft_tokens = gumbel_sample(
|
||||
logits,
|
||||
self.temperature[:num_reqs],
|
||||
self.seeds[:num_reqs],
|
||||
pos + 1,
|
||||
apply_temperature=True,
|
||||
)
|
||||
self.draft_tokens[:num_reqs, step] = draft_tokens
|
||||
|
||||
if step < self.num_speculative_steps - 1:
|
||||
# Update the inputs for the next step.
|
||||
update_eagle_inputs(
|
||||
draft_tokens,
|
||||
hidden_states,
|
||||
self.input_buffers,
|
||||
self.hidden_states,
|
||||
self.max_model_len,
|
||||
)
|
||||
self.block_tables.compute_slot_mappings(query_start_loc, pos)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
if self.num_speculative_steps == 1:
|
||||
return
|
||||
logger.info("Capturing model for Eagle speculator...")
|
||||
self.cudagraph_manager.capture(
|
||||
self.generate_draft,
|
||||
self.input_buffers,
|
||||
self.block_tables,
|
||||
self.attn_metadata_builders,
|
||||
self.kv_cache_config,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def propose(
|
||||
self,
|
||||
@ -80,64 +212,110 @@ class EagleSpeculator:
|
||||
)
|
||||
else:
|
||||
hidden_states = last_hidden_states
|
||||
num_tokens = input_batch.num_tokens_after_padding
|
||||
self.hidden_states[:num_tokens] = hidden_states
|
||||
|
||||
# Get the input ids and last token indices for the speculator.
|
||||
last_token_indices = prepare_eagle_inputs(
|
||||
self.input_ids,
|
||||
self.input_buffers,
|
||||
input_batch,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
last_sampled,
|
||||
next_prefill_tokens,
|
||||
)
|
||||
input_ids = self.input_ids[: input_batch.num_tokens_after_padding]
|
||||
|
||||
# Prefill: Run the eagle speculator with eager mode.
|
||||
with set_forward_context(
|
||||
# TODO(woosuk): Support CUDA graph for prefill.
|
||||
last_hidden_states, hidden_states = self.run_model(
|
||||
num_tokens,
|
||||
input_batch.attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch.num_tokens_after_padding,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=input_batch.positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
num_tokens_across_dp=None, # FIXME
|
||||
)
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
|
||||
num_reqs = input_batch.num_reqs
|
||||
cu_num_logits = input_batch.cu_num_logits[:num_reqs]
|
||||
temperature = sampling_metadata.temperature[cu_num_logits]
|
||||
seed = sampling_metadata.seeds[cu_num_logits]
|
||||
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
|
||||
# used for draft and target sampling.
|
||||
pos = input_batch.positions[last_token_indices] + 1
|
||||
# NOTE(woosuk): For draft sampling, we only consider the temperature
|
||||
# and ignore the other sampling parameters such as top_k and top_p,
|
||||
# for simplicity and performance.
|
||||
# While this may slightly degrade the acceptance rate, it does not
|
||||
# affect the output distribution after rejection sampling.
|
||||
temperature = self.temperature[:num_reqs]
|
||||
seeds = self.seeds[:num_reqs]
|
||||
pos = self.input_buffers.positions[:num_reqs]
|
||||
# Gather the values and copy them to the pre-allocated buffers.
|
||||
torch.gather(sampling_metadata.temperature, 0, cu_num_logits, out=temperature)
|
||||
torch.gather(sampling_metadata.seeds, 0, cu_num_logits, out=seeds)
|
||||
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
|
||||
# NOTE(woosuk): We must add 1 to the positions to match the Gumbel noise
|
||||
# used for draft and target sampling.
|
||||
draft_tokens = gumbel_sample(
|
||||
logits, temperature, seed, pos, apply_temperature=True
|
||||
logits, temperature, seeds, pos + 1, apply_temperature=True
|
||||
)
|
||||
if self.num_speculative_steps == 1:
|
||||
# Early exit.
|
||||
return draft_tokens.view(-1, 1)
|
||||
raise NotImplementedError("num_speculative_steps > 1 is not supported yet.")
|
||||
|
||||
# Save the draft tokens for the first step.
|
||||
self.draft_tokens[:num_reqs, 0] = draft_tokens
|
||||
# Prepare the inputs for the decode steps.
|
||||
prepare_eagle_decode(
|
||||
draft_tokens,
|
||||
hidden_states,
|
||||
last_token_indices,
|
||||
input_batch.seq_lens,
|
||||
num_rejected,
|
||||
self.input_buffers,
|
||||
self.hidden_states,
|
||||
self.max_model_len,
|
||||
self.max_num_reqs,
|
||||
)
|
||||
query_start_loc = self.input_buffers.query_start_loc
|
||||
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
|
||||
slot_mappings = self.block_tables.compute_slot_mappings(
|
||||
query_start_loc_gpu, pos
|
||||
)
|
||||
|
||||
cudagraph_size = self.cudagraph_manager.get_cudagraph_size(num_reqs)
|
||||
if cudagraph_size is not None:
|
||||
# Run CUDA graph.
|
||||
self.cudagraph_manager.run(cudagraph_size)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
# Run eager mode.
|
||||
query_start_loc.np[: num_reqs + 1] = np.arange(num_reqs + 1)
|
||||
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
|
||||
# HACK(woosuk)
|
||||
seq_lens_np = np.full(num_reqs, self.max_model_len, dtype=np.int32)
|
||||
block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables]
|
||||
|
||||
# FIXME(woosuk): This is UNSAFE!!
|
||||
attn_metadata = build_attn_metadata(
|
||||
attn_metadata_builders=self.attn_metadata_builders,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_reqs,
|
||||
query_start_loc_gpu=query_start_loc_gpu,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens=self.input_buffers.seq_lens[:num_reqs],
|
||||
seq_lens_np=seq_lens_np,
|
||||
num_computed_tokens_cpu=None, # FIXME
|
||||
block_tables=block_tables,
|
||||
slot_mappings=slot_mappings,
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
)
|
||||
self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_eagle_inputs_kernel(
|
||||
last_token_indices_ptr,
|
||||
eagle_input_ids_ptr,
|
||||
eagle_positions_ptr,
|
||||
target_input_ids_ptr,
|
||||
target_positions_ptr,
|
||||
idx_mapping_ptr,
|
||||
last_sampled_ptr,
|
||||
next_prefill_tokens_ptr,
|
||||
@ -175,9 +353,16 @@ def _prepare_eagle_inputs_kernel(
|
||||
tl.store(last_token_indices_ptr + batch_idx, last_token_index)
|
||||
tl.store(eagle_input_ids_ptr + last_token_index, next_token)
|
||||
|
||||
# Copy positions.
|
||||
for i in range(0, query_len, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < query_len
|
||||
target_pos = tl.load(target_positions_ptr + query_start + block, mask=mask)
|
||||
tl.store(eagle_positions_ptr + query_start + block, target_pos, mask=mask)
|
||||
|
||||
|
||||
def prepare_eagle_inputs(
|
||||
eagle_input_ids: torch.Tensor,
|
||||
input_buffers: InputBuffers,
|
||||
input_batch: InputBatch,
|
||||
# [num_reqs]
|
||||
num_sampled: torch.Tensor,
|
||||
@ -192,12 +377,14 @@ def prepare_eagle_inputs(
|
||||
last_token_indices = torch.empty(
|
||||
num_reqs,
|
||||
dtype=torch.int64,
|
||||
device=eagle_input_ids.device,
|
||||
device=num_sampled.device,
|
||||
)
|
||||
_prepare_eagle_inputs_kernel[(num_reqs,)](
|
||||
last_token_indices,
|
||||
eagle_input_ids,
|
||||
input_buffers.input_ids.gpu,
|
||||
input_buffers.positions,
|
||||
input_batch.input_ids,
|
||||
input_batch.positions,
|
||||
input_batch.idx_mapping,
|
||||
last_sampled,
|
||||
next_prefill_tokens,
|
||||
@ -207,3 +394,174 @@ def prepare_eagle_inputs(
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
return last_token_indices
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _prepare_eagle_docode_kernel(
|
||||
draft_tokens_ptr,
|
||||
output_hidden_states_ptr,
|
||||
output_hidden_states_stride,
|
||||
last_token_indices_ptr,
|
||||
target_seq_lens_ptr,
|
||||
num_rejected_ptr,
|
||||
input_ids_ptr,
|
||||
positions_ptr,
|
||||
input_hidden_states_ptr,
|
||||
input_hidden_states_stride,
|
||||
query_start_loc_ptr,
|
||||
seq_lens_ptr,
|
||||
hidden_size,
|
||||
max_model_len,
|
||||
max_num_reqs,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
num_reqs = tl.num_programs(0) - 1
|
||||
if req_idx == num_reqs:
|
||||
# Compute query_start_loc. Pad it with the last query_start_loc
|
||||
# for CUDA graphs.
|
||||
for i in range(0, max_num_reqs + 1, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
q = tl.where(block < num_reqs, block, num_reqs)
|
||||
mask = block < max_num_reqs + 1
|
||||
tl.store(query_start_loc_ptr + block, q, mask=mask)
|
||||
# Pad seq_lens for CUDA graphs.
|
||||
for i in range(req_idx, max_num_reqs, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < max_num_reqs
|
||||
tl.store(seq_lens_ptr + block, 0, mask=mask)
|
||||
return
|
||||
|
||||
# draft token -> input id.
|
||||
draft_token = tl.load(draft_tokens_ptr + req_idx)
|
||||
tl.store(input_ids_ptr + req_idx, draft_token)
|
||||
|
||||
# output hidden states -> input hidden states.
|
||||
src_idx = tl.load(last_token_indices_ptr + req_idx)
|
||||
for i in range(0, hidden_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < hidden_size
|
||||
output_hidden_states = tl.load(
|
||||
output_hidden_states_ptr + src_idx * output_hidden_states_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
|
||||
output_hidden_states,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Compute position and seq_lens.
|
||||
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
|
||||
# if they reach the max model length.
|
||||
position = tl.load(positions_ptr + req_idx)
|
||||
position = tl.minimum(position + 1, max_model_len - 1)
|
||||
tl.store(positions_ptr + req_idx, position)
|
||||
|
||||
target_seq_len = tl.load(target_seq_lens_ptr + req_idx)
|
||||
num_rejected = tl.load(num_rejected_ptr + req_idx)
|
||||
seq_len = target_seq_len - num_rejected
|
||||
seq_len = tl.minimum(seq_len + 1, max_model_len)
|
||||
tl.store(seq_lens_ptr + req_idx, seq_len)
|
||||
|
||||
|
||||
def prepare_eagle_decode(
|
||||
draft_tokens: torch.Tensor,
|
||||
output_hidden_states: torch.Tensor,
|
||||
last_token_indices: torch.Tensor,
|
||||
target_seq_lens: torch.Tensor,
|
||||
num_rejected: torch.Tensor,
|
||||
input_buffers: InputBuffers,
|
||||
input_hidden_states: torch.Tensor,
|
||||
max_model_len: int,
|
||||
max_num_reqs: int,
|
||||
):
|
||||
num_reqs = draft_tokens.shape[0]
|
||||
hidden_size = output_hidden_states.shape[-1]
|
||||
_prepare_eagle_docode_kernel[(num_reqs + 1,)](
|
||||
draft_tokens,
|
||||
output_hidden_states,
|
||||
output_hidden_states.stride(0),
|
||||
last_token_indices,
|
||||
target_seq_lens,
|
||||
num_rejected,
|
||||
input_buffers.input_ids.gpu,
|
||||
input_buffers.positions,
|
||||
input_hidden_states,
|
||||
input_hidden_states.stride(0),
|
||||
input_buffers.query_start_loc.gpu,
|
||||
input_buffers.seq_lens,
|
||||
hidden_size,
|
||||
max_model_len,
|
||||
max_num_reqs,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _update_eagle_inputs_kernel(
|
||||
input_ids_ptr,
|
||||
positions_ptr,
|
||||
input_hidden_states_ptr,
|
||||
input_hidden_states_stride,
|
||||
seq_lens_ptr,
|
||||
max_model_len,
|
||||
draft_tokens_ptr,
|
||||
output_hidden_states_ptr,
|
||||
output_hidden_states_stride,
|
||||
hidden_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
|
||||
# Draft token -> Input ID.
|
||||
draft_token = tl.load(draft_tokens_ptr + req_idx)
|
||||
tl.store(input_ids_ptr + req_idx, draft_token)
|
||||
|
||||
# Output hidden states -> Input hidden states.
|
||||
for i in range(0, hidden_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < hidden_size
|
||||
output_hidden_states = tl.load(
|
||||
output_hidden_states_ptr + req_idx * output_hidden_states_stride + block,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(
|
||||
input_hidden_states_ptr + req_idx * input_hidden_states_stride + block,
|
||||
output_hidden_states,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Increment position and seq_lens.
|
||||
# NOTE(woosuk): To prevent out-of-range access, we clamp these values
|
||||
# if they reach the max model length.
|
||||
position = tl.load(positions_ptr + req_idx)
|
||||
position = tl.minimum(position + 1, max_model_len - 1)
|
||||
tl.store(positions_ptr + req_idx, position)
|
||||
|
||||
seq_len = tl.load(seq_lens_ptr + req_idx)
|
||||
seq_len = tl.minimum(seq_len + 1, max_model_len)
|
||||
tl.store(seq_lens_ptr + req_idx, seq_len)
|
||||
|
||||
|
||||
def update_eagle_inputs(
|
||||
draft_tokens: torch.Tensor,
|
||||
output_hidden_states: torch.Tensor,
|
||||
input_buffers: InputBuffers,
|
||||
hidden_states: torch.Tensor,
|
||||
max_model_len: int,
|
||||
):
|
||||
num_reqs, hidden_size = output_hidden_states.shape
|
||||
_update_eagle_inputs_kernel[(num_reqs,)](
|
||||
input_buffers.input_ids.gpu,
|
||||
input_buffers.positions,
|
||||
hidden_states,
|
||||
hidden_states.stride(0),
|
||||
input_buffers.seq_lens,
|
||||
max_model_len,
|
||||
draft_tokens,
|
||||
output_hidden_states,
|
||||
output_hidden_states.stride(0),
|
||||
hidden_size,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
112
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
Normal file
112
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
Normal file
@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.cudagraph_utils import (
|
||||
capture_graphs,
|
||||
get_cudagraph_sizes,
|
||||
prepare_inputs_to_capture,
|
||||
)
|
||||
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
|
||||
from vllm.v1.worker.gpu.input_batch import InputBuffers
|
||||
|
||||
|
||||
class EagleCudaGraphManager:
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
self.vllm_config = vllm_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device = device
|
||||
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.compilation_config is not None
|
||||
|
||||
if self.compilation_config.cudagraph_mode is None:
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
else:
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
if self.cudagraph_mode == CUDAGraphMode.FULL:
|
||||
# NOTE(woosuk): For Eagle, we only use CUDA graphs for decode.
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY
|
||||
|
||||
self.cudagraph_sizes = get_cudagraph_sizes(
|
||||
self.compilation_config.cudagraph_capture_sizes,
|
||||
self.max_num_reqs,
|
||||
self.max_num_tokens,
|
||||
self.cudagraph_mode,
|
||||
)
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
|
||||
def get_cudagraph_size(self, num_tokens: int) -> int | None:
|
||||
return self.cudagraph_sizes.get(num_tokens)
|
||||
|
||||
def capture_graph(
|
||||
self,
|
||||
num_tokens: int,
|
||||
generate_fn: Callable,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
num_reqs = min(num_tokens, self.max_num_reqs)
|
||||
attn_metadata = prepare_inputs_to_capture(
|
||||
num_reqs,
|
||||
num_tokens,
|
||||
input_buffers,
|
||||
block_tables,
|
||||
attn_metadata_builders,
|
||||
self.max_model_len,
|
||||
kv_cache_config,
|
||||
)
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
|
||||
|
||||
# Warm up.
|
||||
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
|
||||
|
||||
# Capture the graph.
|
||||
assert num_tokens not in self.graphs
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, self.pool):
|
||||
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp)
|
||||
self.graphs[num_tokens] = graph
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture(
|
||||
self,
|
||||
generate_fn: Callable,
|
||||
input_buffers: InputBuffers,
|
||||
block_tables: BlockTables,
|
||||
attn_metadata_builders: list[AttentionMetadataBuilder],
|
||||
kv_cache_config: KVCacheConfig,
|
||||
) -> None:
|
||||
capture_graphs(
|
||||
self.cudagraph_sizes,
|
||||
self.device,
|
||||
self.capture_graph,
|
||||
generate_fn=generate_fn,
|
||||
input_buffers=input_buffers,
|
||||
block_tables=block_tables,
|
||||
attn_metadata_builders=attn_metadata_builders,
|
||||
kv_cache_config=kv_cache_config,
|
||||
)
|
||||
|
||||
def run(self, num_tokens: int) -> None:
|
||||
assert num_tokens in self.graphs
|
||||
self.graphs[num_tokens].replay()
|
||||
Loading…
x
Reference in New Issue
Block a user