[Core] draft_model_runner: Implement prepare_inputs on GPU for advance_step (#6338)

This commit is contained in:
Alexander Matveev 2024-07-17 17:30:28 -04:00 committed by GitHub
parent 5f0b9933e6
commit e76466dde2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 568 additions and 130 deletions

View File

@ -151,6 +151,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp")
if(VLLM_GPU_LANG STREQUAL "CUDA")

View File

@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
#ifndef USE_ROCM
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,

View File

@ -0,0 +1,131 @@
/*
* The goal of this GPU kernel is to advance input tensors on the GPU directly
* PR: https://github.com/vllm-project/vllm/pull/6338
* Current restrictions:
* 1. Specialized for DraftModelRunner
* 2. Supports flash_attn only
*/
#include "advance_step.cuh"
namespace prepare_inputs {
//
template <int const num_threads>
__global__ void advance_step_kernel(int num_seqs, int num_queries,
int block_size, long* input_tokens_ptr,
long const* sampled_token_ids_ptr,
long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr,
int const* block_tables_ptr,
int64_t const block_tables_stride) {
int num_query_blocks = div_ceil(num_queries, num_threads);
if (blockIdx.x >= num_query_blocks) {
return;
}
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
if (cur_query_id >= num_queries) {
return;
}
// Update input_tokens
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
int seq_len = seq_lens_ptr[cur_query_id];
int next_seq_len = seq_len + 1;
int next_input_pos = next_seq_len - 1;
// Update seq_lens
seq_lens_ptr[cur_query_id] = next_seq_len;
// Update input_positions
input_positions_ptr[cur_query_id] = next_input_pos;
int const* seq_block_tables_ptr =
block_tables_ptr + block_tables_stride * cur_query_id;
int block_index = next_input_pos / block_size;
int block_offset = next_input_pos % block_size;
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
// Update slot_mapping
slot_mapping_ptr[cur_query_id] = slot_num;
}
inline void verify_tensor(std::string const& name, torch::Tensor& t,
int64_t const size_0, int64_t const size_1,
c10::ScalarType const type) {
bool size_0_cond = true;
if (size_0 != -1) {
size_0_cond = t.size(0) == size_0;
}
bool size_1_cond = true;
if (size_1 != -1) {
size_1_cond = t.size(1) == size_1;
}
bool is_contiguous = t.is_contiguous();
bool same_type = t.dtype() == type;
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
if (!pass) {
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
" is not as expected: shape = [", size_0, ", ", size_1,
"], type = ", type);
}
}
void advance_step(int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
torch::Tensor& input_positions, // type: long
torch::Tensor& seq_lens, // type: int
torch::Tensor& slot_mapping, // type: long
torch::Tensor& block_tables) { // type: int
if (logging) {
printf("advance_step:\n");
printf(" num_seqs = %d\n", num_seqs);
printf(" num_queries = %d\n", num_queries);
printf(" block_size = %d\n", block_size);
}
// Verify all tensors
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
at::kLong);
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
int dev = sampled_token_ids.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
num_seqs, num_queries, block_size,
reinterpret_cast<long*>(input_tokens.data_ptr()),
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
reinterpret_cast<long*>(input_positions.data_ptr()),
reinterpret_cast<int*>(seq_lens.data_ptr()),
reinterpret_cast<long*>(slot_mapping.data_ptr()),
reinterpret_cast<int const*>(block_tables.data_ptr()),
block_tables.stride(0));
}
} // namespace prepare_inputs
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
torch::Tensor& input_positions, torch::Tensor& seq_lens,
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
sampled_token_ids, input_positions, seq_lens,
slot_mapping, block_tables);
}

View File

@ -0,0 +1,19 @@
#pragma once
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
namespace prepare_inputs {
static constexpr int max_threads = 256;
static constexpr bool logging = false;
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
} // namespace prepare_inputs

View File

@ -72,6 +72,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
// prepare_inputs advance_step
ops.def("advance_step", &advance_step);
ops.impl("advance_step", torch::kCUDA, &advance_step);
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(

View File

@ -227,6 +227,7 @@ def get_output_from_llm_generator(
maybe_assert_ngram_worker(llm)
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
token_ids = [output.outputs[0].token_ids for output in outputs]
tokens = [output.outputs[0].text for output in outputs]

View File

@ -642,3 +642,51 @@ def test_draft_proposals_mixed_k():
assert proposals.proposal_lens.tolist() == [
k for _ in range(expected_num_proposal_seqs - 1)
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
@torch.inference_mode()
def test_use_draft_model_runner_advance_step():
"""Verify that draft model runner triggers advance step
when applicable.
"""
seed = 100
model_name = 'JackFram/llama-68m'
k = 5
batch_size = 32
block_size = 32
num_gpu_blocks = 2048 // block_size
worker = create_worker(
MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
)
# Mock "_gpu_advance_step" to raise an exception when called.
exception_secret = "artificial stop"
worker.model_runner._gpu_advance_step = MagicMock()
worker.model_runner._gpu_advance_step.side_effect = ValueError(
exception_secret)
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
# Fallback (should not call) when num_steps=1.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=1)
worker.execute_model(execute_model_req=execute_model_req)
# Expect exception if _gpu_advance_step is called.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k,
num_steps=k)
with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(execute_model_req=execute_model_req)
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
assert len(call_args_list) == 1

View File

@ -166,6 +166,18 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
def advance_step(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor, seq_lens: torch.Tensor,
slot_mapping: torch.Tensor,
block_tables: torch.Tensor) -> None:
"""Advance a step on GPU for existing inputs for a multi-step runner"""
return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping,
block_tables)
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,

View File

@ -47,6 +47,32 @@ class Sampler(nn.Module):
# speculative decoding.
self.include_gpu_probs_tensor = False
def _init_sampling_tensors(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
):
"""The goal here is to reuse sampling tensors between similar decode
runs. This is possible because sampling logic does not change between
decodes of the same sequences.
"""
_, vocab_size = logits.shape
# First free any existing stored sampling tensors.
# This is necessary because some sampling tensors may
# have pinned memory.
self._sampling_tensors = None
# Initialize new sampling tensors
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype)
self._sampling_tensors = sampling_tensors
self._do_penalties = do_penalties
self._do_top_p_top_k = do_top_p_top_k
self._do_min_p = do_min_p
def forward(
self,
logits: torch.Tensor,
@ -60,12 +86,23 @@ class Sampler(nn.Module):
assert logits is not None
_, vocab_size = logits.shape
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
sampling_metadata, vocab_size, logits.device, logits.dtype)
if not sampling_metadata.reuse_sampling_tensors:
self._init_sampling_tensors(logits, sampling_metadata)
elif self._do_penalties:
# In this case, the sampling tensors logic depends on
# "output_tokens" of a sequence. As a result, we cannot
# reuse sampling tensors, since "output_tokens" changes
# between decode runs.
self._init_sampling_tensors(logits, sampling_metadata)
assert self._sampling_tensors is not None
sampling_tensors = self._sampling_tensors
do_penalties = self._do_penalties
do_top_p_top_k = self._do_top_p_top_k
do_min_p = self._do_min_p
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Apply presence and frequency penalties.
if do_penalties:
@ -77,7 +114,7 @@ class Sampler(nn.Module):
# Apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
if do_top_p_top_k:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
@ -109,13 +146,19 @@ class Sampler(nn.Module):
on_device_tensors = None
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors)
prompt_logprobs = None
sample_logprobs = None
if not sampling_metadata.skip_sampler_cpu_output:
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(
sample_results,
sampling_metadata,
prompt_logprobs,
sample_logprobs,
on_device_tensors=on_device_tensors,
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
@property
def _should_modify_greedy_probs_inplace(self) -> bool:
@ -535,24 +578,29 @@ def _sample_with_torch(
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
if not sampling_metadata.skip_sampler_cpu_output:
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
else:
sample_results = []
sample_results = [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results, sampled_token_ids_tensor
@ -997,10 +1045,11 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
def _build_sampler_output(
sample_results: SampleResultType,
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
sample_logprobs: Optional[List[SampleLogprobs]],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]],
skip_sampler_cpu_output: bool = False,
) -> SamplerOutput:
"""Construct Python objects with the output of sampling.
@ -1010,22 +1059,26 @@ def _build_sampler_output(
allows post-processing without copies to CPU/serialization, e.g. in
speculative decoding rejection sampling.
"""
sampler_output: List[CompletionSequenceGroupOutput] = []
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs: List[SequenceOutput] = []
for parent_id, next_token_id, logprobs in zip(parent_ids,
next_token_ids,
group_sample_logprobs):
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
sampler_output.append(
CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))
if not skip_sampler_cpu_output:
assert prompt_logprobs is not None
assert sample_logprobs is not None
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs: List[SequenceOutput] = []
for parent_id, next_token_id, logprobs in zip(
parent_ids, next_token_ids, group_sample_logprobs):
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
logprobs))
sampler_output.append(
CompletionSequenceGroupOutput(seq_outputs,
group_prompt_logprobs))
# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:

View File

@ -87,6 +87,12 @@ class SamplingMetadata:
The first tuple is [1, 2] (sampled index within original logit),
and the second tuple is [0, 1] (sampled index within pruned logit).
num_prompts: Number of prompt sequence groups in seq_groups.
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
serialization of token outputs.
reuse_sampling_tensors: Indicates if we want to reuse sampling
tensors that are part of the sampler forward pass. Currently,
it is mainly used for multi-step decode.
"""
def __init__(
@ -95,11 +101,15 @@ class SamplingMetadata:
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
num_prompts: int,
skip_sampler_cpu_output: bool = False,
reuse_sampling_tensors: bool = False,
) -> None:
self.seq_groups = seq_groups
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.num_prompts = num_prompts
self.skip_sampler_cpu_output = skip_sampler_cpu_output
self.reuse_sampling_tensors = reuse_sampling_tensors
@staticmethod
def prepare(

View File

@ -2,17 +2,22 @@ from typing import List, Optional
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
logger = init_logger(__name__)
debug_advance_input = False
enable_gpu_advance_step = True
class TP1DraftModelRunner(ModelRunner):
"""Specialized model runner for speculative decoding draft model.
@ -21,18 +26,9 @@ class TP1DraftModelRunner(ModelRunner):
we could get rid of most CPU-GPU synchronization and data transfer
overheads by keeping model input and output tensors on GPU all the time.
This runner is still under development so there's no performance gain
at this moment. Currently we adopt a temporary solution that caches the
seq_group_metadata_list for multi-step execution, so that we can
leverage existing prepare_model_input to be compatible with the current
execution flow, but we plan to remove this cache and avoid calling
prepare_model_input in execute_model at all.
The detail development plan includes:
1. Use "update_model_input" to update existing model_input without
creating a new one.
2. Improve the performance of "update_model_input" with a GPU kernel.
3. Support TP > 1 (this requires some designs because we do not expect
TODOs:
1. Currently supports only flash-attn, add support for other attn_backends.
2. Support TP > 1 (this requires some designs because we do not expect
any broadcasting inside execute_model).
"""
@ -71,51 +67,156 @@ class TP1DraftModelRunner(ModelRunner):
return_hidden_states=return_hidden_states,
)
# TODO: Remove this cache when we are able to update model_input
# directly in advance_step.
self.cached_seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
num_queries):
assert isinstance(attn_metadata, FlashAttentionMetadata)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForGPUWithSamplingMetadata:
"""A temporary solution that caches the seq_group_metadata_list
for multi-step execution.
TODO: In-place update model_input and remove this function.
"""
self.cached_seq_group_metadata_list = seq_group_metadata_list
return super().prepare_model_input(
seq_group_metadata_list,
finished_requests_ids=finished_requests_ids)
if num_seqs != num_queries:
assert num_seqs > num_queries
assert attn_metadata.use_cuda_graph
def update_model_input(
assert attn_metadata.num_prefills == 0
assert attn_metadata.num_prefill_tokens == 0
assert attn_metadata.num_decode_tokens == num_seqs
assert attn_metadata.slot_mapping.shape == (num_seqs, )
assert len(attn_metadata.seq_lens) == num_seqs
assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
assert attn_metadata.max_query_len == 1
assert attn_metadata.max_prefill_seq_len == 0
assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)
assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )
assert attn_metadata.context_lens_tensor.shape == (num_queries, )
assert attn_metadata.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
attn_metadata.seq_lens[i] += 1
attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
assert sampling_metadata.num_prompts == 0
assert len(sampling_metadata.seq_groups) == num_queries
assert sampling_metadata.selected_token_indices.shape == (
num_queries, )
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
# Verify that all sequences are decodes
for i in range(num_queries):
seq_group = sampling_metadata.seq_groups[i]
assert seq_group.is_prompt is False # No prompt
assert seq_group.prompt_logprob_indices == [] # No prompt
assert seq_group.sample_indices == [i] # Simple
assert seq_group.seq_len is None # Decode
assert seq_group.query_len is None # Decode
def _gpu_advance_step(
self, model_input: ModelInputForGPUWithSamplingMetadata,
last_output: SamplerOutput
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model inputs for the next step.
TODO: In-place update model_input instead of calling
prepare_model_input.
# Currently, we expect "decode mode" only
assert not model_input.is_prompt
# Get num_seqs
num_seqs = len(model_input.seq_lens)
num_queries = len(model_input.query_lens)
# Get output tokens GPU tensor
sampled_token_ids = last_output.sampled_token_ids
assert sampled_token_ids is not None
# Update attn_metadata
attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
# Update GPU tensors
ops.advance_step(num_seqs=num_seqs,
num_queries=num_queries,
block_size=self.block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=attn_metadata.seq_lens_tensor,
slot_mapping=attn_metadata.slot_mapping,
block_tables=attn_metadata.block_tables)
# Update sampling_metadata
sampling_metadata = model_input.sampling_metadata
self._update_sampling_metadata(sampling_metadata, num_seqs,
num_queries)
# Create new input
new_model_input = self._model_input_cls(
input_tokens=model_input.input_tokens,
input_positions=model_input.input_positions,
attn_metadata=attn_metadata,
seq_lens=attn_metadata.seq_lens,
query_lens=model_input.query_lens,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
sampling_metadata=model_input.sampling_metadata,
is_prompt=False,
)
# Ensure we skip CPU samples
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
# We can reuse sampling tensors since every decode iteration is the same
new_model_input.sampling_metadata.reuse_sampling_tensors = True
if debug_advance_input:
logger.debug("NEW INPUT: ")
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
logger.debug(" input_positions = %s",
new_model_input.input_positions)
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
logger.debug(" query_lens = %d", new_model_input.query_lens)
logger.debug(" attn_metadata:")
logger.debug(" seq_lens_tensor: %s",
attn_metadata.seq_lens_tensor)
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
logger.debug(" block_tables: %s", attn_metadata.block_tables)
return new_model_input
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
"""Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are:
1. Only decodes
2. Only flash-attn
3. No LORA
4. No prompt_adapter_config
"""
if not enable_gpu_advance_step:
return False
# Append the output token to the sequence data.
assert self.cached_seq_group_metadata_list is not None
for seq_group_metadata, sequence_group_outputs in zip(
self.cached_seq_group_metadata_list, last_output.outputs):
seq_group_metadata.is_prompt = False
# We allow multi-step GPU only in decode mode
for seq_group in execute_model_req.seq_group_metadata_list:
if seq_group.is_prompt:
return False
for seq_output in sequence_group_outputs.samples:
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
# TODO: Add support for other attn backends
if self.attn_backend.get_name() != "flash-attn":
return False
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
# TODO: Add support for LORA
if self.lora_config:
return False
seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1)
# TODO: Add soft-tuning prompt adapter support
if self.prompt_adapter_config:
return False
return self.prepare_model_input(self.cached_seq_group_metadata_list)
return True
@torch.inference_mode()
def execute_model(
@ -125,42 +226,86 @@ class TP1DraftModelRunner(ModelRunner):
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
"""Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
Optimizations used:
1. Input tensors are updated on the GPU directly
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
them since we do batch expansion later that uses GPU outputs)
3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic)
"""
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# When num_steps == 1, we execute the fallback here for the GPU
# advance_step, which runs prepare_inputs on CPU and for each spec
# iteration invokes this function only once
# (Look at multi-step-worker code)
is_fallback = num_steps == 1
if not is_fallback:
# Since we do not broadcast data inside execute_model anymore,
# we need to figure out the best way to support TP > 1 in this
# case, because we will at least need to broadcast the sampled
# tokens to all workers.
if not self.is_driver_worker:
raise ValueError("TP1DraftModelRunner only supports TP=1.")
# Sanity
if self.lora_config is not None:
raise ValueError("TP1DraftModelRunner has no support for LORA")
if self.prompt_adapter_config is not None:
raise ValueError("TP1DraftModelRunner has no support for "
"prompt_adapter_config")
if model_input.multi_modal_kwargs:
raise ValueError(
"TP1DraftModelRunner has no support for multi_modal_kwargs"
)
else:
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)
# Detect exec mode
assert model_input.attn_metadata is not None
use_cuda_graph = False
if model_input.attn_metadata.num_prefills > 0:
# In this case, execute_model(..) was called directly
if num_steps > 1:
raise ValueError(
"execute_model(..) of draft_model_runner can be called "
"directly only with a single-step prefill")
else:
# We can skip CPU samples for spec token generation.
# (We do allow CPU samples for num_steps == 1 to support the
# fallback case, where supports_gpu_multi_step(..) does not pass)
model_input.sampling_metadata.skip_sampler_cpu_output = (
not is_fallback)
# Attn attr defines if we use cuda graphs
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
# Get model
if use_cuda_graph:
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (self.graph_runners[model_input.virtual_engine]
[graph_batch_size])
else:
model_executable = self.model
virtual_engine = model_input.virtual_engine
outputs: List[SamplerOutput] = []
for step in range(num_steps):
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[virtual_engine][graph_batch_size])
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
# Run model
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
@ -181,8 +326,8 @@ class TP1DraftModelRunner(ModelRunner):
sampling_metadata=model_input.sampling_metadata,
))
# Prepare the inputs for the next step.
# Prepare inputs for the next step
if step != num_steps - 1:
model_input = self.update_model_input(model_input, outputs[-1])
model_input = self._gpu_advance_step(model_input, outputs[-1])
return outputs

View File

@ -67,14 +67,23 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
expanded_request, indices_of_seq_with_bonus_tokens =\
self._expand_execute_model_request(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
# Run model sample_len times.
model_outputs: List[SamplerOutput] = []
if isinstance(self.model_runner, TP1DraftModelRunner):
if isinstance(
self.model_runner, TP1DraftModelRunner
) and self.model_runner.supports_gpu_multi_step(expanded_request):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
# TODO: Remove this branch once DraftModelRunner supports TP>1.
# Here we run multi-step directly, with every step prepared
# on the CPU.
# TODO: Remove this branch once DraftModelRunner supports TP>1
# and other restrictions that are part of DraftModelRunner's
# supports_gpu_multi_step(..)
for _ in range(sample_len):
model_output: List[SamplerOutput] = super().execute_model(
execute_model_req=expanded_request)
@ -171,7 +180,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
outputs=[
expanded_batch_output.outputs[i]
for i in output_indices_to_retain
],
] if len(expanded_batch_output.outputs) > 0 else [],
sampled_token_probs=(
expanded_batch_output.
sampled_token_probs[output_indices_to_retain]