mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 21:46:21 +08:00
[Attention] Use sparse prefill kernel for fp8 kv-cache in DeepSeek-v3.2 (#27532)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
91401c7a26
commit
3e41992fec
12
csrc/cache.h
12
csrc/cache.h
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
@ -58,6 +59,15 @@ void cp_gather_cache(
|
||||
torch::Tensor const& cu_seq_lens, // [BATCH+1]
|
||||
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
|
||||
|
||||
// Gather and upconvert FP8 KV cache to BF16 workspace
|
||||
void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, 576]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& seq_lens, // [BATCH]
|
||||
torch::Tensor const& workspace_starts, // [BATCH]
|
||||
int64_t batch_size);
|
||||
|
||||
// Indexer K quantization and cache function
|
||||
void indexer_k_quant_and_cache(
|
||||
torch::Tensor& k, // [num_tokens, head_dim]
|
||||
@ -72,4 +82,4 @@ void cp_gather_indexer_k_quant_cache(
|
||||
torch::Tensor& dst_k, // [num_tokens, head_dim]
|
||||
torch::Tensor& dst_scale, // [num_tokens, head_dim / quant_block_size * 4]
|
||||
const torch::Tensor& block_table, // [batch_size, num_blocks]
|
||||
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
|
||||
const torch::Tensor& cu_seq_lens); // [batch_size + 1]
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "cuda_compat.h"
|
||||
@ -514,7 +515,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
|
||||
const int quant_block_size, // quantization block size
|
||||
const int cache_block_size, // cache block size
|
||||
const int cache_stride, // stride for each token in kv_cache
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
|
||||
const bool use_ue8m0 // use ue8m0 scale format
|
||||
) {
|
||||
constexpr int VEC_SIZE = 4;
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
@ -1061,6 +1063,82 @@ void gather_and_maybe_dequant_cache(
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Gather and upconvert FP8 KV cache tokens to BF16 workspace
|
||||
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
|
||||
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
|
||||
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
|
||||
const int32_t* __restrict__ seq_lens, // [BATCH]
|
||||
const int32_t* __restrict__ workspace_starts, // [BATCH]
|
||||
const int32_t block_size, const int32_t head_dim,
|
||||
const int64_t block_table_stride, const int64_t cache_block_stride,
|
||||
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
|
||||
const int64_t bid = blockIdx.x; // Batch ID
|
||||
const int32_t num_splits = gridDim.y;
|
||||
const int32_t split = blockIdx.y;
|
||||
const int32_t seq_start = workspace_starts[bid];
|
||||
const int32_t seq_len = seq_lens[bid];
|
||||
const int32_t tot_slots = seq_len;
|
||||
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
|
||||
|
||||
const int32_t split_start = split * split_slots;
|
||||
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
|
||||
|
||||
const bool is_active_split = (split_start < tot_slots);
|
||||
|
||||
if (!is_active_split) return;
|
||||
|
||||
// Adjust the pointer for the block_table for this batch
|
||||
const int32_t batch_offset = bid * block_table_stride;
|
||||
int32_t offset = split_start;
|
||||
int32_t offset_div = offset / block_size;
|
||||
offset = offset % block_size;
|
||||
const int32_t* batch_block_table = block_table + batch_offset;
|
||||
|
||||
// Adjust dst pointer based on the cumulative sequence lengths
|
||||
dst += seq_start * dst_entry_stride;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
// Process each token in this split
|
||||
for (int pid = split_start; pid < split_end; ++pid) {
|
||||
auto block_id = batch_block_table[offset_div];
|
||||
const uint8_t* token_ptr =
|
||||
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
|
||||
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
|
||||
|
||||
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
|
||||
const uint8_t* no_pe_ptr = token_ptr;
|
||||
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
|
||||
const __nv_bfloat16* rope_ptr =
|
||||
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
|
||||
|
||||
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
|
||||
if (tid < 512) {
|
||||
// FP8 dequantization
|
||||
const int tile = tid >> 7; // each tile is 128 elements
|
||||
const float scale = scales_ptr[tile];
|
||||
const uint8_t val = no_pe_ptr[tid];
|
||||
dst_ptr[tid] =
|
||||
fp8::scaled_convert<__nv_bfloat16, uint8_t,
|
||||
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
|
||||
} else if (tid < 576) {
|
||||
// Rope copy (64 bf16 elements)
|
||||
const int rope_idx = tid - 512;
|
||||
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
|
||||
}
|
||||
|
||||
// Move to next token
|
||||
offset += 1;
|
||||
if (offset == block_size) {
|
||||
offset_div += 1;
|
||||
offset = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
// Note(hc): The cp_gather_cache allows seq_starts to no longer be divisible by
|
||||
// block_size.
|
||||
@ -1202,6 +1280,57 @@ void cp_gather_cache(
|
||||
}
|
||||
}
|
||||
|
||||
void cp_gather_and_upconvert_fp8_kv_cache(
|
||||
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
|
||||
torch::Tensor const& dst, // [TOT_TOKENS, 576]
|
||||
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
|
||||
torch::Tensor const& seq_lens, // [BATCH]
|
||||
torch::Tensor const& workspace_starts, // [BATCH]
|
||||
int64_t batch_size) {
|
||||
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
int32_t block_size = src_cache.size(1);
|
||||
int32_t head_dim = dst.size(1);
|
||||
|
||||
TORCH_CHECK(block_table.dtype() == torch::kInt32,
|
||||
"block_table must be int32");
|
||||
TORCH_CHECK(seq_lens.dtype() == torch::kInt32, "seq_lens must be int32");
|
||||
TORCH_CHECK(workspace_starts.dtype() == torch::kInt32,
|
||||
"workspace_starts must be int32");
|
||||
|
||||
TORCH_CHECK(src_cache.device() == dst.device(),
|
||||
"src_cache and dst must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == block_table.device(),
|
||||
"src_cache and block_table must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == seq_lens.device(),
|
||||
"src_cache and seq_lens must be on the same device");
|
||||
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
|
||||
"src_cache and workspace_starts must be on the same device");
|
||||
|
||||
TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
|
||||
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
|
||||
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
|
||||
|
||||
int64_t block_table_stride = block_table.stride(0);
|
||||
int64_t cache_block_stride = src_cache.stride(0);
|
||||
int64_t cache_entry_stride = src_cache.stride(1);
|
||||
int64_t dst_entry_stride = dst.stride(0);
|
||||
|
||||
// Decide on the number of splits based on the batch size
|
||||
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
|
||||
dim3 grid(batch_size, num_splits);
|
||||
dim3 block(576);
|
||||
|
||||
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
|
||||
src_cache.data_ptr<uint8_t>(),
|
||||
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
|
||||
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
|
||||
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
|
||||
block_table_stride, cache_block_stride, cache_entry_stride,
|
||||
dst_entry_stride);
|
||||
}
|
||||
|
||||
// Macro to dispatch the kernel based on the data type.
|
||||
#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::indexer_k_quant_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
|
||||
@ -754,6 +754,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
|
||||
cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"cp_gather_and_upconvert_fp8_kv_cache(Tensor src_cache, Tensor! dst, "
|
||||
"Tensor block_table, Tensor seq_lens, Tensor workspace_starts, int "
|
||||
"batch_size) -> ()");
|
||||
cache_ops.impl("cp_gather_and_upconvert_fp8_kv_cache", torch::kCUDA,
|
||||
&cp_gather_and_upconvert_fp8_kv_cache);
|
||||
|
||||
cache_ops.def(
|
||||
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor "
|
||||
"slot_mapping, "
|
||||
|
||||
@ -202,6 +202,27 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_init():
|
||||
"""Initialize the workspace manager for tests that need it.
|
||||
|
||||
This fixture initializes the workspace manager with a CUDA device
|
||||
if available, and resets it after the test completes. Tests that
|
||||
create a full vLLM engine should NOT use this fixture as the engine
|
||||
will initialize the workspace manager itself.
|
||||
"""
|
||||
from vllm.v1.worker.workspace import (
|
||||
init_workspace_manager,
|
||||
reset_workspace_manager,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
init_workspace_manager(device)
|
||||
yield
|
||||
reset_workspace_manager()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def dynamo_reset():
|
||||
yield
|
||||
|
||||
@ -27,7 +27,7 @@ BLOCK_SIZE = [128, 128]
|
||||
@pytest.mark.parametrize("N", [512, 1024]) # intermediate dim per expert
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
def test_batched_deepgemm_vs_triton(
|
||||
E: int, T: int, K: int, N: int, topk: int, monkeypatch
|
||||
E: int, T: int, K: int, N: int, topk: int, monkeypatch, workspace_init
|
||||
):
|
||||
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
|
||||
|
||||
|
||||
@ -248,6 +248,7 @@ def test_fused_moe_batched_experts(
|
||||
per_act_token_quant: bool,
|
||||
block_shape: list[int] | None,
|
||||
input_scales: bool,
|
||||
workspace_init,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
|
||||
@ -137,7 +137,7 @@ def setup_cuda():
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(
|
||||
M, N, K, E, topk, block_size, dtype, seed, monkeypatch
|
||||
M, N, K, E, topk, block_size, dtype, seed, monkeypatch, workspace_init
|
||||
):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
@ -274,6 +274,7 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size: int | None = None,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
@ -329,6 +330,7 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
@ -385,9 +387,19 @@ def test_cutlass_moe_8_bit_EP(
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
|
||||
@ -419,9 +431,19 @@ def test_cutlass_moe_8_bit_EP_large(
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
test_cutlass_moe_8_bit_no_graph(
|
||||
m, n, k, e, topk, per_act_token, per_out_channel, monkeypatch, ep_size
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
e,
|
||||
topk,
|
||||
per_act_token,
|
||||
per_out_channel,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
|
||||
@ -445,6 +467,7 @@ def test_run_cutlass_moe_fp8(
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
|
||||
@ -29,6 +29,7 @@ from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
@ -363,6 +364,9 @@ def _test_deepep_deepgemm_moe(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
w1 = w1.to(device=torch.cuda.current_device())
|
||||
@ -445,6 +449,7 @@ def test_ht_deepep_deepgemm_moe(
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
disable_deepgemm_ue8m0,
|
||||
workspace_init,
|
||||
):
|
||||
"""
|
||||
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||
@ -518,6 +523,7 @@ def test_ll_deepep_deepgemm_moe(
|
||||
block_size: list[int],
|
||||
world_dp_size: tuple[int, int],
|
||||
disable_deepgemm_ue8m0,
|
||||
workspace_init,
|
||||
):
|
||||
"""
|
||||
Tests for Low-Latency DeepEP + DeepGemm integration.
|
||||
|
||||
@ -22,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
@ -342,6 +343,9 @@ def _deep_ep_moe(
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
if not low_latency_mode:
|
||||
assert not use_fp8_dispatch, (
|
||||
"FP8 dispatch interface is available only in low-latency mode"
|
||||
@ -437,6 +441,7 @@ def test_deep_ep_moe(
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
workspace_init,
|
||||
):
|
||||
low_latency_mode = False
|
||||
use_fp8_dispatch = False
|
||||
@ -492,6 +497,7 @@ def test_low_latency_deep_ep_moe(
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_fp8_dispatch: bool,
|
||||
workspace_init,
|
||||
):
|
||||
low_latency_mode = True
|
||||
|
||||
|
||||
@ -143,7 +143,7 @@ NUM_EXPERTS = [32]
|
||||
@pytest.mark.parametrize("topk", TOPKS)
|
||||
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
|
||||
@pytest.mark.skipif(not is_deep_gemm_supported(), reason="Requires deep_gemm kernels")
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch):
|
||||
def test_deepgemm_vs_triton(m, n, k, topk, num_experts, monkeypatch, workspace_init):
|
||||
with monkeypatch.context() as mp:
|
||||
mp.setenv("VLLM_USE_DEEP_GEMM", "1")
|
||||
|
||||
|
||||
@ -206,6 +206,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
topk: int,
|
||||
activation: str,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
|
||||
@ -51,7 +51,14 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, activation: str
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
activation: str,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
|
||||
@ -269,7 +269,7 @@ class Case:
|
||||
)
|
||||
@pytest.mark.parametrize("num_token", [2])
|
||||
@pytest.mark.parametrize("tp", [1, 2, 4, 8])
|
||||
def test_equiv(num_token, a_dtype, w_dtype, tp):
|
||||
def test_equiv(num_token, a_dtype, w_dtype, tp, workspace_init):
|
||||
from triton_kernels.tensor_details import layout
|
||||
|
||||
if not hasattr(layout, "make_default_matmul_mxfp4_w_layout"):
|
||||
|
||||
@ -16,6 +16,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
from vllm.utils.import_utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from .modular_kernel_tools.common import (
|
||||
Config,
|
||||
@ -77,6 +78,10 @@ def rank_worker(
|
||||
weights: WeightTensors,
|
||||
verbose: bool,
|
||||
):
|
||||
# Initialize workspace manager in child process
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
@ -300,6 +305,7 @@ def test_modular_kernel_combinations_singlegpu(
|
||||
chunk_size: int | None,
|
||||
world_size: int,
|
||||
pytestconfig,
|
||||
workspace_init,
|
||||
):
|
||||
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
|
||||
and those tests will be skipped on unsupported hardware."""
|
||||
|
||||
@ -209,6 +209,7 @@ def test_oai_triton_moe(
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
unfused: bool,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
(
|
||||
|
||||
@ -231,6 +231,7 @@ def test_fused_moe(
|
||||
padding: bool,
|
||||
chunk_size: int,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@torch.inference_mode()
|
||||
def test_cutlass_fp4_moe_no_graph(
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
|
||||
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype, workspace_init
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
|
||||
@ -46,6 +46,7 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
from ...utils import multi_gpu_test
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
@ -181,6 +182,7 @@ def test_fused_moe_batched_experts(
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
workspace_init,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
@ -863,6 +865,9 @@ def _pplx_test_loop(
|
||||
make_weights: bool,
|
||||
test_fn: Callable,
|
||||
):
|
||||
device = torch.device(f"cuda:{pgi.local_rank}")
|
||||
init_workspace_manager(device)
|
||||
|
||||
def format_result(msg, ex=None):
|
||||
if ex is not None:
|
||||
x = str(ex)
|
||||
|
||||
@ -22,10 +22,14 @@ from tests.v1.attention.utils import (
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops import flashmla
|
||||
from vllm.config import set_current_vllm_config
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import FlashMLASparseBackend
|
||||
from vllm.v1.attention.backends.mla.indexer import split_prefill_chunks
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
FlashMLASparseBackend,
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import split_prefill_chunks
|
||||
|
||||
SPARSE_BACKEND_BATCH_SPECS = {
|
||||
name: BATCH_SPECS[name]
|
||||
@ -114,8 +118,12 @@ def _quantize_dequantize_fp8_ds_mla(
|
||||
@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys()))
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"])
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1, 2, 4])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_sparse_backend_decode_correctness(
|
||||
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size
|
||||
dist_init, batch_name, kv_cache_dtype, tensor_parallel_size, workspace_init
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for sparse MLA decode test")
|
||||
@ -320,28 +328,29 @@ def test_sparse_backend_decode_correctness(
|
||||
mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous())
|
||||
|
||||
impl_cls = FlashMLASparseBackend.get_impl_cls()
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer,
|
||||
)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
impl = impl_cls(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
logits_soft_cap=None,
|
||||
attn_type="decoder",
|
||||
kv_sharing_target_layer_name=None,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_head_dim=qk_nope_head_dim + qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
kv_b_proj=mock_kv_b_proj,
|
||||
indexer=mock_indexer,
|
||||
)
|
||||
|
||||
impl.process_weights_after_loading(dtype)
|
||||
impl.process_weights_after_loading(dtype)
|
||||
|
||||
layer = MockAttentionLayer(device)
|
||||
out_buffer = torch.empty(
|
||||
@ -366,22 +375,192 @@ def test_sparse_backend_decode_correctness(
|
||||
torch.testing.assert_close(backend_output, sdpa_reference, rtol=0.5, atol=0.5)
|
||||
|
||||
|
||||
def _triton_convert_reference_impl(
|
||||
req_ids: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
token_indices: torch.Tensor,
|
||||
block_size: int,
|
||||
num_topk_tokens: int,
|
||||
HAS_PREFILL_WORKSPACE: bool = False,
|
||||
prefill_workspace_request_ids: torch.Tensor | None = None,
|
||||
prefill_workspace_starts: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Reference implementation for triton_convert_req_index_to_global_index."""
|
||||
num_tokens = req_ids.shape[0]
|
||||
max_blocks_per_req = block_table.shape[1]
|
||||
result = torch.empty(
|
||||
num_tokens, num_topk_tokens, dtype=torch.int32, device=req_ids.device
|
||||
)
|
||||
|
||||
for token_id in range(num_tokens):
|
||||
req_id = req_ids[token_id].item()
|
||||
|
||||
# Determine if this token uses workspace or paged cache
|
||||
use_prefill_workspace = False
|
||||
workspace_start = 0
|
||||
if HAS_PREFILL_WORKSPACE and prefill_workspace_request_ids is not None:
|
||||
assert prefill_workspace_starts is not None
|
||||
prefill_req_id = prefill_workspace_request_ids[token_id].item()
|
||||
if prefill_req_id >= 0:
|
||||
use_prefill_workspace = True
|
||||
workspace_start = prefill_workspace_starts[prefill_req_id].item()
|
||||
|
||||
for idx_id in range(num_topk_tokens):
|
||||
token_idx = token_indices[token_id, idx_id].item()
|
||||
|
||||
if token_idx == -1:
|
||||
result[token_id, idx_id] = -1
|
||||
elif use_prefill_workspace:
|
||||
# Prefill + using prefill workspace: map to workspace offset
|
||||
result[token_id, idx_id] = workspace_start + token_idx
|
||||
else:
|
||||
# Decode: map to paged cache
|
||||
block_id = token_idx // block_size
|
||||
if block_id >= max_blocks_per_req:
|
||||
result[token_id, idx_id] = -1
|
||||
else:
|
||||
block_num = block_table[req_id, block_id].item()
|
||||
offset = token_idx % block_size
|
||||
result[token_id, idx_id] = block_num * block_size + offset
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16, 64, 128])
|
||||
@pytest.mark.parametrize("num_topk_tokens", [128, 256, 512])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_triton_convert_req_index_to_global_index_decode_only(
|
||||
block_size, num_topk_tokens
|
||||
):
|
||||
device = torch.device("cuda")
|
||||
num_tokens = 8
|
||||
num_requests = 4
|
||||
max_blocks_per_req = 10
|
||||
|
||||
req_id = torch.randint(
|
||||
0, num_requests, (num_tokens,), dtype=torch.int32, device=device
|
||||
)
|
||||
block_table = torch.randint(
|
||||
0, 100, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
token_indices = torch.randint(
|
||||
0,
|
||||
block_size * max_blocks_per_req,
|
||||
(num_tokens, num_topk_tokens),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set some to -1 to test masking
|
||||
token_indices[0, :10] = -1
|
||||
token_indices[3, 50:60] = -1
|
||||
|
||||
# Set some to out of bounds
|
||||
token_indices[2, 100:110] = max_blocks_per_req * block_size
|
||||
token_indices[6, 150:160] = max_blocks_per_req * block_size
|
||||
|
||||
result = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
)
|
||||
|
||||
reference_result = _triton_convert_reference_impl(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
block_size,
|
||||
num_topk_tokens,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.skipif(
|
||||
torch.cuda.get_device_capability() < (9, 0),
|
||||
reason="FlashMLASparseBackend requires CUDA 9.0 or higher",
|
||||
)
|
||||
def test_triton_convert_req_index_to_global_index_with_prefill_workspace(block_size):
|
||||
device = torch.device("cuda")
|
||||
num_requests = 4
|
||||
max_blocks_per_req = 8
|
||||
num_topk_tokens = 128
|
||||
|
||||
# First 6 tokens are decode (reqs 0, 1), last 6 are prefill (reqs 2, 3)
|
||||
req_id = torch.tensor(
|
||||
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], dtype=torch.int32, device=device
|
||||
)
|
||||
prefill_workspace_request_ids = torch.tensor(
|
||||
[-1, -1, -1, -1, -1, -1, 0, 0, 0, 1, 1, 1], dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
# Workspace starts for the 2 prefill reqs: req 2 starts at 0, req 3 starts at 100
|
||||
prefill_workspace_starts = torch.tensor([0, 100], dtype=torch.int32, device=device)
|
||||
|
||||
block_table = torch.randint(
|
||||
0, 50, (num_requests, max_blocks_per_req), dtype=torch.int32, device=device
|
||||
)
|
||||
token_indices = torch.randint(
|
||||
0,
|
||||
block_size * max_blocks_per_req,
|
||||
(req_id.shape[0], num_topk_tokens),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Set some to -1 to test masking
|
||||
token_indices[0, :10] = -1
|
||||
token_indices[3, 50:60] = -1
|
||||
|
||||
# Set some to out of bounds
|
||||
token_indices[2, 100:110] = max_blocks_per_req * block_size
|
||||
token_indices[6, 150:160] = max_blocks_per_req * block_size
|
||||
|
||||
result = triton_convert_req_index_to_global_index(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
BLOCK_SIZE=block_size,
|
||||
NUM_TOPK_TOKENS=num_topk_tokens,
|
||||
HAS_PREFILL_WORKSPACE=True,
|
||||
prefill_workspace_request_ids=prefill_workspace_request_ids,
|
||||
prefill_workspace_starts=prefill_workspace_starts,
|
||||
)
|
||||
|
||||
reference_result = _triton_convert_reference_impl(
|
||||
req_id,
|
||||
block_table,
|
||||
token_indices,
|
||||
block_size,
|
||||
num_topk_tokens,
|
||||
HAS_PREFILL_WORKSPACE=True,
|
||||
prefill_workspace_request_ids=prefill_workspace_request_ids,
|
||||
prefill_workspace_starts=prefill_workspace_starts,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(result, reference_result, rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens,max_buf,start,expected",
|
||||
"seq_lens,max_buf,expected",
|
||||
[
|
||||
# Basic split: totals per chunk ≤ max_buf
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 0, [(0, 2), (2, 3), (3, 4)]),
|
||||
# Non-zero start index
|
||||
(torch.tensor([2, 3, 4, 2]), 5, 1, [(1, 2), (2, 3), (3, 4)]),
|
||||
# Exact fits should split between items when adding the next would
|
||||
# overflow
|
||||
(torch.tensor([5, 5, 5]), 5, 0, [(0, 1), (1, 2), (2, 3)]),
|
||||
(torch.tensor([2, 3, 4, 2]), 5, [(0, 2), (2, 3), (3, 4)]),
|
||||
# Exact fits should split between items when adding the next would overflow
|
||||
(torch.tensor([5, 5, 5]), 5, [(0, 1), (1, 2), (2, 3)]),
|
||||
# All requests fit in a single chunk
|
||||
(torch.tensor([1, 1, 1]), 10, 0, [(0, 3)]),
|
||||
# Large buffer with non-zero start
|
||||
(torch.tensor([4, 4, 4]), 100, 1, [(1, 3)]),
|
||||
(torch.tensor([1, 1, 1]), 10, [(0, 3)]),
|
||||
# Large buffer
|
||||
(torch.tensor([4, 4, 4]), 100, [(0, 3)]),
|
||||
],
|
||||
)
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, start, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf, start)
|
||||
def test_split_prefill_chunks(seq_lens, max_buf, expected):
|
||||
out = split_prefill_chunks(seq_lens, max_buf)
|
||||
assert out == expected
|
||||
|
||||
@ -2403,6 +2403,29 @@ def cp_gather_cache(
|
||||
)
|
||||
|
||||
|
||||
def cp_gather_and_upconvert_fp8_kv_cache(
|
||||
src_cache: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
workspace_starts: torch.Tensor,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
"""Gather and upconvert FP8 KV cache to BF16 workspace.
|
||||
|
||||
Args:
|
||||
src_cache: FP8 KV cache [num_blocks, block_size, 656]
|
||||
dst: BF16 output workspace [total_tokens, 576]
|
||||
block_table: Block indices [num_reqs, max_blocks]
|
||||
seq_lens: Sequence lengths [num_reqs]
|
||||
workspace_starts: Workspace start offsets [num_reqs]
|
||||
batch_size: Number of requests
|
||||
"""
|
||||
torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache(
|
||||
src_cache, dst, block_table, seq_lens, workspace_starts, batch_size
|
||||
)
|
||||
|
||||
|
||||
def indexer_k_quant_and_cache(
|
||||
k: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
|
||||
@ -239,6 +239,7 @@ if TYPE_CHECKING:
|
||||
VLLM_NCCL_INCLUDE_PATH: str | None = None
|
||||
VLLM_USE_FBGEMM: bool = False
|
||||
VLLM_GC_DEBUG: str = ""
|
||||
VLLM_DEBUG_WORKSPACE: bool = False
|
||||
VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
|
||||
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
|
||||
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
|
||||
@ -1537,6 +1538,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - VLLM_GC_DEBUG='{"top_objects":5}': enable GC debugger with
|
||||
# top 5 collected objects
|
||||
"VLLM_GC_DEBUG": lambda: os.getenv("VLLM_GC_DEBUG", ""),
|
||||
# Debug workspace allocations.
|
||||
# logging of workspace resize operations.
|
||||
"VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))),
|
||||
# Disables parallel execution of shared_experts via separate cuda stream
|
||||
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
|
||||
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
|
||||
|
||||
@ -22,12 +22,12 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_current_ubatch_id,
|
||||
dbo_enabled,
|
||||
dbo_maybe_run_recv_hook,
|
||||
dbo_register_recv_hook,
|
||||
dbo_yield,
|
||||
)
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -661,25 +661,6 @@ def _slice_scales(
|
||||
return None
|
||||
|
||||
|
||||
class SharedResizableBuffer:
|
||||
def __init__(self):
|
||||
self.buffer = None
|
||||
|
||||
def get(
|
||||
self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
assert shape != ()
|
||||
shape_numel = prod(shape)
|
||||
if (
|
||||
self.buffer is None
|
||||
or self.buffer.numel() < shape_numel
|
||||
or self.buffer.device != device
|
||||
or self.buffer.dtype != dtype
|
||||
):
|
||||
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
|
||||
return self.buffer[:shape_numel].view(*shape)
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
@ -694,22 +675,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
objects.
|
||||
"""
|
||||
|
||||
class SharedBuffers:
|
||||
def __init__(self) -> None:
|
||||
self.fused_out = SharedResizableBuffer()
|
||||
self.workspace13 = SharedResizableBuffer()
|
||||
self.workspace2 = SharedResizableBuffer()
|
||||
|
||||
# Persistent buffers that are shared across `FusedMoEModularKernel`
|
||||
# instances (layers), to save memory and allocattions.
|
||||
#
|
||||
# We have two sets of buffers to support dual batch overlap (DBO) where each
|
||||
# microbatch (ubatch) should use its own set of buffers to avoid
|
||||
# cross-ubatch contimination.
|
||||
# NOTE that memory is lazily allocated for these buffers, meaning that if
|
||||
# DBO isn't being used, the second SharedBuffers will be empty.
|
||||
shared_buffers: list[SharedBuffers] = [SharedBuffers(), SharedBuffers()]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
@ -806,10 +771,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
assert M_full > 0 and M_chunk > 0
|
||||
|
||||
num_chunks, _ = self._chunk_info(M_full)
|
||||
|
||||
# select per-ubatch buffers to avoid cross-ubatch reuse under DBO
|
||||
ubatch_idx = dbo_current_ubatch_id()
|
||||
buffers = self.shared_buffers[ubatch_idx]
|
||||
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
|
||||
|
||||
# Force worst-case allocation in profiling run for
|
||||
@ -832,14 +793,11 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
expert_tokens_meta,
|
||||
)
|
||||
)
|
||||
buffers.workspace13.get(
|
||||
max_workspace_13, device=device, dtype=workspace_dtype
|
||||
)
|
||||
buffers.workspace2.get(
|
||||
max_workspace_2, device=device, dtype=workspace_dtype
|
||||
)
|
||||
buffers.fused_out.get(
|
||||
max_fused_out_shape, device=device, dtype=workspace_dtype
|
||||
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(max_workspace_13, workspace_dtype),
|
||||
(max_workspace_2, workspace_dtype),
|
||||
(max_fused_out_shape, out_dtype),
|
||||
)
|
||||
|
||||
# Get intermediate workspace shapes based off the chunked M size.
|
||||
@ -866,22 +824,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = buffers.workspace13.get(
|
||||
workspace13_shape, device=device, dtype=workspace_dtype
|
||||
)
|
||||
workspace2 = buffers.workspace2.get(
|
||||
workspace2_shape, device=device, dtype=workspace_dtype
|
||||
)
|
||||
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
# Reuse workspace13 for the output in the non-chunked case as long
|
||||
# as it is large enough. This will not always be the case for standard
|
||||
# format experts and with experts that have empty workspaces.
|
||||
if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape):
|
||||
workspace13, workspace2 = current_workspace_manager().get_simultaneous(
|
||||
(workspace13_shape, workspace_dtype),
|
||||
(workspace2_shape, workspace_dtype),
|
||||
)
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
else:
|
||||
fused_out = buffers.fused_out.get(
|
||||
fused_out_shape, device=device, dtype=out_dtype
|
||||
workspace13, workspace2, fused_out = (
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(workspace13_shape, workspace_dtype),
|
||||
(workspace2_shape, workspace_dtype),
|
||||
(fused_out_shape, out_dtype),
|
||||
)
|
||||
)
|
||||
|
||||
return workspace13, workspace2, fused_out
|
||||
|
||||
@ -83,6 +83,7 @@ from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
@ -616,8 +617,15 @@ def sparse_attn_indexer(
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
# Reserve workspace for indexer during profiling run
|
||||
current_workspace_manager().get_simultaneous(
|
||||
((total_seq_lens, head_dim), torch.float8_e4m3fn),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
@ -651,17 +659,17 @@ def sparse_attn_indexer(
|
||||
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
|
||||
# Get the full shared workspace buffers once (will allocate on first use)
|
||||
workspace_manager = current_workspace_manager()
|
||||
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
|
||||
((total_seq_lens, head_dim), fp8_dtype),
|
||||
((total_seq_lens, 4), torch.uint8),
|
||||
)
|
||||
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_fp8 = torch.empty(
|
||||
[chunk.total_seq_lens, head_dim],
|
||||
device=k.device,
|
||||
dtype=fp8_dtype,
|
||||
)
|
||||
k_scale = torch.empty(
|
||||
[chunk.total_seq_lens, 4],
|
||||
device=k.device,
|
||||
dtype=torch.uint8,
|
||||
)
|
||||
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
|
||||
k_scale = k_scale_full[: chunk.total_seq_lens]
|
||||
ops.cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_fp8,
|
||||
@ -777,15 +785,6 @@ def sparse_attn_indexer_fake(
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
# profile run
|
||||
# NOTE(Chen): create the max possible flattened_kv. So that
|
||||
# profile_run can get correct memory usage.
|
||||
_flattened_kv = torch.empty(
|
||||
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
|
||||
)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
|
||||
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from vllm.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@ -30,13 +30,31 @@ from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.v1.worker.workspace import current_workspace_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# For FP8 sparse attention we have two impelementations:
|
||||
# 1. Mixed batch mode: use the FP8 decode kernel for both prefill and decode this is
|
||||
# done by treating all tokens as single batch.
|
||||
# 2. Separate prefill and decode mode: use the BF16 prefill kernel for prefill
|
||||
# (upconverting the FP8 cache to BF16 then calling the prefill kernel) and using
|
||||
# the FP8 decode kernel for decode.
|
||||
# Currently we use #1 when the number of heads per rank is low (i.e. TP) since the BF16
|
||||
# prefill kernel requires padding the numer of heads to 128 while the decode does not
|
||||
# so when the per ranke head count is below MIN_HEADS_FOR_BF16_PREFILL we use the mixed
|
||||
# batch mode (#2).
|
||||
MIN_HEADS_FOR_BF16_PREFILL = 32
|
||||
|
||||
"""
|
||||
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
|
||||
|
||||
@ -127,19 +145,72 @@ class FlashMLASparseMetadata:
|
||||
dummy_block_table: torch.Tensor
|
||||
cache_lens: torch.Tensor
|
||||
|
||||
fp8_extra_metadata: FP8KernelMetadata | None = None
|
||||
@dataclass
|
||||
class FP8SeperatePrefillDecode:
|
||||
@dataclass
|
||||
class Decode:
|
||||
kernel_metadata: "FlashMLASparseMetadata.FP8KernelMetadata"
|
||||
decode_query_len: int # needed for reshape in spec decode
|
||||
|
||||
@dataclass
|
||||
class Prefill:
|
||||
# Sequence lengths (context + query) for prefill requests
|
||||
# Shape: [num_prefill_reqs]
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# Request ID for each token: -1 for decode tokens, request index
|
||||
# (0, 1, 2, ...) for prefill tokens.
|
||||
# Shape: [num_actual_tokens]
|
||||
request_ids: torch.Tensor
|
||||
|
||||
# Workspace start offsets for all prefill requests
|
||||
# Shape: [num_prefill_reqs], adjusted in-place per chunk to be
|
||||
# 0-indexed within each chunk. Used to map prefill tokens to workspace
|
||||
# offsets in convert_logical_index_to_physical_index
|
||||
workspace_starts: torch.Tensor
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
"""Metadata for a chunk of prefill requests.
|
||||
|
||||
Prefill requests may be chunked to fit within the fixed workspace size.
|
||||
"""
|
||||
|
||||
seq_lens: torch.Tensor
|
||||
tokens_slice: slice
|
||||
block_table: torch.Tensor
|
||||
req_start_idx: int
|
||||
workspace_starts: torch.Tensor
|
||||
chunk_tot_seqlen: int
|
||||
|
||||
chunks: list[Chunk]
|
||||
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
|
||||
decode: Decode | None = None
|
||||
prefill: Prefill | None = None
|
||||
|
||||
fp8_extra_metadata: FP8SeperatePrefillDecode | FP8KernelMetadata | None = None
|
||||
fp8_use_mixed_batch: bool = False
|
||||
|
||||
|
||||
# Kernel with prefill workspace support
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
prefill_request_id_ptr, # int32 [num_tokens], -1 for decode, >=0 for prefill
|
||||
workspace_starts_ptr, # int32 [num_prefill_reqs+1] or nullptr
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
HAS_PREFILL: tl.constexpr,
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
@ -165,7 +236,10 @@ def _convert_req_index_to_global_index_kernel(
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
|
||||
is_prefill = False
|
||||
if HAS_PREFILL:
|
||||
prefill_req_id = tl.load(prefill_request_id_ptr + token_id)
|
||||
is_prefill = prefill_req_id >= 0
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
@ -173,12 +247,18 @@ def _convert_req_index_to_global_index_kernel(
|
||||
# Guard block_table access
|
||||
valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
base = tl.load(bt_ptr, mask=valid_block, other=0)
|
||||
is_invalid_tok |= ~valid_block
|
||||
base = tl.load(bt_ptr, mask=valid_block & ~is_prefill, other=0)
|
||||
out_val = base * BLOCK_SIZE + inblock_off
|
||||
|
||||
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
|
||||
out_val = tl.where(
|
||||
is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off
|
||||
)
|
||||
# Override with prefill output if prefill is enabled
|
||||
if HAS_PREFILL:
|
||||
workspace_start = tl.load(
|
||||
workspace_starts_ptr + prefill_req_id, mask=is_prefill, other=0
|
||||
)
|
||||
prefill_out = workspace_start + tok
|
||||
out_val = tl.where(is_prefill, prefill_out, out_val)
|
||||
out_val = tl.where(is_invalid_tok, -1, out_val)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
@ -192,6 +272,9 @@ def triton_convert_req_index_to_global_index(
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
HAS_PREFILL_WORKSPACE: bool = False,
|
||||
prefill_workspace_request_ids: torch.Tensor | None = None,
|
||||
prefill_workspace_starts: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
@ -202,17 +285,32 @@ def triton_convert_req_index_to_global_index(
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
|
||||
When HAS_PREFILL_WORKSPACE is True, prefill tokens are mapped to workspace offsets
|
||||
instead of global cache slots. prefill_workspace_request_ids and
|
||||
prefill_workspace_starts must be provided.
|
||||
|
||||
prefill_workspace_request_ids: int32 [num_tokens], -1 for decode else
|
||||
prefill request index (maps to prefill_workspace_starts)
|
||||
prefill_workspace_starts: int32 [num_prefills], 0-indexed workspace
|
||||
starts for each prefill request
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})"
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by BLOCK_N ({BLOCK_N})"
|
||||
)
|
||||
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None
|
||||
assert prefill_workspace_starts is not None
|
||||
assert prefill_workspace_request_ids.dtype == torch.int32
|
||||
assert prefill_workspace_starts.dtype == torch.int32
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
num_requests, max_num_blocks_per_req = block_table.shape
|
||||
max_num_blocks_per_req = block_table.shape[1]
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
@ -226,6 +324,13 @@ def triton_convert_req_index_to_global_index(
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Prepare prefill pointers
|
||||
if HAS_PREFILL_WORKSPACE:
|
||||
assert prefill_workspace_request_ids is not None # for mypy
|
||||
assert prefill_workspace_starts is not None # for mypy
|
||||
assert prefill_workspace_request_ids.is_contiguous()
|
||||
assert prefill_workspace_starts.is_contiguous()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
@ -234,10 +339,13 @@ def triton_convert_req_index_to_global_index(
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
prefill_workspace_request_ids,
|
||||
prefill_workspace_starts,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
HAS_PREFILL_WORKSPACE,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
@ -249,7 +357,16 @@ def triton_convert_req_index_to_global_index(
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
def get_prefill_workspace_size(max_model_len: int):
|
||||
# NOTE(Lucas): 5 is a magic number for controlling the prefill buffer size.
|
||||
# May be tuned later.
|
||||
# Memory usage: 5 * max_model_len * 576 * 2 bytes
|
||||
# Example: DeepSeek-V3.2 with max_model_len=163840 ->
|
||||
# 5 * 163840 * 576 * 2 = ~900 MB
|
||||
# This fits nicely below the typical MoE workspace size of >2GB so this is "free"
|
||||
return max_model_len * 5
|
||||
|
||||
|
||||
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
@ -259,29 +376,42 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.layer_names = layer_names
|
||||
cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
|
||||
# Treat requests with query length <= 1 as decodes to match the
|
||||
# DeepGEMM indexer constraint (fp8_paged_mqa_logits only supports next_n <= 2)
|
||||
self._init_reorder_batch_threshold(1, supports_spec_as_decode=True)
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
sm_count = props.multi_processor_count
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
|
||||
self.topk_tokens_tensor = torch.tensor(
|
||||
[self.topk_tokens], device=device, dtype=torch.int32
|
||||
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
# Shape: [max_num_seqs], all elements = topk_tokens (constant for full-CG)
|
||||
self.topk_tokens_tensor = torch.full(
|
||||
(max_num_seqs,), self.topk_tokens, device=device, dtype=torch.int32
|
||||
)
|
||||
self.max_model_len_tensor = torch.tensor(
|
||||
[self.model_config.max_model_len], device=device, dtype=torch.int32
|
||||
# Shape: [max_num_seqs], all elements = max_model_len
|
||||
self.max_model_len_tensor = torch.full(
|
||||
(max_num_seqs,),
|
||||
self.model_config.max_model_len,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
||||
self.dummy_block_table = torch.empty(
|
||||
(1, 1), dtype=torch.int32, device=self.device
|
||||
(max_num_seqs, 1), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Equation taken from FlashMLA/csrc/pybind.cpp
|
||||
@ -299,10 +429,9 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
# Sized for per-request batching (num_decodes + 1)
|
||||
self.num_splits_buffer = torch.empty(
|
||||
# We pack all the tokens into one batch for sparse attention.
|
||||
# Otherwise, we can exceed the sm of `get_mla_metadata`.
|
||||
(2,),
|
||||
(max_num_seqs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
@ -312,30 +441,171 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
def _build_fp8_mixed_decode_prefill(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> FlashMLASparseMetadata:
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
) -> "FlashMLASparseMetadata.FP8KernelMetadata":
|
||||
"""Build FP8 metadata treating all tokens as one mixed batch.
|
||||
|
||||
This matches main branch's approach and avoids the BF16 prefill kernel
|
||||
which has head padding overhead when num_heads is small (high TP case).
|
||||
"""
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
# Build metadata for all tokens as a single batch
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor[:1], # Single batch
|
||||
num_q_tokens_per_head_k=num_tokens * self.num_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=1,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
num_sm_parts = tile_scheduler_metadata.size(0)
|
||||
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
|
||||
:num_sm_parts
|
||||
]
|
||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
||||
num_splits_view = self.num_splits_buffer[:2]
|
||||
num_splits_view.copy_(num_splits)
|
||||
|
||||
fp8_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
||||
num_splits=num_splits_view,
|
||||
cache_lens=self.max_model_len_tensor[:1],
|
||||
dummy_block_table=self.dummy_block_table[:1],
|
||||
)
|
||||
|
||||
return fp8_metadata
|
||||
|
||||
def _build_fp8_separate_prefill_decode(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> "FlashMLASparseMetadata.FP8SeperatePrefillDecode":
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold or 1,
|
||||
require_uniform=True,
|
||||
)
|
||||
)
|
||||
|
||||
FP8Meta = FlashMLASparseMetadata.FP8SeperatePrefillDecode
|
||||
fp8_metadata = FP8Meta(
|
||||
num_decodes=num_decodes,
|
||||
num_prefills=num_prefills,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
)
|
||||
|
||||
# Extract prefill sequence lengths (context + query, not just query)
|
||||
# Decode requests come first in the batch, prefill requests follow
|
||||
prefill_seq_lens = None
|
||||
prefill_request_id = None
|
||||
prefill_workspace_starts = None
|
||||
prefill_chunks = None
|
||||
|
||||
# For pure decode batches, prefill_request_id will be None
|
||||
# For mixed batches, it will have -1 for decode and request_id for prefill
|
||||
if num_prefills > 0:
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
prefill_seq_lens_cpu = seq_lens_cpu[num_decodes:]
|
||||
prefill_seq_lens = seq_lens[num_decodes:]
|
||||
|
||||
# Build prefill_request_id: -1 for decode, request index for
|
||||
# prefill. This enables a single
|
||||
# convert_logical_index_to_physical_index call for all tokens
|
||||
prefill_request_id = torch.full(
|
||||
(num_tokens,), -1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
# Map prefill tokens to their request IDs (0, 1, 2, ...)
|
||||
for req_idx in range(num_prefills):
|
||||
# Get query token range for this prefill request
|
||||
global_req_idx = num_decodes + req_idx
|
||||
req_query_start = query_start_loc_cpu[global_req_idx]
|
||||
req_query_end = query_start_loc_cpu[global_req_idx + 1]
|
||||
prefill_request_id[req_query_start:req_query_end] = req_idx
|
||||
|
||||
# will be adjusted by chunk loop
|
||||
prefill_workspace_starts_cpu = torch.zeros(
|
||||
num_prefills, dtype=torch.int32, pin_memory=True
|
||||
)
|
||||
prefill_workspace_starts_cpu[1:] = torch.cumsum(
|
||||
prefill_seq_lens_cpu[:-1], dim=0
|
||||
)
|
||||
# populated by non-blocking copy after prefill_workspace_starts_cpu is
|
||||
# updated by each chunk
|
||||
prefill_workspace_starts = torch.empty(
|
||||
num_prefills, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Chunk prefill requests to fit within workspace size
|
||||
max_prefill_buffer_size = get_prefill_workspace_size(
|
||||
self.vllm_config.model_config.max_model_len
|
||||
)
|
||||
chunk_bounds = split_prefill_chunks(
|
||||
prefill_seq_lens_cpu, max_prefill_buffer_size
|
||||
)
|
||||
|
||||
prefill_chunks = []
|
||||
for chunk_start, chunk_end in chunk_bounds:
|
||||
# Adjust workspace_starts in-place per chunk to be
|
||||
# 0-indexed within each chunk
|
||||
# Example: seq_lens=[10,15,20,5], chunks=[[0,2],[2,4]]
|
||||
# Initial: workspace_starts=[0,10,25,45]
|
||||
# After: workspace_starts=[0,10,0,20]
|
||||
# (chunk 0 starts at 0, chunk 1 starts at 0)
|
||||
offset = prefill_workspace_starts_cpu[chunk_start].item()
|
||||
prefill_workspace_starts_cpu[chunk_start:chunk_end] -= offset
|
||||
|
||||
chunk_seq_lens = prefill_seq_lens[chunk_start:chunk_end]
|
||||
chunk_tot_seqlen = prefill_seq_lens_cpu[chunk_start:chunk_end].sum()
|
||||
token_start = query_start_loc_cpu[num_decodes + chunk_start].item()
|
||||
token_end = query_start_loc_cpu[num_decodes + chunk_end].item()
|
||||
tokens_slice = slice(token_start, token_end)
|
||||
|
||||
# Create chunk view of gpu tensor
|
||||
chunk_workspace_starts = prefill_workspace_starts[chunk_start:chunk_end]
|
||||
chunk_block_table = common_attn_metadata.block_table_tensor[
|
||||
num_decodes + chunk_start : num_decodes + chunk_end
|
||||
]
|
||||
|
||||
prefill_chunks.append(
|
||||
FP8Meta.Prefill.Chunk(
|
||||
seq_lens=chunk_seq_lens,
|
||||
tokens_slice=tokens_slice,
|
||||
block_table=chunk_block_table,
|
||||
req_start_idx=chunk_start,
|
||||
workspace_starts=chunk_workspace_starts,
|
||||
chunk_tot_seqlen=chunk_tot_seqlen,
|
||||
)
|
||||
)
|
||||
|
||||
prefill_workspace_starts.copy_(
|
||||
prefill_workspace_starts_cpu, non_blocking=True
|
||||
)
|
||||
|
||||
fp8_metadata.prefill = FP8Meta.Prefill(
|
||||
seq_lens=prefill_seq_lens,
|
||||
request_ids=prefill_request_id,
|
||||
workspace_starts=prefill_workspace_starts,
|
||||
chunks=prefill_chunks,
|
||||
)
|
||||
|
||||
if num_decodes > 0:
|
||||
# Compute decode_query_len for spec decode (uniform due to require_uniform)
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
decode_query_len = (query_start_loc_cpu[1] - query_start_loc_cpu[0]).item()
|
||||
|
||||
fp8_extra_metadata = None
|
||||
if self.use_fp8_kv_cache:
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor,
|
||||
num_q_tokens_per_head_k=num_tokens * self.num_heads,
|
||||
cache_seqlens=self.topk_tokens_tensor[:num_decodes],
|
||||
num_q_tokens_per_head_k=decode_query_len * self.num_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=1,
|
||||
@ -348,33 +618,70 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
|
||||
:num_sm_parts
|
||||
]
|
||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
||||
self.num_splits_buffer.copy_(num_splits)
|
||||
# num_splits has size [num_decodes + 1]
|
||||
num_splits_view = self.num_splits_buffer[: num_decodes + 1]
|
||||
num_splits_view.copy_(num_splits)
|
||||
|
||||
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
kernel_meta = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
||||
num_splits=self.num_splits_buffer,
|
||||
# cache_lens and block_table are basically unused in sparse case
|
||||
# but the decode kernel will treat -1 and indices >= cache_lens
|
||||
# as invalid so we make sure cache_lens is large enough to not
|
||||
# accidentally mark indices invalid, we will use -1 exclusively
|
||||
# to mark invalid indices
|
||||
cache_lens=self.max_model_len_tensor,
|
||||
dummy_block_table=self.dummy_block_table,
|
||||
num_splits=num_splits_view,
|
||||
dummy_block_table=self.dummy_block_table[:num_decodes],
|
||||
cache_lens=self.max_model_len_tensor[:num_decodes],
|
||||
)
|
||||
fp8_metadata.decode = FP8Meta.Decode(
|
||||
kernel_metadata=kernel_meta,
|
||||
decode_query_len=decode_query_len,
|
||||
)
|
||||
|
||||
return fp8_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> FlashMLASparseMetadata:
|
||||
cm = common_attn_metadata
|
||||
num_tokens = cm.num_actual_tokens
|
||||
starts = np.asarray(cm.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
fp8_extra_metadata: (
|
||||
FlashMLASparseMetadata.FP8SeperatePrefillDecode
|
||||
| FlashMLASparseMetadata.FP8KernelMetadata
|
||||
| None
|
||||
) = None
|
||||
fp8_use_mixed_batch = self.num_heads < MIN_HEADS_FOR_BF16_PREFILL
|
||||
if self.use_fp8_kv_cache:
|
||||
if fp8_use_mixed_batch:
|
||||
fp8_extra_metadata = self._build_fp8_mixed_decode_prefill(cm)
|
||||
else:
|
||||
fp8_extra_metadata = self._build_fp8_separate_prefill_decode(cm)
|
||||
|
||||
metadata = FlashMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
num_reqs=cm.num_reqs,
|
||||
max_query_len=cm.max_query_len,
|
||||
max_seq_len=cm.max_seq_len,
|
||||
num_actual_tokens=cm.num_actual_tokens,
|
||||
query_start_loc=cm.query_start_loc,
|
||||
slot_mapping=cm.slot_mapping,
|
||||
block_table=cm.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
fp8_extra_metadata=fp8_extra_metadata,
|
||||
fp8_use_mixed_batch=fp8_use_mixed_batch,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
@ -414,12 +721,204 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||
self.padding = 128 if current_platform.is_device_capability(100) else 64
|
||||
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
# Reserve workspace during initialization
|
||||
vllm_config = get_current_vllm_config()
|
||||
assert vllm_config is not None and vllm_config.model_config is not None
|
||||
prefill_workspace_size = get_prefill_workspace_size(
|
||||
vllm_config.model_config.max_model_len
|
||||
)
|
||||
self.prefill_workspace_shape = (prefill_workspace_size, head_size)
|
||||
(self.prefill_bf16_workspace,) = (
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(self.prefill_workspace_shape, torch.bfloat16)
|
||||
)
|
||||
)
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
# Convert per-request indices to global slots (decode) or workspace
|
||||
# offsets (prefill).
|
||||
topk_indices = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
)
|
||||
|
||||
return self._bf16_flash_mla_kernel(q, kv_c_and_k_pe_cache, topk_indices)
|
||||
|
||||
def _forward_fp8_kv_separate_prefill_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
fp8_metadata = attn_metadata.fp8_extra_metadata
|
||||
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode)
|
||||
num_decodes = fp8_metadata.num_decodes
|
||||
|
||||
prefill_request_ids = None
|
||||
prefill_workspace_starts = None
|
||||
has_prefill_workspace = False
|
||||
if fp8_metadata.prefill is not None:
|
||||
prefill_request_ids = fp8_metadata.prefill.request_ids
|
||||
prefill_workspace_starts = fp8_metadata.prefill.workspace_starts
|
||||
has_prefill_workspace = True
|
||||
|
||||
# Convert per-request indices to global slots (decode) or workspace
|
||||
# offsets (prefill).
|
||||
# For FP8 cache: prefill uses workspace mapping (upconverted to BF16)
|
||||
# For BF16 cache: always use global cache slots (no workspace)
|
||||
# prefill_workspace_starts has been adjusted in-place per chunk so
|
||||
# prefill indices automatically come out chunk-local
|
||||
topk_indices = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
HAS_PREFILL_WORKSPACE=has_prefill_workspace,
|
||||
prefill_workspace_request_ids=prefill_request_ids,
|
||||
prefill_workspace_starts=prefill_workspace_starts,
|
||||
)
|
||||
|
||||
fp8_metadata = attn_metadata.fp8_extra_metadata
|
||||
assert isinstance(fp8_metadata, FlashMLASparseMetadata.FP8SeperatePrefillDecode)
|
||||
|
||||
def _fp8_decode(q: torch.Tensor, topk_indices: torch.Tensor) -> torch.Tensor:
|
||||
# Reshape q: (num_decode_tokens, num_heads, head_dim)
|
||||
# -> (num_decodes, seq_len, num_heads, head_dim)
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
seq_len = q.shape[1]
|
||||
# Reshape topk_indices: (num_decode_tokens, topk)
|
||||
# -> (num_decodes, seq_len, topk)
|
||||
topk_indices = topk_indices.view(num_decodes, seq_len, -1)
|
||||
assert fp8_metadata.decode is not None
|
||||
attn_out, _ = self._fp8_flash_mla_kernel(
|
||||
q=q,
|
||||
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
|
||||
topk_indices=topk_indices,
|
||||
kernel_metadata=fp8_metadata.decode.kernel_metadata,
|
||||
)
|
||||
# Reshape output: (num_decodes, seq_len, num_heads, head_dim_v)
|
||||
# -> (num_decode_tokens, num_heads, head_dim_v)
|
||||
return reshape_attn_output_for_spec_decode(attn_out)
|
||||
|
||||
num_decode_tokens = fp8_metadata.num_decode_tokens
|
||||
num_prefill_tokens = fp8_metadata.num_prefill_tokens
|
||||
|
||||
# Pure decode: direct call without allocation
|
||||
if num_decode_tokens > 0 and num_prefill_tokens == 0:
|
||||
assert fp8_metadata.decode is not None
|
||||
attn_out = _fp8_decode(q, topk_indices)
|
||||
else:
|
||||
# Mixed or pure prefill: allocate output tensor
|
||||
attn_out = q.new_empty(
|
||||
(attn_metadata.num_actual_tokens, self.num_heads, self.kv_lora_rank),
|
||||
dtype=q.dtype,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
if num_decode_tokens > 0:
|
||||
attn_out[:num_decode_tokens] = _fp8_decode(
|
||||
q[:num_decode_tokens], topk_indices[:num_decode_tokens]
|
||||
)
|
||||
|
||||
assert fp8_metadata.prefill is not None
|
||||
for chunk in fp8_metadata.prefill.chunks:
|
||||
chunk_workspace = self.prefill_bf16_workspace[: chunk.chunk_tot_seqlen]
|
||||
ops.cp_gather_and_upconvert_fp8_kv_cache(
|
||||
kv_c_and_k_pe_cache,
|
||||
chunk_workspace,
|
||||
chunk.block_table,
|
||||
chunk.seq_lens,
|
||||
chunk.workspace_starts,
|
||||
len(chunk.block_table),
|
||||
)
|
||||
|
||||
chunk_q = q[chunk.tokens_slice]
|
||||
chunk_topk_indices_workspace = topk_indices[chunk.tokens_slice]
|
||||
|
||||
attn_out[chunk.tokens_slice] = self._bf16_flash_mla_kernel(
|
||||
chunk_q,
|
||||
chunk_workspace,
|
||||
chunk_topk_indices_workspace,
|
||||
)
|
||||
|
||||
return attn_out
|
||||
|
||||
def _forward_fp8_kv_mixed_batch(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""Mixed batch FP8 forward path that treats all tokens as one batch.
|
||||
|
||||
This is equivalent to main branch's approach and avoids the BF16
|
||||
prefill kernel which has head padding overhead when num_heads is small.
|
||||
Used when use_mixed_batch is True.
|
||||
"""
|
||||
# Convert per-request indices to global slots (decode) or workspace
|
||||
# offsets (prefill).
|
||||
topk_indices = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=topk_indices.shape[1],
|
||||
)
|
||||
|
||||
assert attn_metadata.fp8_extra_metadata is not None
|
||||
assert isinstance(
|
||||
attn_metadata.fp8_extra_metadata, FlashMLASparseMetadata.FP8KernelMetadata
|
||||
)
|
||||
fp8_metadata = attn_metadata.fp8_extra_metadata
|
||||
|
||||
_attn_out, _ = self._fp8_flash_mla_kernel(
|
||||
q=q.unsqueeze(0), # unsqueeze to add batch_dim: (T, H, D) -> (1, T, H, D)
|
||||
kv_c_and_k_pe_cache=kv_c_and_k_pe_cache,
|
||||
topk_indices=topk_indices.unsqueeze(0), # (T, topk) -> (1, T, topk)
|
||||
kernel_metadata=fp8_metadata,
|
||||
)
|
||||
|
||||
# Output is (1, T, H, D_v), squeeze back to (T, H, D_v)
|
||||
return _attn_out.squeeze(0)
|
||||
|
||||
def _fp8_flash_mla_kernel(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
kernel_metadata: FlashMLASparseMetadata.FP8KernelMetadata,
|
||||
) -> torch.Tensor:
|
||||
return flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
|
||||
block_table=kernel_metadata.dummy_block_table,
|
||||
head_dim_v=512,
|
||||
cache_seqlens=kernel_metadata.cache_lens,
|
||||
tile_scheduler_metadata=kernel_metadata.scheduler_metadata,
|
||||
num_splits=kernel_metadata.num_splits,
|
||||
is_fp8_kvcache=True,
|
||||
indices=topk_indices,
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
|
||||
def _bf16_flash_mla_kernel(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = q.shape[0]
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
@ -445,31 +944,6 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
output = output[:, : self.num_heads, :]
|
||||
return output
|
||||
|
||||
def _forward_fp8_kv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.fp8_extra_metadata is not None
|
||||
extra_metadata = attn_metadata.fp8_extra_metadata
|
||||
|
||||
_attn_out, _ = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
|
||||
block_table=extra_metadata.dummy_block_table,
|
||||
head_dim_v=512,
|
||||
cache_seqlens=extra_metadata.cache_lens,
|
||||
tile_scheduler_metadata=extra_metadata.scheduler_metadata,
|
||||
num_splits=extra_metadata.num_splits,
|
||||
is_fp8_kvcache=True,
|
||||
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
|
||||
return _attn_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
@ -477,7 +951,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
attn_metadata: FlashMLASparseMetadata | None,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
@ -493,6 +967,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Dummy run - no need to allocate buffers
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
@ -505,6 +980,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
@ -514,16 +990,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
ql_nope = ql_nope.transpose(0, 1)
|
||||
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
# TODO: handle index / kv_cache correctly
|
||||
topk_indices_global = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
||||
)
|
||||
use_fp8_cache = self.kv_cache_dtype == "fp8_ds_mla"
|
||||
|
||||
q = torch.cat([ql_nope, q_pe], dim=-1)
|
||||
|
||||
@ -538,13 +1005,15 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||
attn_out = self._forward_bf16_kv(
|
||||
q, kv_cache, topk_indices_global, attn_metadata
|
||||
if not use_fp8_cache:
|
||||
attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices, attn_metadata)
|
||||
elif attn_metadata.fp8_use_mixed_batch:
|
||||
attn_out = self._forward_fp8_kv_mixed_batch(
|
||||
q, kv_cache, topk_indices, attn_metadata
|
||||
)
|
||||
else:
|
||||
attn_out = self._forward_fp8_kv(
|
||||
q, kv_cache, topk_indices_global, attn_metadata
|
||||
attn_out = self._forward_fp8_kv_separate_prefill_decode(
|
||||
q, kv_cache, topk_indices, attn_metadata
|
||||
)
|
||||
|
||||
self._v_up_proj(attn_out, out=output[:num_actual_toks])
|
||||
|
||||
@ -18,6 +18,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -176,40 +177,15 @@ def kv_spans_from_batches(
|
||||
|
||||
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
|
||||
# May be tuned later.
|
||||
return max_model_len * 2
|
||||
|
||||
|
||||
def split_prefill_chunks(
|
||||
seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
|
||||
such that the total sequence length of each chunk is less than the
|
||||
maximum prefill buffer size.
|
||||
|
||||
Args:
|
||||
seq_lens_cpu: The sequence lengths of the prefill requests.
|
||||
max_prefill_buffer_size: The maximum prefill buffer size.
|
||||
reqs_start: The start index of the prefill requests.
|
||||
|
||||
Returns:
|
||||
A list of tuples of (reqs_start, reqs_end).
|
||||
"""
|
||||
chunk_seq_ids = []
|
||||
total_seq_lens = 0
|
||||
for i in range(reqs_start, len(seq_lens_cpu)):
|
||||
cur_seq_len = seq_lens_cpu[i].item()
|
||||
assert cur_seq_len <= max_prefill_buffer_size
|
||||
total_seq_lens += cur_seq_len
|
||||
if total_seq_lens > max_prefill_buffer_size:
|
||||
chunk_seq_ids.append((reqs_start, i))
|
||||
reqs_start = i
|
||||
total_seq_lens = cur_seq_len
|
||||
if total_seq_lens > 0:
|
||||
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
|
||||
return chunk_seq_ids
|
||||
# NOTE(Chen): 40 is a magic number for controlling the prefill buffer size.
|
||||
# Each entry is 128 fp8 bytes and 4 scale bytes for a total of 132 bytes.
|
||||
# The flashmla_sparse backend uses a workspace size of 5 * max_model_len.
|
||||
# The memory usage of the workspace there is 576 * 2 bytes; so we size this as
|
||||
# (576 * 2 // 132) * 5 = 40 to maximize this workspace size while still fitting
|
||||
# within the flashmla_sparse workspace.
|
||||
# For DeepSeek-V3.2, the max_model_len is 163840.
|
||||
# 40 * 163840 * 132 = 865075200 bytes = 825 MB
|
||||
return max_model_len * 40
|
||||
|
||||
|
||||
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
@ -302,9 +278,9 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.seq_lens_cpu[num_decodes:],
|
||||
self.max_prefill_buffer_size,
|
||||
num_decodes,
|
||||
request_offset=num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
|
||||
@ -937,6 +937,33 @@ def split_decodes_and_prefills(
|
||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||
|
||||
|
||||
def split_prefill_chunks(
|
||||
seq_lens_cpu: torch.Tensor, workspace_size: int, request_offset: int = 0
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Split the prefill requests into chunks such that the total sequence length
|
||||
of each chunk is less than or equal to the workspace size.
|
||||
|
||||
Args:
|
||||
seq_lens_cpu: The sequence lengths of the prefill requests on CPU.
|
||||
workspace_size: The maximum workspace size (in tokens) per chunk.
|
||||
request_offset: The offset to add to the request indices.
|
||||
Returns:
|
||||
A list of tuples of (reqs_start, reqs_end) representing chunk boundaries.
|
||||
"""
|
||||
chunk_bounds = []
|
||||
i, n = 0, len(seq_lens_cpu)
|
||||
assert torch.all(seq_lens_cpu <= workspace_size).item()
|
||||
|
||||
while i < n:
|
||||
start, chunk_total = i, 0
|
||||
while i < n and (chunk_total + (s := seq_lens_cpu[i].item())) <= workspace_size:
|
||||
chunk_total += s
|
||||
i += 1
|
||||
chunk_bounds.append((start + request_offset, i + request_offset))
|
||||
return chunk_bounds
|
||||
|
||||
|
||||
def reorder_batch_to_split_decodes_and_prefills(
|
||||
input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput",
|
||||
|
||||
@ -162,6 +162,7 @@ from vllm.v1.worker.ubatch_utils import (
|
||||
maybe_create_ubatch_slices,
|
||||
)
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
from vllm.v1.worker.workspace import lock_workspace
|
||||
|
||||
from .utils import (
|
||||
AttentionGroup,
|
||||
@ -297,6 +298,7 @@ class GPUModelRunner(
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype_str_to_dtype(
|
||||
cache_config.cache_dtype, self.model_config
|
||||
)
|
||||
@ -4597,6 +4599,10 @@ class GPUModelRunner(
|
||||
# after here.
|
||||
set_cudagraph_capturing_enabled(False)
|
||||
|
||||
# Lock workspace to prevent resizing during execution.
|
||||
# Max workspace sizes should have been captured during warmup/profiling.
|
||||
lock_workspace()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
elapsed_time = end_time - start_time
|
||||
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
|
||||
|
||||
@ -54,6 +54,7 @@ from vllm.v1.outputs import (
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -255,6 +256,10 @@ class Worker(WorkerBase):
|
||||
else:
|
||||
raise RuntimeError(f"Not support device type: {self.device_config.device}")
|
||||
|
||||
# Initialize workspace manager
|
||||
num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1
|
||||
init_workspace_manager(self.device, num_ubatches)
|
||||
|
||||
# Construct the model runner
|
||||
if self.use_v2_model_runner:
|
||||
from vllm.v1.worker.gpu.model_runner import (
|
||||
|
||||
245
vllm/v1/worker/workspace.py
Normal file
245
vllm/v1/worker/workspace.py
Normal file
@ -0,0 +1,245 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from itertools import accumulate
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import round_up
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _compute_bytes(shape: tuple[int, ...], dtype: torch.dtype) -> int:
|
||||
return prod(shape) * dtype.itemsize
|
||||
|
||||
|
||||
# Constants
|
||||
_MB = 1024**2
|
||||
_GiB = 1024**3
|
||||
|
||||
# Global workspace manager instance
|
||||
_manager: Optional["WorkspaceManager"] = None
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
"""Manager for workspace allocation.
|
||||
|
||||
Manages workspace buffers for DBO (Dual Batch Overlap) execution.
|
||||
Can be locked to prevent further growth during execution.
|
||||
"""
|
||||
|
||||
def __init__(self, device: torch.device, num_ubatches: int | None = None):
|
||||
self._device = device
|
||||
# Cache num ubatches at init based on configuration (default to 1)
|
||||
self._num_ubatches = num_ubatches if num_ubatches is not None else 1
|
||||
self._current_workspaces: list[torch.Tensor | None] = [None, None]
|
||||
self._locked: bool = False
|
||||
|
||||
@staticmethod
|
||||
def _workspace_size_bytes(workspace: torch.Tensor | None) -> int:
|
||||
"""Get size of workspace in bytes."""
|
||||
if workspace is None:
|
||||
return 0
|
||||
return workspace.numel() * workspace.element_size()
|
||||
|
||||
def lock(self) -> None:
|
||||
"""Lock the workspace to prevent further growth.
|
||||
|
||||
After locking, any attempt to allocate a larger workspace will raise
|
||||
an assertion error. This ensures workspace size is fixed during execution.
|
||||
"""
|
||||
self._locked = True
|
||||
if envs.VLLM_DEBUG_WORKSPACE:
|
||||
logger.info(
|
||||
"[WORKSPACE DEBUG] Workspace locked. Current sizes: %s",
|
||||
[
|
||||
self._workspace_size_bytes(ws) / _MB
|
||||
for ws in self._current_workspaces
|
||||
if ws is not None
|
||||
],
|
||||
)
|
||||
|
||||
def is_locked(self) -> bool:
|
||||
"""Check if workspace is locked."""
|
||||
return self._locked
|
||||
|
||||
def get_simultaneous(
|
||||
self, *shapes_and_dtypes: tuple[tuple[int, ...], torch.dtype]
|
||||
) -> list[torch.Tensor]:
|
||||
"""Get multiple workspace tensors simultaneously from a single allocation.
|
||||
|
||||
Args:
|
||||
*shapes_and_dtypes: One or more (shape, dtype) tuples.
|
||||
|
||||
Returns:
|
||||
List of tensor views into the workspace buffer, one per shape/dtype pair.
|
||||
"""
|
||||
actual_bytes = [_compute_bytes(s, d) for s, d in shapes_and_dtypes]
|
||||
aligned_bytes = [round_up(actual, 256) for actual in actual_bytes]
|
||||
total_bytes = sum(aligned_bytes)
|
||||
|
||||
# Calculate cumulative offsets using itertools.accumulate
|
||||
offsets = list(accumulate([0] + aligned_bytes[:-1]))
|
||||
|
||||
current_workspace = self._ensure_workspace_size(total_bytes)
|
||||
|
||||
return [
|
||||
current_workspace[offsets[i] : offsets[i] + actual_bytes[i]]
|
||||
.view(shapes_and_dtypes[i][1])
|
||||
.reshape(shapes_and_dtypes[i][0])
|
||||
for i in range(len(shapes_and_dtypes))
|
||||
]
|
||||
|
||||
def _ensure_workspace_size(self, required_bytes: int) -> torch.Tensor:
|
||||
"""Ensure workspace is allocated and large enough, return current workspace.
|
||||
|
||||
Args:
|
||||
required_bytes: The number of bytes required.
|
||||
|
||||
Returns:
|
||||
The current workspace tensor.
|
||||
"""
|
||||
ubatch_id = dbo_current_ubatch_id()
|
||||
current_workspace = self._current_workspaces[ubatch_id]
|
||||
current_size = self._workspace_size_bytes(current_workspace)
|
||||
|
||||
if current_size < required_bytes:
|
||||
|
||||
def get_caller_info() -> str:
|
||||
"""Find first frame outside WorkspaceManager."""
|
||||
curr_frame = inspect.currentframe()
|
||||
if curr_frame is None:
|
||||
return "unknown"
|
||||
# Walk up the stack skipping WorkspaceManager frames
|
||||
curr_frame = curr_frame.f_back
|
||||
while curr_frame is not None:
|
||||
# TODO: This only catches instance methods (self), missing
|
||||
# classmethods and staticmethods. Once Python 3.11+ is the
|
||||
# minimum supported version, use co_qualname instead:
|
||||
# qualname = curr_frame.f_code.co_qualname
|
||||
# if qualname.startswith("WorkspaceManager."):
|
||||
if isinstance(curr_frame.f_locals.get("self"), WorkspaceManager):
|
||||
curr_frame = curr_frame.f_back
|
||||
continue
|
||||
filename = os.path.basename(curr_frame.f_code.co_filename)
|
||||
return (
|
||||
f"{filename}:{curr_frame.f_lineno}:{curr_frame.f_code.co_name}"
|
||||
)
|
||||
return "unknown"
|
||||
|
||||
if self._locked:
|
||||
raise AssertionError(
|
||||
f"Workspace is locked but allocation from '{get_caller_info()}' "
|
||||
f"requires {required_bytes / _MB:.2f} MB, current size is "
|
||||
f"{current_size / _MB:.2f} MB. "
|
||||
"Workspace growth is not allowed after locking."
|
||||
)
|
||||
|
||||
for ubatch_id in range(self._num_ubatches):
|
||||
current_workspace = self._current_workspaces[ubatch_id]
|
||||
if current_workspace is None:
|
||||
self._current_workspaces[ubatch_id] = torch.empty(
|
||||
(required_bytes,), dtype=torch.uint8, device=self._device
|
||||
)
|
||||
elif self._workspace_size_bytes(current_workspace) < required_bytes:
|
||||
current_workspace.resize_(required_bytes)
|
||||
|
||||
if envs.VLLM_DEBUG_WORKSPACE:
|
||||
logger.info(
|
||||
"[WORKSPACE DEBUG] Resized workspace from '%s': %.2f MB -> "
|
||||
"%.2f MB (%d ubatches, total memory %.2f MB)",
|
||||
get_caller_info(),
|
||||
current_size / _MB,
|
||||
required_bytes / _MB,
|
||||
self._num_ubatches,
|
||||
required_bytes * self._num_ubatches / _MB,
|
||||
)
|
||||
|
||||
current_workspace = self._current_workspaces[dbo_current_ubatch_id()]
|
||||
|
||||
return current_workspace
|
||||
|
||||
|
||||
def is_workspace_manager_initialized() -> bool:
|
||||
"""Check if workspace manager has been initialized.
|
||||
|
||||
Returns:
|
||||
True if workspace manager is initialized, False otherwise.
|
||||
"""
|
||||
return _manager is not None
|
||||
|
||||
|
||||
def current_workspace_manager() -> "WorkspaceManager":
|
||||
"""Get the current workspace manager instance.
|
||||
|
||||
Raises:
|
||||
AssertionError: If workspace manager has not been initialized.
|
||||
"""
|
||||
assert _manager is not None, (
|
||||
"WorkspaceManager not initialized. Call init_workspace_manager() "
|
||||
"with a device before using workspace functions."
|
||||
)
|
||||
return _manager
|
||||
|
||||
|
||||
def init_workspace_manager(
|
||||
device: torch.device, num_ubatches: int | None = None
|
||||
) -> None:
|
||||
"""Initialize the workspace manager with a device.
|
||||
|
||||
Must be called before using any workspace functions. Typically called
|
||||
from GPUModelRunner.__init__.
|
||||
|
||||
Args:
|
||||
device: The device to allocate workspace on.
|
||||
num_ubatches: Number of micro-batches. Defaults to 1.
|
||||
"""
|
||||
global _manager
|
||||
if _manager is not None:
|
||||
logger.warning(
|
||||
"WorkspaceManager already initialized on device %s, "
|
||||
"reinitializing on device %s",
|
||||
_manager._device,
|
||||
device,
|
||||
)
|
||||
_manager = WorkspaceManager(device, num_ubatches)
|
||||
|
||||
|
||||
def lock_workspace() -> None:
|
||||
"""Lock the workspace to prevent further growth.
|
||||
|
||||
After calling this function, any attempt to allocate a workspace larger
|
||||
than the current size will raise an AssertionError. This ensures that
|
||||
workspace size is fixed during execution and prevents unexpected memory
|
||||
allocations in the hot path.
|
||||
|
||||
Example:
|
||||
# During initialization
|
||||
init_workspace_manager(device)
|
||||
reserve_workspace(shape1, dtype1)
|
||||
reserve_workspace(shape2, dtype2)
|
||||
|
||||
# Lock after warmup/profiling
|
||||
lock_workspace()
|
||||
|
||||
# Now all get_workspace calls must fit in pre-allocated size
|
||||
"""
|
||||
current_workspace_manager().lock()
|
||||
|
||||
|
||||
def reset_workspace_manager() -> None:
|
||||
"""Reset the workspace manager to uninitialized state.
|
||||
|
||||
This is primarily intended for testing purposes to allow tests
|
||||
to reinitialize the workspace manager cleanly.
|
||||
"""
|
||||
global _manager
|
||||
_manager = None
|
||||
Loading…
x
Reference in New Issue
Block a user