mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 21:57:08 +08:00
WIP initial working version
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
76732ff701
commit
d4c9448b26
@ -28,6 +28,15 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
|||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype,
|
||||||
const double k_scale, const double v_scale);
|
const double k_scale, const double v_scale);
|
||||||
|
|
||||||
|
void reshape_and_cache_flash_full_cuda(
|
||||||
|
torch::Tensor& tokenshape,
|
||||||
|
torch::Tensor& key, torch::Tensor& value,
|
||||||
|
torch::Tensor& key_cache,
|
||||||
|
torch::Tensor& value_cache,
|
||||||
|
torch::Tensor& slot_mapping,
|
||||||
|
const std::string& kv_cache_dtype,
|
||||||
|
const double k_scale, const double v_scale);
|
||||||
|
|
||||||
// Just for unittest
|
// Just for unittest
|
||||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
const double scale, const std::string& kv_cache_dtype);
|
const double scale, const std::string& kv_cache_dtype);
|
||||||
|
|||||||
@ -245,6 +245,52 @@ __global__ void reshape_and_cache_flash_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||||
|
__global__ void reshape_and_cache_flash_full_cuda_kernel(
|
||||||
|
const int32_t* __restrict__ tensorshape,
|
||||||
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
|
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||||
|
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
|
||||||
|
// head_size]
|
||||||
|
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
||||||
|
// head_size]
|
||||||
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
|
const int block_stride, const int key_stride, const int value_stride,
|
||||||
|
const int num_heads, const int head_size, const int block_size,
|
||||||
|
const float k_scale, const float v_scale) {
|
||||||
|
const int64_t token_idx = blockIdx.x;
|
||||||
|
|
||||||
|
int32_t unpadded_num_tokens = tensorshape[0];
|
||||||
|
if(token_idx >= unpadded_num_tokens) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
|
const int64_t block_idx = slot_idx / block_size;
|
||||||
|
const int64_t block_offset = slot_idx % block_size;
|
||||||
|
const int n = num_heads * head_size;
|
||||||
|
for (int i = threadIdx.x; i < n; i += blockDim.x) {
|
||||||
|
const int64_t src_key_idx = token_idx * key_stride + i;
|
||||||
|
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||||
|
const int head_idx = i / head_size;
|
||||||
|
const int head_offset = i % head_size;
|
||||||
|
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
||||||
|
block_offset * num_heads * head_size +
|
||||||
|
head_idx * head_size + head_offset;
|
||||||
|
scalar_t tgt_key = key[src_key_idx];
|
||||||
|
scalar_t tgt_value = value[src_value_idx];
|
||||||
|
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||||
|
key_cache[tgt_key_value_idx] = tgt_key;
|
||||||
|
value_cache[tgt_key_value_idx] = tgt_value;
|
||||||
|
} else {
|
||||||
|
key_cache[tgt_key_value_idx] =
|
||||||
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||||
|
value_cache[tgt_key_value_idx] =
|
||||||
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
// KV_T is the stored data type of kv-cache.
|
// KV_T is the stored data type of kv-cache.
|
||||||
@ -339,6 +385,49 @@ void reshape_and_cache_flash(
|
|||||||
CALL_RESHAPE_AND_CACHE_FLASH);
|
CALL_RESHAPE_AND_CACHE_FLASH);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// KV_T is the stored data type of kv-cache.
|
||||||
|
// CACHE_T is the data type of key and value tensors.
|
||||||
|
// KV_DTYPE is the real data type of kv-cache.
|
||||||
|
#define CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA(KV_T, CACHE_T, KV_DTYPE)\
|
||||||
|
vllm::reshape_and_cache_flash_full_cuda_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||||
|
<<<grid, block, 0, stream>>>( \
|
||||||
|
reinterpret_cast<int32_t*>(tokenshape.data_ptr()), \
|
||||||
|
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||||
|
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||||
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||||
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||||
|
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
||||||
|
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
|
||||||
|
|
||||||
|
void reshape_and_cache_flash_full_cuda(
|
||||||
|
torch::Tensor& tokenshape, // true num_tokens at first entry.
|
||||||
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
|
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||||
|
torch::Tensor&
|
||||||
|
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||||
|
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||||
|
const std::string& kv_cache_dtype, const double k_scale,
|
||||||
|
const double v_scale) {
|
||||||
|
int padded_num_tokens = slot_mapping.size(0);
|
||||||
|
int num_heads = key.size(1);
|
||||||
|
int head_size = key.size(2);
|
||||||
|
int block_size = key_cache.size(1);
|
||||||
|
|
||||||
|
int key_stride = key.stride(0);
|
||||||
|
int value_stride = value.stride(0);
|
||||||
|
int block_stride = key_cache.stride(0);
|
||||||
|
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
||||||
|
|
||||||
|
dim3 grid(padded_num_tokens);
|
||||||
|
dim3 block(std::min(num_heads * head_size, 512));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||||
|
CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA);
|
||||||
|
}
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
|
|||||||
@ -460,6 +460,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
||||||
&reshape_and_cache_flash);
|
&reshape_and_cache_flash);
|
||||||
|
|
||||||
|
// Reshape the key and value tensors and cache them.
|
||||||
|
cache_ops.def(
|
||||||
|
"reshape_and_cache_flash_full_cuda(Tensor tensorshape,"
|
||||||
|
" Tensor key, Tensor value,"
|
||||||
|
" Tensor! key_cache,"
|
||||||
|
" Tensor! value_cache,"
|
||||||
|
" Tensor slot_mapping,"
|
||||||
|
" str kv_cache_dtype,"
|
||||||
|
" float k_scale, float v_scale) -> ()");
|
||||||
|
cache_ops.impl("reshape_and_cache_flash_full_cuda", torch::kCUDA,
|
||||||
|
&reshape_and_cache_flash_full_cuda);
|
||||||
|
|
||||||
// Convert the key and value cache to fp8 data type.
|
// Convert the key and value cache to fp8 data type.
|
||||||
cache_ops.def(
|
cache_ops.def(
|
||||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
||||||
|
|||||||
@ -514,6 +514,7 @@ class VllmBackend:
|
|||||||
|
|
||||||
if not self.compilation_config.use_cudagraph or \
|
if not self.compilation_config.use_cudagraph or \
|
||||||
not self.compilation_config.cudagraph_copy_inputs:
|
not self.compilation_config.cudagraph_copy_inputs:
|
||||||
|
# return self.graph
|
||||||
return self.split_gm
|
return self.split_gm
|
||||||
|
|
||||||
# if we need to copy input buffers for cudagraph
|
# if we need to copy input buffers for cudagraph
|
||||||
|
|||||||
@ -2705,7 +2705,7 @@ class CompilationConfig(BaseModel):
|
|||||||
custom_ops: List[str] = Field(default_factory=list)
|
custom_ops: List[str] = Field(default_factory=list)
|
||||||
splitting_ops: List[str] = Field(default=None) # type: ignore
|
splitting_ops: List[str] = Field(default=None) # type: ignore
|
||||||
|
|
||||||
use_inductor: bool = True
|
use_inductor: bool = False
|
||||||
candidate_compile_sizes: Optional[List[int]] = Field(default=None)
|
candidate_compile_sizes: Optional[List[int]] = Field(default=None)
|
||||||
inductor_compile_config: Dict = Field(default_factory=dict)
|
inductor_compile_config: Dict = Field(default_factory=dict)
|
||||||
inductor_passes: Dict[str, str] = Field(default_factory=dict)
|
inductor_passes: Dict[str, str] = Field(default_factory=dict)
|
||||||
@ -3181,8 +3181,7 @@ class VllmConfig:
|
|||||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||||
self.compilation_config.pass_config.enable_fusion = False
|
self.compilation_config.pass_config.enable_fusion = False
|
||||||
self.compilation_config.pass_config.enable_reshape = False
|
self.compilation_config.pass_config.enable_reshape = False
|
||||||
# self.compilation_config.level = CompilationLevel.PIECEWISE
|
self.compilation_config.level = CompilationLevel.PIECEWISE
|
||||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
|
||||||
|
|
||||||
self._set_cudagraph_sizes()
|
self._set_cudagraph_sizes()
|
||||||
|
|
||||||
@ -3263,8 +3262,7 @@ class VllmConfig:
|
|||||||
batch_size_capture_list = []
|
batch_size_capture_list = []
|
||||||
if self.model_config is not None and \
|
if self.model_config is not None and \
|
||||||
not self.model_config.enforce_eager:
|
not self.model_config.enforce_eager:
|
||||||
batch_size_capture_list = [1, 2, 4
|
batch_size_capture_list = [1, 2, 4] + [i for i in range(8, 513, 8)]
|
||||||
] + [i for i in range(8, 513, 8)]
|
|
||||||
|
|
||||||
self.compilation_config.init_with_cudagraph_sizes(
|
self.compilation_config.init_with_cudagraph_sizes(
|
||||||
batch_size_capture_list)
|
batch_size_capture_list)
|
||||||
|
|||||||
@ -65,6 +65,9 @@ class FlashAttentionMetadata:
|
|||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
slot_mapping: torch.Tensor
|
slot_mapping: torch.Tensor
|
||||||
|
|
||||||
|
# [num_actual_tokens, batch_size, max_query_len, max_seq_len]
|
||||||
|
tokenshape: torch.Tensor
|
||||||
|
|
||||||
# For cascade attention.
|
# For cascade attention.
|
||||||
use_cascade: bool
|
use_cascade: bool
|
||||||
common_prefix_len: int
|
common_prefix_len: int
|
||||||
@ -155,7 +158,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Profiling run.
|
# Dynamic shape profiling run.
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# IMPORTANT!
|
# IMPORTANT!
|
||||||
@ -167,19 +170,17 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# Whenever making a change in this method, please benchmark the
|
# Whenever making a change in this method, please benchmark the
|
||||||
# performance to make sure it does not introduce any overhead.
|
# performance to make sure it does not introduce any overhead.
|
||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
tokenshape = attn_metadata.tokenshape
|
||||||
|
num_padded_tokens = key.shape[0]
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
|
||||||
# not padded. However, we don't need to do key[:num_actual_tokens] and
|
|
||||||
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
|
|
||||||
# the slot_mapping's shape to determine the number of actual tokens.
|
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
torch.ops._C_cache_ops.reshape_and_cache_flash_full_cuda(
|
||||||
|
tokenshape,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping[:num_actual_tokens],
|
attn_metadata.slot_mapping[:num_padded_tokens],
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
@ -188,13 +189,15 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# Compute attention and update output up to `num_actual_tokens`.
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
if not attn_metadata.use_cascade:
|
if not attn_metadata.use_cascade:
|
||||||
# Regular attention (common case).
|
# Regular attention (common case).
|
||||||
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
batch_size = attn_metadata.block_table.shape[0]
|
batch_size = attn_metadata.block_table.shape[0]
|
||||||
|
print(f"q, k v shapes: {query.shape}")
|
||||||
|
|
||||||
flash_attn_varlen_func(
|
flash_attn_varlen_func(
|
||||||
q=query[:num_actual_tokens],
|
q=query[:num_padded_tokens],
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
out=output[:num_actual_tokens],
|
out=output[:num_padded_tokens],
|
||||||
cu_seqlens_q=attn_metadata.query_start_loc[:batch_size+1],
|
cu_seqlens_q=attn_metadata.query_start_loc[:batch_size+1],
|
||||||
max_seqlen_q=attn_metadata.max_query_len,
|
max_seqlen_q=attn_metadata.max_query_len,
|
||||||
cu_seqlens_k=attn_metadata.seq_start_loc[:batch_size+1],
|
cu_seqlens_k=attn_metadata.seq_start_loc[:batch_size+1],
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -110,9 +110,10 @@ class GPUModelRunner:
|
|||||||
vocab_size=model_config.get_vocab_size(),
|
vocab_size=model_config.get_vocab_size(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
# self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE
|
# == CompilationLevel.PIECEWISE
|
||||||
and not self.model_config.enforce_eager)
|
# and not self.model_config.enforce_eager)
|
||||||
|
self.use_cuda_graph = not self.model_config.enforce_eager
|
||||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||||
# The convention is different.
|
# The convention is different.
|
||||||
# self.cudagraph_batch_sizes sorts in ascending order.
|
# self.cudagraph_batch_sizes sorts in ascending order.
|
||||||
@ -149,6 +150,7 @@ class GPUModelRunner:
|
|||||||
# this one must be int64
|
# this one must be int64
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
self.tokenshape = torch.zeros(4, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||||
self.arange_np = np.arange(max(self.max_num_reqs + 1,
|
self.arange_np = np.arange(max(self.max_num_reqs + 1,
|
||||||
@ -183,6 +185,10 @@ class GPUModelRunner:
|
|||||||
pin_memory=self.pin_memory)
|
pin_memory=self.pin_memory)
|
||||||
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
self.seq_start_loc_np = self.seq_start_loc_cpu.numpy()
|
||||||
|
|
||||||
|
self.tokenshape_cpu = torch.zeros(4, dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=self.pin_memory)
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
# Remove stopped requests from the cached states.
|
# Remove stopped requests from the cached states.
|
||||||
# Keep the states of the pre-empted requests.
|
# Keep the states of the pre-empted requests.
|
||||||
@ -379,6 +385,12 @@ class GPUModelRunner:
|
|||||||
self.slot_mapping_cpu[:total_num_scheduled_tokens],
|
self.slot_mapping_cpu[:total_num_scheduled_tokens],
|
||||||
non_blocking=True)
|
non_blocking=True)
|
||||||
|
|
||||||
|
self.tokenshape_cpu[0] = total_num_scheduled_tokens # Actual number of tokens to process
|
||||||
|
self.tokenshape_cpu[1] = num_reqs # Number of requests
|
||||||
|
self.tokenshape_cpu[2] = max_num_scheduled_tokens # Maximum query length
|
||||||
|
self.tokenshape_cpu[3] = max_seq_len # Maximum sequence length
|
||||||
|
self.tokenshape.copy_(self.tokenshape_cpu, non_blocking=True)
|
||||||
|
|
||||||
# Prepare for cascade attention if needed.
|
# Prepare for cascade attention if needed.
|
||||||
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
|
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
|
||||||
self.block_size)
|
self.block_size)
|
||||||
@ -468,6 +480,7 @@ class GPUModelRunner:
|
|||||||
seq_start_loc=self.seq_start_loc,
|
seq_start_loc=self.seq_start_loc,
|
||||||
block_table=self.input_batch.block_table[:num_reqs],
|
block_table=self.input_batch.block_table[:num_reqs],
|
||||||
slot_mapping=self.slot_mapping,
|
slot_mapping=self.slot_mapping,
|
||||||
|
tokenshape=self.tokenshape,
|
||||||
# Cascade stuff
|
# Cascade stuff
|
||||||
use_cascade=use_cascade,
|
use_cascade=use_cascade,
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
@ -710,6 +723,7 @@ class GPUModelRunner:
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: Optional[FlashAttentionMetadata],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
@ -717,7 +731,7 @@ class GPUModelRunner:
|
|||||||
else:
|
else:
|
||||||
input_ids = self.input_ids[:num_tokens]
|
input_ids = self.input_ids[:num_tokens]
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
with set_forward_context(None, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
hidden_states = model(
|
hidden_states = model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=self.positions[:num_tokens],
|
positions=self.positions[:num_tokens],
|
||||||
@ -726,6 +740,28 @@ class GPUModelRunner:
|
|||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def metadata_for_dummy_run(self, num_tokens) -> FlashAttentionMetadata:
|
||||||
|
# Create placeholder metadata
|
||||||
|
num_reqs = num_tokens
|
||||||
|
max_query_len = num_tokens
|
||||||
|
max_seq_len = num_tokens
|
||||||
|
return FlashAttentionMetadata(
|
||||||
|
num_actual_tokens=num_tokens,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
query_start_loc=self.query_start_loc,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
seq_start_loc=self.seq_start_loc,
|
||||||
|
block_table=self.input_batch.block_table[:num_reqs],
|
||||||
|
slot_mapping=self.slot_mapping,
|
||||||
|
tokenshape=self.tokenshape,
|
||||||
|
# Cascade stuff. Non-piecewise CUDA graphs NYI
|
||||||
|
use_cascade=None,
|
||||||
|
common_prefix_len=0,
|
||||||
|
cu_prefix_query_lens=None,
|
||||||
|
cu_prefix_kv_lens=None,
|
||||||
|
cu_suffix_kv_lens=None,
|
||||||
|
)
|
||||||
|
|
||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
@ -831,7 +867,7 @@ class GPUModelRunner:
|
|||||||
|
|
||||||
# Trigger compilation for general shape.
|
# Trigger compilation for general shape.
|
||||||
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
|
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
|
||||||
dummy_kv_caches)
|
dummy_kv_caches, None)
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
logits = self.model.compute_logits(hidden_states, None)
|
||||||
logits = logits[:self.max_num_tokens]
|
logits = logits[:self.max_num_tokens]
|
||||||
# TODO(woosuk): Consider the memory usage of the sampler.
|
# TODO(woosuk): Consider the memory usage of the sampler.
|
||||||
@ -849,10 +885,11 @@ class GPUModelRunner:
|
|||||||
# can reuse the memory pool allocated for the large shapes.
|
# can reuse the memory pool allocated for the large shapes.
|
||||||
with graph_capture():
|
with graph_capture():
|
||||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||||
|
attn_metadata = self.metadata_for_dummy_run(num_tokens)
|
||||||
for _ in range(self.vllm_config.compilation_config.
|
for _ in range(self.vllm_config.compilation_config.
|
||||||
cudagraph_num_of_warmups):
|
cudagraph_num_of_warmups):
|
||||||
self._dummy_run(self.model, num_tokens, self.kv_caches)
|
self._dummy_run(self.model, num_tokens, self.kv_caches, attn_metadata)
|
||||||
self._dummy_run(self.model, num_tokens, self.kv_caches)
|
self._dummy_run(self.model, num_tokens, self.kv_caches, attn_metadata)
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user