mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 06:55:01 +08:00
[ROCm][FP8][Kernel] FP8 quantization fused into Custom Paged Attention (#17139)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
7377dd0307
commit
32aa74c09c
@ -1287,7 +1287,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
|||||||
// max_num_partitions, head_size]
|
// max_num_partitions, head_size]
|
||||||
const int* __restrict__ context_lens, // [num_seqs]
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||||
const int max_num_partitions) {
|
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
|
||||||
const auto num_heads = gridDim.x;
|
const auto num_heads = gridDim.x;
|
||||||
const auto head_idx = blockIdx.x;
|
const auto head_idx = blockIdx.x;
|
||||||
const auto seq_idx = blockIdx.y;
|
const auto seq_idx = blockIdx.y;
|
||||||
@ -1465,8 +1465,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
|||||||
|
|
||||||
const float inv_global_exp_sum =
|
const float inv_global_exp_sum =
|
||||||
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
|
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
|
||||||
|
const float out_scale =
|
||||||
|
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
|
||||||
acc *= inv_global_exp_sum;
|
acc *= inv_global_exp_sum;
|
||||||
|
acc *= out_scale;
|
||||||
const int64_t query_start_off = static_cast<int64_t>(
|
const int64_t query_start_off = static_cast<int64_t>(
|
||||||
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
|
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
|
||||||
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
|
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
|
||||||
@ -1548,7 +1550,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
|||||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||||
const int* __restrict__ context_lens, // [num_seqs]
|
const int* __restrict__ context_lens, // [num_seqs]
|
||||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||||
const int max_num_partitions) {
|
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
|
||||||
UNREACHABLE_CODE
|
UNREACHABLE_CODE
|
||||||
}
|
}
|
||||||
// clang-format on
|
// clang-format on
|
||||||
@ -1582,7 +1584,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
|||||||
PARTITION_SIZE, NPAR_LOOPS> \
|
PARTITION_SIZE, NPAR_LOOPS> \
|
||||||
<<<reduce_grid, reduce_block, 0, stream>>>( \
|
<<<reduce_grid, reduce_block, 0, stream>>>( \
|
||||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
|
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
|
||||||
context_lens_ptr, query_start_loc_ptr, max_num_partitions);
|
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
|
||||||
|
fp8_out_scale_ptr);
|
||||||
|
|
||||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||||
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
||||||
@ -1594,7 +1597,7 @@ void paged_attention_custom_launcher(
|
|||||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||||
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
|
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
|
||||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||||
torch::Tensor& v_scale) {
|
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale) {
|
||||||
int num_seqs = block_tables.size(0);
|
int num_seqs = block_tables.size(0);
|
||||||
int num_heads = query.size(1);
|
int num_heads = query.size(1);
|
||||||
int head_size = query.size(2);
|
int head_size = query.size(2);
|
||||||
@ -1626,6 +1629,11 @@ void paged_attention_custom_launcher(
|
|||||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||||
|
// NOTE: fp8_out_scale is optional.
|
||||||
|
const auto fp8_out_scale_ptr =
|
||||||
|
fp8_out_scale
|
||||||
|
? static_cast<const float*>(fp8_out_scale.value().data_ptr())
|
||||||
|
: nullptr;
|
||||||
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
|
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
|
||||||
|
|
||||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
||||||
@ -1736,29 +1744,50 @@ void paged_attention_custom_launcher(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \
|
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
|
||||||
ALIBI_ENABLED) \
|
PSIZE, ALIBI_ENABLED) \
|
||||||
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
|
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
|
||||||
PSIZE, ALIBI_ENABLED>( \
|
PSIZE, ALIBI_ENABLED>( \
|
||||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
|
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
|
||||||
max_context_len, alibi_slopes, k_scale, v_scale);
|
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale);
|
||||||
|
|
||||||
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||||
PSIZE) \
|
OUTT, PSIZE) \
|
||||||
if (alibi_slopes) { \
|
if (alibi_slopes) { \
|
||||||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \
|
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
|
||||||
|
true); \
|
||||||
} else { \
|
} else { \
|
||||||
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \
|
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
|
||||||
|
false); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(__HIPCC__) && defined(__gfx90a__)
|
||||||
|
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
|
||||||
|
if (fp8_out_scale) { \
|
||||||
|
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
|
||||||
|
} else { \
|
||||||
|
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
|
||||||
|
256); \
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
|
||||||
|
if (fp8_out_scale) { \
|
||||||
|
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||||
|
uint8_t, 256); \
|
||||||
|
} else { \
|
||||||
|
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
|
||||||
|
256); \
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
|
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
|
||||||
switch (block_size) { \
|
switch (block_size) { \
|
||||||
case 16: \
|
case 16: \
|
||||||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \
|
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
|
||||||
break; \
|
break; \
|
||||||
case 32: \
|
case 32: \
|
||||||
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \
|
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
|
||||||
break; \
|
break; \
|
||||||
default: \
|
default: \
|
||||||
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
||||||
@ -1795,7 +1824,8 @@ void paged_attention(
|
|||||||
int64_t block_size, int64_t max_context_len,
|
int64_t block_size, int64_t max_context_len,
|
||||||
const std::optional<torch::Tensor>& alibi_slopes,
|
const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||||
torch::Tensor& v_scale) {
|
torch::Tensor& v_scale,
|
||||||
|
const c10::optional<torch::Tensor>& fp8_out_scale) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
const int head_size = query.size(2);
|
const int head_size = query.size(2);
|
||||||
if (kv_cache_dtype == "auto") {
|
if (kv_cache_dtype == "auto") {
|
||||||
|
|||||||
@ -11,14 +11,12 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
|
|||||||
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
|
||||||
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
|
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
|
||||||
|
|
||||||
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
|
void paged_attention(
|
||||||
torch::Tensor& max_logits, torch::Tensor& tmp_out,
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||||
torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int64_t num_kv_heads,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
double scale, torch::Tensor& block_tables,
|
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||||
torch::Tensor& context_lens,
|
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
|
||||||
const std::optional<torch::Tensor>& query_start_loc,
|
int64_t max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||||
int64_t block_size, int64_t max_context_len,
|
|
||||||
const std::optional<torch::Tensor>& alibi_slopes,
|
|
||||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||||
torch::Tensor& v_scale);
|
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale);
|
||||||
|
|||||||
@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
|||||||
" int max_context_len,"
|
" int max_context_len,"
|
||||||
" Tensor? alibi_slopes,"
|
" Tensor? alibi_slopes,"
|
||||||
" str kv_cache_dtype,"
|
" str kv_cache_dtype,"
|
||||||
" Tensor k_scale, Tensor v_scale) -> ()");
|
" Tensor k_scale, Tensor v_scale,"
|
||||||
|
" Tensor? fp8_out_scale) -> ()");
|
||||||
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -117,13 +117,14 @@ def paged_attention_rocm(
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
k_scale: torch.Tensor,
|
k_scale: torch.Tensor,
|
||||||
v_scale: torch.Tensor,
|
v_scale: torch.Tensor,
|
||||||
|
fp8_out_scale: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
|
||||||
key_cache, value_cache, num_kv_heads,
|
key_cache, value_cache, num_kv_heads,
|
||||||
scale, block_tables, seq_lens,
|
scale, block_tables, seq_lens,
|
||||||
query_start_loc, block_size, max_seq_len,
|
query_start_loc, block_size, max_seq_len,
|
||||||
alibi_slopes, kv_cache_dtype, k_scale,
|
alibi_slopes, kv_cache_dtype, k_scale,
|
||||||
v_scale)
|
v_scale, fp8_out_scale)
|
||||||
|
|
||||||
|
|
||||||
def mla_decode_kvcache_cpu(
|
def mla_decode_kvcache_cpu(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user