mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:04:58 +08:00
[Core] draft_model_runner: Implement prepare_inputs on GPU for advance_step (#6338)
This commit is contained in:
parent
5f0b9933e6
commit
e76466dde2
@ -151,6 +151,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/moe_align_block_size_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
@ -52,6 +52,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
|
||||
131
csrc/prepare_inputs/advance_step.cu
Normal file
131
csrc/prepare_inputs/advance_step.cu
Normal file
@ -0,0 +1,131 @@
|
||||
/*
|
||||
* The goal of this GPU kernel is to advance input tensors on the GPU directly
|
||||
* PR: https://github.com/vllm-project/vllm/pull/6338
|
||||
* Current restrictions:
|
||||
* 1. Specialized for DraftModelRunner
|
||||
* 2. Supports flash_attn only
|
||||
*/
|
||||
|
||||
#include "advance_step.cuh"
|
||||
|
||||
namespace prepare_inputs {
|
||||
|
||||
//
|
||||
template <int const num_threads>
|
||||
__global__ void advance_step_kernel(int num_seqs, int num_queries,
|
||||
int block_size, long* input_tokens_ptr,
|
||||
long const* sampled_token_ids_ptr,
|
||||
long* input_positions_ptr,
|
||||
int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||
int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride) {
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x >= num_query_blocks) {
|
||||
return;
|
||||
}
|
||||
|
||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
if (cur_query_id >= num_queries) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Update input_tokens
|
||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||
|
||||
int seq_len = seq_lens_ptr[cur_query_id];
|
||||
int next_seq_len = seq_len + 1;
|
||||
int next_input_pos = next_seq_len - 1;
|
||||
|
||||
// Update seq_lens
|
||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||
// Update input_positions
|
||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||
|
||||
int const* seq_block_tables_ptr =
|
||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||
|
||||
int block_index = next_input_pos / block_size;
|
||||
int block_offset = next_input_pos % block_size;
|
||||
|
||||
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||
// Update slot_mapping
|
||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||
}
|
||||
|
||||
inline void verify_tensor(std::string const& name, torch::Tensor& t,
|
||||
int64_t const size_0, int64_t const size_1,
|
||||
c10::ScalarType const type) {
|
||||
bool size_0_cond = true;
|
||||
if (size_0 != -1) {
|
||||
size_0_cond = t.size(0) == size_0;
|
||||
}
|
||||
|
||||
bool size_1_cond = true;
|
||||
if (size_1 != -1) {
|
||||
size_1_cond = t.size(1) == size_1;
|
||||
}
|
||||
|
||||
bool is_contiguous = t.is_contiguous();
|
||||
bool same_type = t.dtype() == type;
|
||||
|
||||
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
|
||||
if (!pass) {
|
||||
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
|
||||
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
|
||||
" is not as expected: shape = [", size_0, ", ", size_1,
|
||||
"], type = ", type);
|
||||
}
|
||||
}
|
||||
|
||||
void advance_step(int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
}
|
||||
// Verify all tensors
|
||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||
at::kLong);
|
||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||
|
||||
int dev = sampled_token_ids.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
|
||||
num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0));
|
||||
}
|
||||
|
||||
} // namespace prepare_inputs
|
||||
|
||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
|
||||
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
|
||||
sampled_token_ids, input_positions, seq_lens,
|
||||
slot_mapping, block_tables);
|
||||
}
|
||||
19
csrc/prepare_inputs/advance_step.cuh
Normal file
19
csrc/prepare_inputs/advance_step.cuh
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace prepare_inputs {
|
||||
|
||||
static constexpr int max_threads = 256;
|
||||
static constexpr bool logging = false;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
} // namespace prepare_inputs
|
||||
@ -72,6 +72,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("gelu_quick(Tensor! out, Tensor input) -> ()");
|
||||
ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
|
||||
|
||||
// prepare_inputs advance_step
|
||||
ops.def("advance_step", &advance_step);
|
||||
ops.impl("advance_step", torch::kCUDA, &advance_step);
|
||||
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
|
||||
@ -227,6 +227,7 @@ def get_output_from_llm_generator(
|
||||
maybe_assert_ngram_worker(llm)
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
|
||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||
tokens = [output.outputs[0].text for output in outputs]
|
||||
|
||||
|
||||
@ -642,3 +642,51 @@ def test_draft_proposals_mixed_k():
|
||||
assert proposals.proposal_lens.tolist() == [
|
||||
k for _ in range(expected_num_proposal_seqs - 1)
|
||||
] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_use_draft_model_runner_advance_step():
|
||||
"""Verify that draft model runner triggers advance step
|
||||
when applicable.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
k = 5
|
||||
batch_size = 32
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
|
||||
# Mock "_gpu_advance_step" to raise an exception when called.
|
||||
exception_secret = "artificial stop"
|
||||
worker.model_runner._gpu_advance_step = MagicMock()
|
||||
worker.model_runner._gpu_advance_step.side_effect = ValueError(
|
||||
exception_secret)
|
||||
|
||||
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||
|
||||
# Fallback (should not call) when num_steps=1.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
num_steps=1)
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
# Expect exception if _gpu_advance_step is called.
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
num_lookahead_slots=k,
|
||||
num_steps=k)
|
||||
|
||||
with pytest.raises(ValueError, match=exception_secret):
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
|
||||
assert len(call_args_list) == 1
|
||||
|
||||
@ -166,6 +166,18 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||
torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
|
||||
|
||||
|
||||
def advance_step(num_seqs: int, num_queries: int, block_size: int,
|
||||
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
|
||||
input_positions: torch.Tensor, seq_lens: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
block_tables: torch.Tensor) -> None:
|
||||
"""Advance a step on GPU for existing inputs for a multi-step runner"""
|
||||
return torch.ops._C.advance_step(num_seqs, num_queries, block_size,
|
||||
input_tokens, sampled_token_ids,
|
||||
input_positions, seq_lens, slot_mapping,
|
||||
block_tables)
|
||||
|
||||
|
||||
# quantization ops
|
||||
# awq
|
||||
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
|
||||
|
||||
@ -47,6 +47,32 @@ class Sampler(nn.Module):
|
||||
# speculative decoding.
|
||||
self.include_gpu_probs_tensor = False
|
||||
|
||||
def _init_sampling_tensors(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
):
|
||||
"""The goal here is to reuse sampling tensors between similar decode
|
||||
runs. This is possible because sampling logic does not change between
|
||||
decodes of the same sequences.
|
||||
"""
|
||||
_, vocab_size = logits.shape
|
||||
|
||||
# First free any existing stored sampling tensors.
|
||||
# This is necessary because some sampling tensors may
|
||||
# have pinned memory.
|
||||
self._sampling_tensors = None
|
||||
|
||||
# Initialize new sampling tensors
|
||||
(sampling_tensors, do_penalties, do_top_p_top_k,
|
||||
do_min_p) = SamplingTensors.from_sampling_metadata(
|
||||
sampling_metadata, vocab_size, logits.device, logits.dtype)
|
||||
|
||||
self._sampling_tensors = sampling_tensors
|
||||
self._do_penalties = do_penalties
|
||||
self._do_top_p_top_k = do_top_p_top_k
|
||||
self._do_min_p = do_min_p
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
@ -60,12 +86,23 @@ class Sampler(nn.Module):
|
||||
assert logits is not None
|
||||
_, vocab_size = logits.shape
|
||||
|
||||
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||
|
||||
# Prepare sampling tensors with pinned memory to avoid blocking.
|
||||
(sampling_tensors, do_penalties, do_top_p_top_k,
|
||||
do_min_p) = SamplingTensors.from_sampling_metadata(
|
||||
sampling_metadata, vocab_size, logits.device, logits.dtype)
|
||||
if not sampling_metadata.reuse_sampling_tensors:
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
elif self._do_penalties:
|
||||
# In this case, the sampling tensors logic depends on
|
||||
# "output_tokens" of a sequence. As a result, we cannot
|
||||
# reuse sampling tensors, since "output_tokens" changes
|
||||
# between decode runs.
|
||||
self._init_sampling_tensors(logits, sampling_metadata)
|
||||
|
||||
assert self._sampling_tensors is not None
|
||||
sampling_tensors = self._sampling_tensors
|
||||
do_penalties = self._do_penalties
|
||||
do_top_p_top_k = self._do_top_p_top_k
|
||||
do_min_p = self._do_min_p
|
||||
|
||||
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
|
||||
|
||||
# Apply presence and frequency penalties.
|
||||
if do_penalties:
|
||||
@ -77,7 +114,7 @@ class Sampler(nn.Module):
|
||||
|
||||
# Apply temperature scaling.
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
|
||||
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
|
||||
|
||||
if do_top_p_top_k:
|
||||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
|
||||
@ -109,13 +146,19 @@ class Sampler(nn.Module):
|
||||
on_device_tensors = None
|
||||
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs, sample_logprobs = _get_logprobs(
|
||||
logprobs, sampling_metadata, sample_results)
|
||||
return _build_sampler_output(sample_results,
|
||||
sampling_metadata,
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
on_device_tensors=on_device_tensors)
|
||||
prompt_logprobs = None
|
||||
sample_logprobs = None
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
prompt_logprobs, sample_logprobs = _get_logprobs(
|
||||
logprobs, sampling_metadata, sample_results)
|
||||
|
||||
return _build_sampler_output(
|
||||
sample_results,
|
||||
sampling_metadata,
|
||||
prompt_logprobs,
|
||||
sample_logprobs,
|
||||
on_device_tensors=on_device_tensors,
|
||||
skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)
|
||||
|
||||
@property
|
||||
def _should_modify_greedy_probs_inplace(self) -> bool:
|
||||
@ -535,24 +578,29 @@ def _sample_with_torch(
|
||||
|
||||
# GPU<->CPU sync happens in the loop below.
|
||||
# This also converts the sample output to Python objects.
|
||||
for sampling_type in SamplingType:
|
||||
if sampling_type not in sample_metadata:
|
||||
continue
|
||||
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
sample_results = _random_sample(seq_groups,
|
||||
multinomial_samples[sampling_type])
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
sample_results = _beam_search_sample(seq_groups,
|
||||
beam_search_logprobs)
|
||||
sample_results_dict.update(zip(seq_group_id, sample_results))
|
||||
if not sampling_metadata.skip_sampler_cpu_output:
|
||||
for sampling_type in SamplingType:
|
||||
if sampling_type not in sample_metadata:
|
||||
continue
|
||||
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(seq_groups, greedy_samples)
|
||||
elif sampling_type in (SamplingType.RANDOM,
|
||||
SamplingType.RANDOM_SEED):
|
||||
sample_results = _random_sample(
|
||||
seq_groups, multinomial_samples[sampling_type])
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
sample_results = _beam_search_sample(seq_groups,
|
||||
beam_search_logprobs)
|
||||
sample_results_dict.update(zip(seq_group_id, sample_results))
|
||||
|
||||
sample_results = [
|
||||
sample_results_dict.get(i, ([], []))
|
||||
for i in range(len(sampling_metadata.seq_groups))
|
||||
]
|
||||
else:
|
||||
sample_results = []
|
||||
|
||||
sample_results = [
|
||||
sample_results_dict.get(i, ([], []))
|
||||
for i in range(len(sampling_metadata.seq_groups))
|
||||
]
|
||||
return sample_results, sampled_token_ids_tensor
|
||||
|
||||
|
||||
@ -997,10 +1045,11 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
|
||||
def _build_sampler_output(
|
||||
sample_results: SampleResultType,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
prompt_logprobs: List[Optional[PromptLogprobs]],
|
||||
sample_logprobs: List[SampleLogprobs],
|
||||
prompt_logprobs: Optional[List[Optional[PromptLogprobs]]],
|
||||
sample_logprobs: Optional[List[SampleLogprobs]],
|
||||
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]],
|
||||
skip_sampler_cpu_output: bool = False,
|
||||
) -> SamplerOutput:
|
||||
"""Construct Python objects with the output of sampling.
|
||||
|
||||
@ -1010,22 +1059,26 @@ def _build_sampler_output(
|
||||
allows post-processing without copies to CPU/serialization, e.g. in
|
||||
speculative decoding rejection sampling.
|
||||
"""
|
||||
|
||||
sampler_output: List[CompletionSequenceGroupOutput] = []
|
||||
for (seq_group, sample_result, group_prompt_logprobs,
|
||||
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
||||
sample_results, prompt_logprobs,
|
||||
sample_logprobs):
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids, parent_ids = sample_result
|
||||
seq_outputs: List[SequenceOutput] = []
|
||||
for parent_id, next_token_id, logprobs in zip(parent_ids,
|
||||
next_token_ids,
|
||||
group_sample_logprobs):
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id, logprobs))
|
||||
sampler_output.append(
|
||||
CompletionSequenceGroupOutput(seq_outputs, group_prompt_logprobs))
|
||||
if not skip_sampler_cpu_output:
|
||||
assert prompt_logprobs is not None
|
||||
assert sample_logprobs is not None
|
||||
|
||||
for (seq_group, sample_result, group_prompt_logprobs,
|
||||
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
|
||||
sample_results, prompt_logprobs,
|
||||
sample_logprobs):
|
||||
seq_ids = seq_group.seq_ids
|
||||
next_token_ids, parent_ids = sample_result
|
||||
seq_outputs: List[SequenceOutput] = []
|
||||
for parent_id, next_token_id, logprobs in zip(
|
||||
parent_ids, next_token_ids, group_sample_logprobs):
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||
logprobs))
|
||||
sampler_output.append(
|
||||
CompletionSequenceGroupOutput(seq_outputs,
|
||||
group_prompt_logprobs))
|
||||
|
||||
# If not specified, store None values in SamplerOutput.
|
||||
if on_device_tensors is not None:
|
||||
|
||||
@ -87,6 +87,12 @@ class SamplingMetadata:
|
||||
The first tuple is [1, 2] (sampled index within original logit),
|
||||
and the second tuple is [0, 1] (sampled index within pruned logit).
|
||||
num_prompts: Number of prompt sequence groups in seq_groups.
|
||||
skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU
|
||||
serialization of token outputs.
|
||||
reuse_sampling_tensors: Indicates if we want to reuse sampling
|
||||
tensors that are part of the sampler forward pass. Currently,
|
||||
it is mainly used for multi-step decode.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -95,11 +101,15 @@ class SamplingMetadata:
|
||||
selected_token_indices: torch.Tensor,
|
||||
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
|
||||
num_prompts: int,
|
||||
skip_sampler_cpu_output: bool = False,
|
||||
reuse_sampling_tensors: bool = False,
|
||||
) -> None:
|
||||
self.seq_groups = seq_groups
|
||||
self.selected_token_indices = selected_token_indices
|
||||
self.categorized_sample_indices = categorized_sample_indices
|
||||
self.num_prompts = num_prompts
|
||||
self.skip_sampler_cpu_output = skip_sampler_cpu_output
|
||||
self.reuse_sampling_tensors = reuse_sampling_tensors
|
||||
|
||||
@staticmethod
|
||||
def prepare(
|
||||
|
||||
@ -2,17 +2,22 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SamplerOutput)
|
||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
debug_advance_input = False
|
||||
enable_gpu_advance_step = True
|
||||
|
||||
|
||||
class TP1DraftModelRunner(ModelRunner):
|
||||
"""Specialized model runner for speculative decoding draft model.
|
||||
@ -21,18 +26,9 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
we could get rid of most CPU-GPU synchronization and data transfer
|
||||
overheads by keeping model input and output tensors on GPU all the time.
|
||||
|
||||
This runner is still under development so there's no performance gain
|
||||
at this moment. Currently we adopt a temporary solution that caches the
|
||||
seq_group_metadata_list for multi-step execution, so that we can
|
||||
leverage existing prepare_model_input to be compatible with the current
|
||||
execution flow, but we plan to remove this cache and avoid calling
|
||||
prepare_model_input in execute_model at all.
|
||||
|
||||
The detail development plan includes:
|
||||
1. Use "update_model_input" to update existing model_input without
|
||||
creating a new one.
|
||||
2. Improve the performance of "update_model_input" with a GPU kernel.
|
||||
3. Support TP > 1 (this requires some designs because we do not expect
|
||||
TODOs:
|
||||
1. Currently supports only flash-attn, add support for other attn_backends.
|
||||
2. Support TP > 1 (this requires some designs because we do not expect
|
||||
any broadcasting inside execute_model).
|
||||
"""
|
||||
|
||||
@ -71,51 +67,156 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
return_hidden_states=return_hidden_states,
|
||||
)
|
||||
|
||||
# TODO: Remove this cache when we are able to update model_input
|
||||
# directly in advance_step.
|
||||
self.cached_seq_group_metadata_list: Optional[
|
||||
List[SequenceGroupMetadata]] = None
|
||||
def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
|
||||
num_queries):
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""A temporary solution that caches the seq_group_metadata_list
|
||||
for multi-step execution.
|
||||
TODO: In-place update model_input and remove this function.
|
||||
"""
|
||||
self.cached_seq_group_metadata_list = seq_group_metadata_list
|
||||
return super().prepare_model_input(
|
||||
seq_group_metadata_list,
|
||||
finished_requests_ids=finished_requests_ids)
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert attn_metadata.use_cuda_graph
|
||||
|
||||
def update_model_input(
|
||||
assert attn_metadata.num_prefills == 0
|
||||
assert attn_metadata.num_prefill_tokens == 0
|
||||
assert attn_metadata.num_decode_tokens == num_seqs
|
||||
assert attn_metadata.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert len(attn_metadata.seq_lens) == num_seqs
|
||||
assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert attn_metadata.max_query_len == 1
|
||||
assert attn_metadata.max_prefill_seq_len == 0
|
||||
assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)
|
||||
|
||||
assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
|
||||
assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert attn_metadata.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert attn_metadata.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
attn_metadata.seq_lens[i] += 1
|
||||
attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||
num_queries):
|
||||
|
||||
assert sampling_metadata.num_prompts == 0
|
||||
assert len(sampling_metadata.seq_groups) == num_queries
|
||||
assert sampling_metadata.selected_token_indices.shape == (
|
||||
num_queries, )
|
||||
# assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501
|
||||
|
||||
# Verify that all sequences are decodes
|
||||
for i in range(num_queries):
|
||||
seq_group = sampling_metadata.seq_groups[i]
|
||||
|
||||
assert seq_group.is_prompt is False # No prompt
|
||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||
assert seq_group.sample_indices == [i] # Simple
|
||||
assert seq_group.seq_len is None # Decode
|
||||
assert seq_group.query_len is None # Decode
|
||||
|
||||
def _gpu_advance_step(
|
||||
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
last_output: SamplerOutput
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""Prepare the model inputs for the next step.
|
||||
TODO: In-place update model_input instead of calling
|
||||
prepare_model_input.
|
||||
# Currently, we expect "decode mode" only
|
||||
assert not model_input.is_prompt
|
||||
|
||||
# Get num_seqs
|
||||
num_seqs = len(model_input.seq_lens)
|
||||
num_queries = len(model_input.query_lens)
|
||||
|
||||
# Get output tokens GPU tensor
|
||||
sampled_token_ids = last_output.sampled_token_ids
|
||||
assert sampled_token_ids is not None
|
||||
|
||||
# Update attn_metadata
|
||||
attn_metadata = model_input.attn_metadata
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
|
||||
|
||||
# Update GPU tensors
|
||||
ops.advance_step(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=self.block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=attn_metadata.seq_lens_tensor,
|
||||
slot_mapping=attn_metadata.slot_mapping,
|
||||
block_tables=attn_metadata.block_tables)
|
||||
|
||||
# Update sampling_metadata
|
||||
sampling_metadata = model_input.sampling_metadata
|
||||
self._update_sampling_metadata(sampling_metadata, num_seqs,
|
||||
num_queries)
|
||||
|
||||
# Create new input
|
||||
new_model_input = self._model_input_cls(
|
||||
input_tokens=model_input.input_tokens,
|
||||
input_positions=model_input.input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
query_lens=model_input.query_lens,
|
||||
lora_mapping=model_input.lora_mapping,
|
||||
lora_requests=model_input.lora_requests,
|
||||
multi_modal_kwargs=model_input.multi_modal_kwargs,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
is_prompt=False,
|
||||
)
|
||||
|
||||
# Ensure we skip CPU samples
|
||||
assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True
|
||||
# We can reuse sampling tensors since every decode iteration is the same
|
||||
new_model_input.sampling_metadata.reuse_sampling_tensors = True
|
||||
|
||||
if debug_advance_input:
|
||||
logger.debug("NEW INPUT: ")
|
||||
logger.debug(" input_tokens = %s", new_model_input.input_tokens)
|
||||
logger.debug(" input_positions = %s",
|
||||
new_model_input.input_positions)
|
||||
logger.debug(" seq_lens = %d", new_model_input.seq_lens)
|
||||
logger.debug(" query_lens = %d", new_model_input.query_lens)
|
||||
logger.debug(" attn_metadata:")
|
||||
logger.debug(" seq_lens_tensor: %s",
|
||||
attn_metadata.seq_lens_tensor)
|
||||
logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping)
|
||||
logger.debug(" block_tables: %s", attn_metadata.block_tables)
|
||||
|
||||
return new_model_input
|
||||
|
||||
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
|
||||
"""Determines if draft_model_runner GPU multi-step can be used.
|
||||
Currently required conditions are:
|
||||
1. Only decodes
|
||||
2. Only flash-attn
|
||||
3. No LORA
|
||||
4. No prompt_adapter_config
|
||||
"""
|
||||
if not enable_gpu_advance_step:
|
||||
return False
|
||||
|
||||
# Append the output token to the sequence data.
|
||||
assert self.cached_seq_group_metadata_list is not None
|
||||
for seq_group_metadata, sequence_group_outputs in zip(
|
||||
self.cached_seq_group_metadata_list, last_output.outputs):
|
||||
seq_group_metadata.is_prompt = False
|
||||
# We allow multi-step GPU only in decode mode
|
||||
for seq_group in execute_model_req.seq_group_metadata_list:
|
||||
if seq_group.is_prompt:
|
||||
return False
|
||||
|
||||
for seq_output in sequence_group_outputs.samples:
|
||||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||
# TODO: Add support for other attn backends
|
||||
if self.attn_backend.get_name() != "flash-attn":
|
||||
return False
|
||||
|
||||
token_id = seq_output.output_token
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
# TODO: Add support for LORA
|
||||
if self.lora_config:
|
||||
return False
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob)
|
||||
seq.update_num_computed_tokens(1)
|
||||
# TODO: Add soft-tuning prompt adapter support
|
||||
if self.prompt_adapter_config:
|
||||
return False
|
||||
|
||||
return self.prepare_model_input(self.cached_seq_group_metadata_list)
|
||||
return True
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@ -125,42 +226,86 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
# Since we do not broadcast data inside execute_model anymore,
|
||||
# we need to figure out the best way to support TP > 1 in this
|
||||
# case, because we will at least need to broadcast the sampled
|
||||
# tokens to all workers.
|
||||
if not self.is_driver_worker:
|
||||
raise ValueError("TP1DraftModelRunner only supports TP=1.")
|
||||
"""Executes num_steps forward passes with advacement of input tensors
|
||||
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
Optimizations used:
|
||||
1. Input tensors are updated on the GPU directly
|
||||
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
|
||||
them since we do batch expansion later that uses GPU outputs)
|
||||
3. Reuses sampling tensors (since we run only decodes and they have
|
||||
a repeating sampling logic)
|
||||
"""
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
# When num_steps == 1, we execute the fallback here for the GPU
|
||||
# advance_step, which runs prepare_inputs on CPU and for each spec
|
||||
# iteration invokes this function only once
|
||||
# (Look at multi-step-worker code)
|
||||
is_fallback = num_steps == 1
|
||||
if not is_fallback:
|
||||
# Since we do not broadcast data inside execute_model anymore,
|
||||
# we need to figure out the best way to support TP > 1 in this
|
||||
# case, because we will at least need to broadcast the sampled
|
||||
# tokens to all workers.
|
||||
if not self.is_driver_worker:
|
||||
raise ValueError("TP1DraftModelRunner only supports TP=1.")
|
||||
|
||||
# Sanity
|
||||
if self.lora_config is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for LORA")
|
||||
if self.prompt_adapter_config is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for "
|
||||
"prompt_adapter_config")
|
||||
if model_input.multi_modal_kwargs:
|
||||
raise ValueError(
|
||||
"TP1DraftModelRunner has no support for multi_modal_kwargs"
|
||||
)
|
||||
else:
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
assert model_input.prompt_adapter_requests is not None
|
||||
assert model_input.prompt_adapter_mapping is not None
|
||||
self.set_active_prompt_adapters(
|
||||
model_input.prompt_adapter_requests,
|
||||
model_input.prompt_adapter_mapping)
|
||||
|
||||
# Detect exec mode
|
||||
assert model_input.attn_metadata is not None
|
||||
use_cuda_graph = False
|
||||
if model_input.attn_metadata.num_prefills > 0:
|
||||
# In this case, execute_model(..) was called directly
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"execute_model(..) of draft_model_runner can be called "
|
||||
"directly only with a single-step prefill")
|
||||
else:
|
||||
# We can skip CPU samples for spec token generation.
|
||||
# (We do allow CPU samples for num_steps == 1 to support the
|
||||
# fallback case, where supports_gpu_multi_step(..) does not pass)
|
||||
model_input.sampling_metadata.skip_sampler_cpu_output = (
|
||||
not is_fallback)
|
||||
|
||||
# Attn attr defines if we use cuda graphs
|
||||
use_cuda_graph = model_input.attn_metadata.use_cuda_graph
|
||||
|
||||
# Get model
|
||||
if use_cuda_graph:
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (self.graph_runners[model_input.virtual_engine]
|
||||
[graph_batch_size])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
virtual_engine = model_input.virtual_engine
|
||||
outputs: List[SamplerOutput] = []
|
||||
for step in range(num_steps):
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[virtual_engine][graph_batch_size])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
|
||||
# Run model
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
@ -181,8 +326,8 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
))
|
||||
|
||||
# Prepare the inputs for the next step.
|
||||
# Prepare inputs for the next step
|
||||
if step != num_steps - 1:
|
||||
model_input = self.update_model_input(model_input, outputs[-1])
|
||||
model_input = self._gpu_advance_step(model_input, outputs[-1])
|
||||
|
||||
return outputs
|
||||
|
||||
@ -67,14 +67,23 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
expanded_request, indices_of_seq_with_bonus_tokens =\
|
||||
self._expand_execute_model_request(
|
||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run model sample_len times.
|
||||
model_outputs: List[SamplerOutput] = []
|
||||
if isinstance(self.model_runner, TP1DraftModelRunner):
|
||||
if isinstance(
|
||||
self.model_runner, TP1DraftModelRunner
|
||||
) and self.model_runner.supports_gpu_multi_step(expanded_request):
|
||||
# Here we run the draft_model_runner with multi-step prepare
|
||||
# on the GPU directly
|
||||
expanded_request.num_steps = sample_len
|
||||
model_outputs = self.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
else:
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1.
|
||||
# Here we run multi-step directly, with every step prepared
|
||||
# on the CPU.
|
||||
# TODO: Remove this branch once DraftModelRunner supports TP>1
|
||||
# and other restrictions that are part of DraftModelRunner's
|
||||
# supports_gpu_multi_step(..)
|
||||
for _ in range(sample_len):
|
||||
model_output: List[SamplerOutput] = super().execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
@ -171,7 +180,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
outputs=[
|
||||
expanded_batch_output.outputs[i]
|
||||
for i in output_indices_to_retain
|
||||
],
|
||||
] if len(expanded_batch_output.outputs) > 0 else [],
|
||||
sampled_token_probs=(
|
||||
expanded_batch_output.
|
||||
sampled_token_probs[output_indices_to_retain]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user