Updates for FA3 and other changes

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-02-03 20:58:48 +00:00
parent d151b63b8b
commit 230730c34d
5 changed files with 50 additions and 48 deletions

View File

@ -29,13 +29,10 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void reshape_and_cache_flash_full_cuda(
torch::Tensor& tokenshape,
torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const double k_scale, const double v_scale);
torch::Tensor& tokenshape, torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping, const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,

View File

@ -262,8 +262,8 @@ __global__ void reshape_and_cache_flash_full_cuda_kernel(
const int64_t token_idx = blockIdx.x;
int32_t unpadded_num_tokens = tensorshape[0];
if(token_idx >= unpadded_num_tokens) {
return;
if (token_idx >= unpadded_num_tokens) {
return;
}
const int64_t slot_idx = slot_mapping[token_idx];
@ -437,27 +437,28 @@ void reshape_and_cache_flash(
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA(KV_T, CACHE_T, KV_DTYPE)\
vllm::reshape_and_cache_flash_full_cuda_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<int32_t*>(tokenshape.data_ptr()), \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
#define CALL_RESHAPE_AND_CACHE_FLASH_FULL_CUDA(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache_flash_full_cuda(
torch::Tensor& tokenshape, // true num_tokens at first entry.
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& tokenshape, // true num_tokens at first entry.
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int padded_num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);

View File

@ -478,7 +478,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor! value_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache_flash_full_cuda", torch::kCUDA,
&reshape_and_cache_flash_full_cuda);

View File

@ -215,16 +215,16 @@ class FlashAttentionImpl(AttentionImpl):
# Regular attention (common case).
num_actual_tokens = attn_metadata.num_actual_tokens
batch_size = attn_metadata.block_table.shape[0]
#TODO: Do we need to slice by [:batch_size+1]?
flash_attn_varlen_func(
q=query[:num_padded_tokens],
k=key_cache,
v=value_cache,
out=output[:num_padded_tokens],
cu_seqlens_q=attn_metadata.query_start_loc[:batch_size+1],
cu_seqlens_q=attn_metadata.query_start_loc[:batch_size + 1],
max_seqlen_q=attn_metadata.max_query_len,
seqused_k=attn_metadata.seq_lens[:batch_size+1],
seqused_k=attn_metadata.seq_lens[:batch_size],
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,

View File

@ -122,9 +122,9 @@ class GPUModelRunner:
vocab_size=model_config.get_vocab_size(),
)
# self.use_cuda_graph = (self.vllm_config.compilation_config.level
# == CompilationLevel.PIECEWISE
# and not self.model_config.enforce_eager)
# self.use_cuda_graph = (self.vllm_config.compilation_config.level
# == CompilationLevel.PIECEWISE
# and not self.model_config.enforce_eager)
self.use_cuda_graph = not self.model_config.enforce_eager
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
@ -221,9 +221,10 @@ class GPUModelRunner:
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self.tokenshape_cpu = torch.zeros(4, dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.tokenshape_cpu = torch.zeros(4,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
@ -459,16 +460,18 @@ class GPUModelRunner:
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
self.slot_mapping[:total_num_scheduled_tokens].copy_(
self.slot_mapping_cpu[:total_num_scheduled_tokens],
non_blocking=True)
self.tokenshape_cpu[0] = total_num_scheduled_tokens # Actual number of tokens to process
self.tokenshape_cpu[1] = num_reqs # Number of requests
self.tokenshape_cpu[2] = max_num_scheduled_tokens # Maximum query length
self.tokenshape_cpu[3] = max_seq_len # Maximum sequence length
self.tokenshape_cpu[
0] = total_num_scheduled_tokens # Tokens to process
self.tokenshape_cpu[1] = num_reqs # Number of requests
self.tokenshape_cpu[
2] = max_num_scheduled_tokens # Maximum query length
self.tokenshape_cpu[3] = max_seq_len # Maximum sequence length
self.tokenshape.copy_(self.tokenshape_cpu, non_blocking=True)
# Prepare for cascade attention if needed.
@ -908,7 +911,7 @@ class GPUModelRunner:
inputs_embeds=inputs_embeds,
)
return hidden_states
def metadata_for_dummy_run(self, num_tokens) -> FlashAttentionMetadata:
# Create placeholder metadata
num_reqs = num_tokens
@ -919,16 +922,17 @@ class GPUModelRunner:
max_query_len=max_query_len,
query_start_loc=self.query_start_loc,
max_seq_len=max_seq_len,
seq_start_loc=self.seq_start_loc,
block_table=self.input_batch.block_table[:num_reqs],
seq_lens=self.seq_lens,
block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=self.slot_mapping,
tokenshape=self.tokenshape,
# Cascade stuff. Non-piecewise CUDA graphs NYI
use_cascade=None,
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
cu_prefix_kv_lens=None,
cu_suffix_kv_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
def profile_run(self) -> None:
@ -1058,8 +1062,8 @@ class GPUModelRunner:
attn_metadata = self.metadata_for_dummy_run(num_tokens)
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(self.model, num_tokens, attn_metadata=attn_metadata)
self._dummy_run(self.model, num_tokens, attn_metadata=attn_metadata)
self._dummy_run(num_tokens, attn_metadata=attn_metadata)
self._dummy_run(num_tokens, attn_metadata=attn_metadata)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]