mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:45:00 +08:00
[Kernel][Attention] Separate Attention.kv_scale into k_scale and v_scale (#6081)
This commit is contained in:
parent
160e1d8c99
commit
978aed5300
@ -100,7 +100,7 @@ def main(
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
# Using default kv_scale
|
# Using default kv_scale
|
||||||
kv_scale = 1.0
|
k_scale = v_scale = 1.0
|
||||||
|
|
||||||
for _ in range(num_iters):
|
for _ in range(num_iters):
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
@ -117,7 +117,8 @@ def main(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
ops.paged_attention_v2(
|
ops.paged_attention_v2(
|
||||||
@ -136,7 +137,8 @@ def main(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid version: {version}")
|
raise ValueError(f"Invalid version: {version}")
|
||||||
|
|||||||
@ -105,9 +105,9 @@ __device__ void paged_attention_kernel(
|
|||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
const float k_scale, const float v_scale, const int tp_rank,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||||
const int seq_idx = blockIdx.y;
|
const int seq_idx = blockIdx.y;
|
||||||
const int partition_idx = blockIdx.z;
|
const int partition_idx = blockIdx.z;
|
||||||
const int max_num_partitions = gridDim.z;
|
const int max_num_partitions = gridDim.z;
|
||||||
@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
|
|||||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
||||||
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||||
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
||||||
k_vec_quant, kv_scale);
|
k_vec_quant, k_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
|
|||||||
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||||
// Vector conversion from V_quant_vec to V_vec.
|
// Vector conversion from V_quant_vec to V_vec.
|
||||||
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
||||||
kv_scale);
|
v_scale);
|
||||||
}
|
}
|
||||||
if (block_idx == num_seq_blocks - 1) {
|
if (block_idx == num_seq_blocks - 1) {
|
||||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
||||||
@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel(
|
|||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
const float k_scale, const float v_scale, const int tp_rank,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||||
KV_DTYPE, IS_BLOCK_SPARSE>(
|
KV_DTYPE, IS_BLOCK_SPARSE>(
|
||||||
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
||||||
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
||||||
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
|
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
|
||||||
blocksparse_vert_stride, blocksparse_block_size,
|
blocksparse_vert_stride, blocksparse_block_size,
|
||||||
blocksparse_head_sliding_step);
|
blocksparse_head_sliding_step);
|
||||||
}
|
}
|
||||||
@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel(
|
|||||||
const int max_num_blocks_per_seq,
|
const int max_num_blocks_per_seq,
|
||||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||||
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
const float k_scale, const float v_scale, const int tp_rank,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||||
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
|
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
|
||||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||||
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
||||||
kv_block_stride, kv_head_stride, kv_scale, tp_rank,
|
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
|
||||||
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
|
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
|
||||||
blocksparse_head_sliding_step);
|
blocksparse_head_sliding_step);
|
||||||
}
|
}
|
||||||
@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|||||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
||||||
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||||
kv_scale, tp_rank, blocksparse_local_blocks, \
|
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
||||||
blocksparse_vert_stride, blocksparse_block_size, \
|
blocksparse_vert_stride, blocksparse_block_size, \
|
||||||
blocksparse_head_sliding_step);
|
blocksparse_head_sliding_step);
|
||||||
|
|
||||||
@ -694,8 +694,8 @@ void paged_attention_v1_launcher(
|
|||||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||||
const int tp_rank, const int blocksparse_local_blocks,
|
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_head_sliding_step) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
@ -770,7 +770,7 @@ void paged_attention_v1_launcher(
|
|||||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
||||||
IS_BLOCK_SPARSE>( \
|
IS_BLOCK_SPARSE>( \
|
||||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||||
seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \
|
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
|
||||||
blocksparse_local_blocks, blocksparse_vert_stride, \
|
blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||||
blocksparse_block_size, blocksparse_head_sliding_step);
|
blocksparse_block_size, blocksparse_head_sliding_step);
|
||||||
|
|
||||||
@ -815,8 +815,8 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int64_t block_size, int64_t max_seq_len,
|
int64_t block_size, int64_t max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step) {
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||||
@ -833,7 +833,7 @@ void paged_attention_v1(
|
|||||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
||||||
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||||
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
|
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
|
||||||
blocksparse_local_blocks, blocksparse_vert_stride, \
|
blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||||
blocksparse_block_size, blocksparse_head_sliding_step); \
|
blocksparse_block_size, blocksparse_head_sliding_step); \
|
||||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
||||||
@ -850,8 +850,8 @@ void paged_attention_v2_launcher(
|
|||||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||||
const int tp_rank, const int blocksparse_local_blocks,
|
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||||
const int blocksparse_head_sliding_step) {
|
const int blocksparse_head_sliding_step) {
|
||||||
int num_seqs = query.size(0);
|
int num_seqs = query.size(0);
|
||||||
@ -932,8 +932,9 @@ void paged_attention_v2_launcher(
|
|||||||
IS_BLOCK_SPARSE>( \
|
IS_BLOCK_SPARSE>( \
|
||||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||||
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
||||||
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
|
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
||||||
blocksparse_block_size, blocksparse_head_sliding_step);
|
blocksparse_vert_stride, blocksparse_block_size, \
|
||||||
|
blocksparse_head_sliding_step);
|
||||||
|
|
||||||
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
||||||
switch (is_block_sparse) { \
|
switch (is_block_sparse) { \
|
||||||
@ -980,8 +981,8 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& seq_lens, // [num_seqs]
|
torch::Tensor& seq_lens, // [num_seqs]
|
||||||
int64_t block_size, int64_t max_seq_len,
|
int64_t block_size, int64_t max_seq_len,
|
||||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step) {
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||||
|
|||||||
@ -18,8 +18,8 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
|||||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype,
|
const std::string& kv_cache_dtype, const double k_scale,
|
||||||
const double kv_scale);
|
const double v_scale);
|
||||||
|
|
||||||
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key_cache,
|
torch::Tensor& key_cache,
|
||||||
|
|||||||
@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
// block_size]
|
// block_size]
|
||||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||||
const int key_stride, const int value_stride, const int num_heads,
|
const int key_stride, const int value_stride, const int num_heads,
|
||||||
const int head_size, const int block_size, const int x,
|
const int head_size, const int block_size, const int x, const float k_scale,
|
||||||
const float kv_scale) {
|
const float v_scale) {
|
||||||
const int64_t token_idx = blockIdx.x;
|
const int64_t token_idx = blockIdx.x;
|
||||||
const int64_t slot_idx = slot_mapping[token_idx];
|
const int64_t slot_idx = slot_mapping[token_idx];
|
||||||
if (slot_idx < 0) {
|
if (slot_idx < 0) {
|
||||||
@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
value_cache[tgt_value_idx] = tgt_value;
|
value_cache[tgt_value_idx] = tgt_value;
|
||||||
} else {
|
} else {
|
||||||
key_cache[tgt_key_idx] =
|
key_cache[tgt_key_idx] =
|
||||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||||
value_cache[tgt_value_idx] =
|
value_cache[tgt_value_idx] =
|
||||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -248,7 +248,7 @@ __global__ void reshape_and_cache_flash_kernel(
|
|||||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||||
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
||||||
num_heads, head_size, block_size, x, kv_scale);
|
num_heads, head_size, block_size, x, k_scale, v_scale);
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
@ -258,7 +258,8 @@ void reshape_and_cache(
|
|||||||
torch::Tensor&
|
torch::Tensor&
|
||||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||||
torch::Tensor& slot_mapping, // [num_tokens]
|
torch::Tensor& slot_mapping, // [num_tokens]
|
||||||
const std::string& kv_cache_dtype, const double kv_scale) {
|
const std::string& kv_cache_dtype, const double k_scale,
|
||||||
|
const double v_scale) {
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
int head_size = key.size(2);
|
int head_size = key.size(2);
|
||||||
@ -318,13 +319,13 @@ namespace vllm {
|
|||||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||||
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
||||||
Tout* __restrict__ dst_cache,
|
Tout* __restrict__ dst_cache,
|
||||||
const float kv_scale,
|
const float scale,
|
||||||
const int64_t block_stride) {
|
const int64_t block_stride) {
|
||||||
const int64_t block_idx = blockIdx.x;
|
const int64_t block_idx = blockIdx.x;
|
||||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||||
int64_t idx = block_idx * block_stride + i;
|
int64_t idx = block_idx * block_stride + i;
|
||||||
dst_cache[idx] =
|
dst_cache[idx] =
|
||||||
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -333,11 +334,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
|||||||
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
||||||
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
||||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
|
reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
|
||||||
|
|
||||||
// Only for testing.
|
// Only for testing.
|
||||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||||
const double kv_scale, const std::string& kv_cache_dtype) {
|
const double scale, const std::string& kv_cache_dtype) {
|
||||||
torch::Device src_device = src_cache.device();
|
torch::Device src_device = src_cache.device();
|
||||||
torch::Device dst_device = dst_cache.device();
|
torch::Device dst_device = dst_cache.device();
|
||||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||||
|
|||||||
@ -423,11 +423,11 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step) {
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||||
"CPU backend does not support blocksparse attention yet.");
|
"CPU backend does not support blocksparse attention yet.");
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||||
@ -742,11 +742,11 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step) {
|
const int64_t blocksparse_head_sliding_step) {
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||||
"CPU backend does not support blocksparse attention yet.");
|
"CPU backend does not support blocksparse attention yet.");
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||||
|
|||||||
@ -107,8 +107,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
|||||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||||
torch::Tensor& slot_mapping,
|
torch::Tensor& slot_mapping,
|
||||||
const std::string& kv_cache_dtype, double kv_scale) {
|
const std::string& kv_cache_dtype, double k_scale,
|
||||||
TORCH_CHECK(kv_scale == 1.0f);
|
double v_scale) {
|
||||||
|
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||||
|
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int num_heads = key.size(1);
|
||||||
|
|||||||
@ -16,8 +16,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
" int max_seq_len, Tensor? alibi_slopes,"
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||||
" int blocksparse_local_blocks,"
|
" int tp_rank, int blocksparse_local_blocks,"
|
||||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
||||||
@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
" int max_seq_len, Tensor? alibi_slopes,"
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||||
" int blocksparse_local_blocks,"
|
" int tp_rank, int blocksparse_local_blocks,"
|
||||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
|
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
|
||||||
@ -103,7 +103,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
" Tensor! key_cache, Tensor! value_cache,"
|
" Tensor! key_cache, Tensor! value_cache,"
|
||||||
" Tensor slot_mapping,"
|
" Tensor slot_mapping,"
|
||||||
" str kv_cache_dtype,"
|
" str kv_cache_dtype,"
|
||||||
" float kv_scale) -> ()");
|
" float k_scale, float v_scale) -> ()");
|
||||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -8,8 +8,8 @@ void paged_attention_v1(
|
|||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step);
|
const int64_t blocksparse_head_sliding_step);
|
||||||
|
|
||||||
@ -19,8 +19,8 @@ void paged_attention_v2(
|
|||||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||||
const int64_t blocksparse_local_blocks,
|
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||||
const int64_t blocksparse_head_sliding_step);
|
const int64_t blocksparse_head_sliding_step);
|
||||||
|
|
||||||
|
|||||||
@ -27,8 +27,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
" int max_seq_len, Tensor? alibi_slopes,"
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||||
" int blocksparse_local_blocks,"
|
" int tp_rank, int blocksparse_local_blocks,"
|
||||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
|
ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1);
|
||||||
@ -41,8 +41,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||||
" int max_seq_len, Tensor? alibi_slopes,"
|
" int max_seq_len, Tensor? alibi_slopes,"
|
||||||
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||||
" int blocksparse_local_blocks,"
|
" int tp_rank, int blocksparse_local_blocks,"
|
||||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||||
" int blocksparse_head_sliding_step) -> ()");
|
" int blocksparse_head_sliding_step) -> ()");
|
||||||
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2);
|
||||||
@ -223,7 +223,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
|||||||
" Tensor! key_cache, Tensor! value_cache,"
|
" Tensor! key_cache, Tensor! value_cache,"
|
||||||
" Tensor slot_mapping,"
|
" Tensor slot_mapping,"
|
||||||
" str kv_cache_dtype,"
|
" str kv_cache_dtype,"
|
||||||
" float kv_scale) -> ()");
|
" float k_scale, float v_scale) -> ()");
|
||||||
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
|
||||||
|
|
||||||
// Reshape the key and value tensors and cache them.
|
// Reshape the key and value tensors and cache them.
|
||||||
|
|||||||
@ -175,7 +175,7 @@ def test_paged_attention(
|
|||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
# Using default kv_scale
|
# Using default kv_scale
|
||||||
kv_scale = 1.0
|
k_scale = v_scale = 1.0
|
||||||
|
|
||||||
# Call the paged attention kernel.
|
# Call the paged attention kernel.
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
@ -193,7 +193,8 @@ def test_paged_attention(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
||||||
@ -224,7 +225,8 @@ def test_paged_attention(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"Unknown version: {version}")
|
raise AssertionError(f"Unknown version: {version}")
|
||||||
|
|||||||
@ -212,7 +212,7 @@ def test_paged_attention(
|
|||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
# Using default kv_scale
|
# Using default kv_scale
|
||||||
kv_scale = 1.0
|
k_scale = v_scale = 1.0
|
||||||
tp_rank = 0
|
tp_rank = 0
|
||||||
|
|
||||||
# Call the paged attention kernel.
|
# Call the paged attention kernel.
|
||||||
@ -231,7 +231,8 @@ def test_paged_attention(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||||
@ -267,7 +268,8 @@ def test_paged_attention(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
blocksparse_local_blocks=blocksparse_local_blocks,
|
blocksparse_local_blocks=blocksparse_local_blocks,
|
||||||
blocksparse_vert_stride=blocksparse_vert_stride,
|
blocksparse_vert_stride=blocksparse_vert_stride,
|
||||||
|
|||||||
@ -155,11 +155,11 @@ def test_reshape_and_cache(
|
|||||||
cloned_value_cache = value_cache.clone()
|
cloned_value_cache = value_cache.clone()
|
||||||
|
|
||||||
# Using default kv_scale
|
# Using default kv_scale
|
||||||
kv_scale = 1.0
|
k_scale = v_scale = 1.0
|
||||||
|
|
||||||
# Call the reshape_and_cache kernel.
|
# Call the reshape_and_cache kernel.
|
||||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
|
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
|
||||||
kv_cache_dtype, kv_scale)
|
kv_cache_dtype, k_scale, v_scale)
|
||||||
|
|
||||||
if kv_cache_dtype == "fp8":
|
if kv_cache_dtype == "fp8":
|
||||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||||
|
|||||||
@ -7,19 +7,49 @@ import torch
|
|||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
from vllm.model_executor.layers.quantization.fp8 import (Fp8KVCacheMethod,
|
||||||
|
Fp8LinearMethod)
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8",
|
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
|
||||||
"nm-testing/Phi-3-mini-128k-instruct-FP8",
|
"nm-testing/Phi-3-mini-128k-instruct-FP8",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||||
reason="FP8 is not supported on this GPU type.")
|
reason="FP8 is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model_id", MODELS)
|
||||||
def test_model_load_and_run(vllm_runner, model: str):
|
def test_model_load_and_run(vllm_runner, model_id: str):
|
||||||
with vllm_runner(model) as llm:
|
with vllm_runner(model_id) as llm:
|
||||||
|
# note: this does not test accuracy, just that we can run through
|
||||||
|
# see lm-eval tests for accuracy
|
||||||
|
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
||||||
|
max_tokens=10)
|
||||||
|
print(outputs[0][1])
|
||||||
|
|
||||||
|
|
||||||
|
KV_CACHE_MODELS = [
|
||||||
|
# Deprecated AutoFP8 format using .kv_scale
|
||||||
|
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV",
|
||||||
|
# AutoFP8 format using separate .k_scale and .v_scale
|
||||||
|
"nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||||
|
reason="FP8 is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
|
||||||
|
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
|
||||||
|
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
||||||
|
|
||||||
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
|
attn = model.model.layers[0].self_attn.attn
|
||||||
|
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
||||||
|
# NOTE: it is valid for scales to be 1.0 (default value), but we know
|
||||||
|
# these checkpoints have scales < 1.0
|
||||||
|
assert 0.0 < attn._k_scale < 1.0
|
||||||
|
assert 0.0 < attn._v_scale < 1.0
|
||||||
|
|
||||||
# note: this does not test accuracy, just that we can run through
|
# note: this does not test accuracy, just that we can run through
|
||||||
# see lm-eval tests for accuracy
|
# see lm-eval tests for accuracy
|
||||||
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
outputs = llm.generate_greedy(prompts=["Hello my name is"],
|
||||||
|
|||||||
@ -84,7 +84,8 @@ def paged_attention_v1(
|
|||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_scale: float,
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
blocksparse_local_blocks: int = 0,
|
blocksparse_local_blocks: int = 0,
|
||||||
blocksparse_vert_stride: int = 0,
|
blocksparse_vert_stride: int = 0,
|
||||||
@ -94,8 +95,9 @@ def paged_attention_v1(
|
|||||||
torch.ops._C.paged_attention_v1(
|
torch.ops._C.paged_attention_v1(
|
||||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
|
||||||
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
|
seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
|
||||||
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
|
k_scale, v_scale, tp_rank, blocksparse_local_blocks,
|
||||||
blocksparse_block_size, blocksparse_head_sliding_step)
|
blocksparse_vert_stride, blocksparse_block_size,
|
||||||
|
blocksparse_head_sliding_step)
|
||||||
|
|
||||||
|
|
||||||
def paged_attention_v2(
|
def paged_attention_v2(
|
||||||
@ -114,7 +116,8 @@ def paged_attention_v2(
|
|||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_scale: float,
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
blocksparse_local_blocks: int = 0,
|
blocksparse_local_blocks: int = 0,
|
||||||
blocksparse_vert_stride: int = 0,
|
blocksparse_vert_stride: int = 0,
|
||||||
@ -124,7 +127,7 @@ def paged_attention_v2(
|
|||||||
torch.ops._C.paged_attention_v2(
|
torch.ops._C.paged_attention_v2(
|
||||||
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
|
out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
|
||||||
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
|
num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
|
||||||
alibi_slopes, kv_cache_dtype, kv_scale, tp_rank,
|
alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
|
||||||
blocksparse_local_blocks, blocksparse_vert_stride,
|
blocksparse_local_blocks, blocksparse_vert_stride,
|
||||||
blocksparse_block_size, blocksparse_head_sliding_step)
|
blocksparse_block_size, blocksparse_head_sliding_step)
|
||||||
|
|
||||||
@ -374,11 +377,12 @@ def reshape_and_cache(
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_scale: float,
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
|
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
|
||||||
value_cache, slot_mapping,
|
value_cache, slot_mapping,
|
||||||
kv_cache_dtype, kv_scale)
|
kv_cache_dtype, k_scale, v_scale)
|
||||||
|
|
||||||
|
|
||||||
def reshape_and_cache_flash(
|
def reshape_and_cache_flash(
|
||||||
|
|||||||
@ -59,7 +59,8 @@ class ipex_ops:
|
|||||||
max_context_len: int,
|
max_context_len: int,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_scale: float,
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
blocksparse_local_blocks: int = 0,
|
blocksparse_local_blocks: int = 0,
|
||||||
blocksparse_vert_stride: int = 0,
|
blocksparse_vert_stride: int = 0,
|
||||||
@ -99,7 +100,8 @@ class ipex_ops:
|
|||||||
max_context_len: int,
|
max_context_len: int,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_scale: float,
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
blocksparse_local_blocks: int = 0,
|
blocksparse_local_blocks: int = 0,
|
||||||
blocksparse_vert_stride: int = 0,
|
blocksparse_vert_stride: int = 0,
|
||||||
@ -227,7 +229,8 @@ class ipex_ops:
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_scale: float,
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert kv_cache_dtype == "auto"
|
assert kv_cache_dtype == "auto"
|
||||||
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
ipex.llm.modules.PagedAttention.reshape_and_cache(
|
||||||
|
|||||||
@ -134,7 +134,8 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: T,
|
attn_metadata: T,
|
||||||
kv_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -327,7 +327,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: BlocksparseFlashAttentionMetadata,
|
attn_metadata: BlocksparseFlashAttentionMetadata,
|
||||||
kv_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
@ -368,7 +369,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
if prefill_meta := attn_metadata.prefill_metadata:
|
if prefill_meta := attn_metadata.prefill_metadata:
|
||||||
@ -405,7 +407,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
blocksparse_local_blocks=self.local_blocks,
|
blocksparse_local_blocks=self.local_blocks,
|
||||||
blocksparse_vert_stride=self.vert_stride,
|
blocksparse_vert_stride=self.vert_stride,
|
||||||
|
|||||||
@ -256,7 +256,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: FlashAttentionMetadata,
|
attn_metadata: FlashAttentionMetadata,
|
||||||
kv_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention.
|
"""Forward pass with FlashAttention.
|
||||||
@ -277,7 +278,8 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
"FlashAttentionImpl")
|
"FlashAttentionImpl")
|
||||||
|
|
||||||
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
|
||||||
assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
|
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||||
|
"key/v_scale is not supported in FlashAttention.")
|
||||||
|
|
||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
# Reshape the query, key, and value tensors.
|
# Reshape the query, key, and value tensors.
|
||||||
|
|||||||
@ -223,10 +223,12 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: FlashInferMetadata,
|
attn_metadata: FlashInferMetadata,
|
||||||
kv_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert kv_scale == 1.0
|
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||||
|
"key/v_scale is not supported in FlashInfer.")
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
"encoder/decoder cross-attention "
|
"encoder/decoder cross-attention "
|
||||||
|
|||||||
@ -156,7 +156,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||||
kv_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||||
@ -170,7 +171,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert kv_scale == 1.0
|
assert k_scale == 1.0 and v_scale == 1.0
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
"encoder/decoder cross-attention "
|
"encoder/decoder cross-attention "
|
||||||
@ -192,7 +193,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping.flatten(),
|
attn_metadata.slot_mapping.flatten(),
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
if attn_metadata.is_prompt:
|
if attn_metadata.is_prompt:
|
||||||
@ -273,7 +275,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Run PagedAttention V2.
|
# Run PagedAttention V2.
|
||||||
@ -305,7 +308,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
|
|||||||
@ -131,7 +131,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
|
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
|
||||||
attn_metadata: PallasMetadata,
|
attn_metadata: PallasMetadata,
|
||||||
kv_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with Pallas attention.
|
"""Forward pass with Pallas attention.
|
||||||
@ -146,7 +147,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [batch_size, seq_len, num_heads * head_size]
|
shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert kv_scale == 1.0
|
assert k_scale == 1.0 and v_scale == 1.0
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
"encoder/decoder cross-attention "
|
"encoder/decoder cross-attention "
|
||||||
|
|||||||
@ -296,7 +296,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: ROCmFlashAttentionMetadata,
|
attn_metadata: ROCmFlashAttentionMetadata,
|
||||||
kv_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with FlashAttention and PagedAttention.
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
@ -336,7 +337,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
@ -456,7 +458,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
|
|||||||
@ -144,7 +144,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: TorchSDPAMetadata, # type: ignore
|
attn_metadata: TorchSDPAMetadata, # type: ignore
|
||||||
kv_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with torch SDPA and PagedAttention.
|
"""Forward pass with torch SDPA and PagedAttention.
|
||||||
@ -158,7 +159,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert kv_scale == 1.0
|
assert k_scale == 1.0 and v_scale == 1.0
|
||||||
if attn_type != AttentionType.DECODER:
|
if attn_type != AttentionType.DECODER:
|
||||||
raise NotImplementedError("Encoder self-attention and "
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
"encoder/decoder cross-attention "
|
"encoder/decoder cross-attention "
|
||||||
@ -176,7 +177,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
PagedAttention.write_to_paged_cache(key, value, key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping,
|
attn_metadata.slot_mapping,
|
||||||
self.kv_cache_dtype, kv_scale)
|
self.kv_cache_dtype, k_scale,
|
||||||
|
v_scale)
|
||||||
|
|
||||||
if attn_metadata.is_prompt:
|
if attn_metadata.is_prompt:
|
||||||
assert attn_metadata.seq_lens is not None
|
assert attn_metadata.seq_lens is not None
|
||||||
@ -239,7 +241,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
|
|||||||
@ -427,7 +427,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
value: Optional[torch.Tensor],
|
value: Optional[torch.Tensor],
|
||||||
kv_cache: Optional[torch.Tensor],
|
kv_cache: Optional[torch.Tensor],
|
||||||
attn_metadata: "XFormersMetadata",
|
attn_metadata: "XFormersMetadata",
|
||||||
kv_scale: float = 1.0,
|
k_scale: float = 1.0,
|
||||||
|
v_scale: float = 1.0,
|
||||||
attn_type: AttentionType = AttentionType.DECODER,
|
attn_type: AttentionType = AttentionType.DECODER,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Forward pass with xFormers and PagedAttention.
|
"""Forward pass with xFormers and PagedAttention.
|
||||||
@ -531,7 +532,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
value_cache,
|
value_cache,
|
||||||
updated_slot_mapping,
|
updated_slot_mapping,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
kv_scale)
|
k_scale, v_scale)
|
||||||
|
|
||||||
if attn_type != AttentionType.ENCODER:
|
if attn_type != AttentionType.ENCODER:
|
||||||
# Decoder self-attention supports chunked prefill.
|
# Decoder self-attention supports chunked prefill.
|
||||||
@ -620,7 +621,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
self.scale,
|
self.scale,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
|
|||||||
@ -47,13 +47,14 @@ class Attention(nn.Module):
|
|||||||
if num_kv_heads is None:
|
if num_kv_heads is None:
|
||||||
num_kv_heads = num_heads
|
num_kv_heads = num_heads
|
||||||
|
|
||||||
# The default kv_scale is set to 1.0. This is ignored
|
# The default k/v_scale is set to 1.0. This is ignored
|
||||||
# when kv-cache is not fp8, and should be used with
|
# when kv-cache is not fp8, and should be used with
|
||||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||||
# expect the pre-quantized kv_scale to be loaded along
|
# expect the pre-quantized k/v_scale to be loaded along
|
||||||
# with the model weights.
|
# with the model weights.
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
self._kv_scale = 1.0
|
self._k_scale = 1.0
|
||||||
|
self._v_scale = 1.0
|
||||||
quant_method = quant_config.get_quant_method(
|
quant_method = quant_config.get_quant_method(
|
||||||
self) if quant_config else None
|
self) if quant_config else None
|
||||||
if quant_method is not None:
|
if quant_method is not None:
|
||||||
@ -66,8 +67,8 @@ class Attention(nn.Module):
|
|||||||
"fp8 checkpoints.")
|
"fp8 checkpoints.")
|
||||||
# When FP8 quantization is enabled, we make a parameter
|
# When FP8 quantization is enabled, we make a parameter
|
||||||
# "kv_scale" so that it can be loaded from FP8 checkpoint.
|
# "kv_scale" so that it can be loaded from FP8 checkpoint.
|
||||||
# The kv_scale will then be converted back to self._kv_scale
|
# The k/v_scale will then be converted back to
|
||||||
# in a native float32 value after weight loading.
|
# self._kv_scale in a native float32 value after weight loading
|
||||||
self.quant_method = quant_method
|
self.quant_method = quant_method
|
||||||
self.quant_method.create_weights(self)
|
self.quant_method.create_weights(self)
|
||||||
|
|
||||||
@ -98,7 +99,8 @@ class Attention(nn.Module):
|
|||||||
value,
|
value,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
self._kv_scale,
|
self._k_scale,
|
||||||
|
self._v_scale,
|
||||||
attn_type=attn_type)
|
attn_type=attn_type)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
|
|||||||
@ -45,7 +45,8 @@ class PagedAttention:
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_scale: float,
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
*args,
|
*args,
|
||||||
) -> None:
|
) -> None:
|
||||||
ipex_modules.PagedAttention.reshape_and_cache(
|
ipex_modules.PagedAttention.reshape_and_cache(
|
||||||
@ -64,7 +65,8 @@ class PagedAttention:
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
kv_scale: float,
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
*args,
|
*args,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
|
|||||||
@ -66,7 +66,8 @@ class PagedAttention:
|
|||||||
value_cache: torch.Tensor,
|
value_cache: torch.Tensor,
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
kv_scale: float,
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
ops.reshape_and_cache(
|
ops.reshape_and_cache(
|
||||||
key,
|
key,
|
||||||
@ -75,7 +76,8 @@ class PagedAttention:
|
|||||||
value_cache,
|
value_cache,
|
||||||
slot_mapping.flatten(),
|
slot_mapping.flatten(),
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -90,7 +92,8 @@ class PagedAttention:
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
kv_scale: float,
|
k_scale: float,
|
||||||
|
v_scale: float,
|
||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
blocksparse_local_blocks: int = 0,
|
blocksparse_local_blocks: int = 0,
|
||||||
blocksparse_vert_stride: int = 0,
|
blocksparse_vert_stride: int = 0,
|
||||||
@ -135,7 +138,8 @@ class PagedAttention:
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
blocksparse_local_blocks,
|
blocksparse_local_blocks,
|
||||||
blocksparse_vert_stride,
|
blocksparse_vert_stride,
|
||||||
@ -172,7 +176,8 @@ class PagedAttention:
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
blocksparse_local_blocks,
|
blocksparse_local_blocks,
|
||||||
blocksparse_vert_stride,
|
blocksparse_vert_stride,
|
||||||
|
|||||||
@ -196,6 +196,15 @@ class ReplicatedLinear(LinearBase):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
|
# If the weight on disk does not have a shape, give it one
|
||||||
|
# (such scales for AutoFp8).
|
||||||
|
if len(loaded_weight.shape) == 0:
|
||||||
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
|
|
||||||
|
assert param.size() == loaded_weight.size()
|
||||||
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|||||||
@ -407,31 +407,56 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module):
|
def create_weights(self, layer: torch.nn.Module):
|
||||||
"""Create "weight" (aka kv_scale) for an attention layer.
|
"""Create "weight" (aka k_scale and v_scale) for an attention layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
layer: The layer that is using the QuantizeMethodBase factory.
|
layer: The layer that is using the QuantizeMethodBase factory.
|
||||||
"""
|
"""
|
||||||
# Initialize the KV cache scale to 1.0 as the default value.
|
# Initialize the KV cache scales to -1.0, which is an invalid value.
|
||||||
# If the kv_scale appears in the checkpoint, it will be
|
# If the k/v_scale appears in the checkpoint, it will be
|
||||||
# overwritten when loading weights.
|
# overwritten when loading weights.
|
||||||
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
|
layer.k_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||||
|
layer.v_scale = Parameter(torch.tensor(-1.0), requires_grad=False)
|
||||||
|
|
||||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||||
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
|
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
|
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
|
||||||
# regardless whether the kv-scale is available in the checkpoint.
|
# regardless whether the kv-scale is available in the checkpoint.
|
||||||
if layer.kv_cache_dtype != "auto":
|
if layer.kv_cache_dtype != "auto":
|
||||||
kv_scale = layer.kv_scale.to("cpu").tolist()
|
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
|
||||||
if not isinstance(kv_scale, float):
|
# We prefer to use separate k_scale and v_scale if present
|
||||||
|
k_scale = layer.k_scale.to("cpu").tolist()
|
||||||
|
v_scale = layer.v_scale.to("cpu").tolist()
|
||||||
|
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||||
|
# If no scales were loaded (both scales are invalid negative
|
||||||
|
# values), use the default value of 1.0
|
||||||
|
k_scale = Parameter(torch.tensor(1.0), requires_grad=False)
|
||||||
|
v_scale = Parameter(torch.tensor(1.0), requires_grad=False)
|
||||||
|
else:
|
||||||
|
# If we find a single kv_scale in the checkpoint, we remap
|
||||||
|
# kv_scale to k_scale during weight loading, and duplicate
|
||||||
|
# k_scale to v_scale here
|
||||||
|
assert layer.k_scale > 0.0
|
||||||
|
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||||
|
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||||
|
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||||
|
|
||||||
|
if not isinstance(k_scale, float) or not isinstance(
|
||||||
|
v_scale, float):
|
||||||
raise ValueError("Only support per-tensor scaling factor "
|
raise ValueError("Only support per-tensor scaling factor "
|
||||||
"for fp8 KV cache")
|
"for fp8 KV cache")
|
||||||
layer._kv_scale = kv_scale
|
|
||||||
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
# These are used in the final Attention.forward()
|
||||||
|
layer._k_scale = k_scale
|
||||||
|
layer._v_scale = v_scale
|
||||||
|
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
|
||||||
|
and "e5m2" not in layer.kv_cache_dtype):
|
||||||
print_warning_once(
|
print_warning_once(
|
||||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
|
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
|
||||||
"cause accuracy issues. Please make sure kv-cache scaling "
|
"may cause accuracy issues. Please make sure k/v_scale "
|
||||||
"factor is available in the fp8 checkpoint.")
|
"scaling factors are available in the fp8 checkpoint.")
|
||||||
del layer.kv_scale
|
|
||||||
|
del layer.k_scale
|
||||||
|
del layer.v_scale
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||||
get_quantization_config)
|
get_quantization_config)
|
||||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||||
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -431,11 +432,6 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
|||||||
def default_weight_loader(param: torch.Tensor,
|
def default_weight_loader(param: torch.Tensor,
|
||||||
loaded_weight: torch.Tensor) -> None:
|
loaded_weight: torch.Tensor) -> None:
|
||||||
"""Default weight loader."""
|
"""Default weight loader."""
|
||||||
# If the weight on disk does not have a shape, give it one
|
|
||||||
# (such scales for AutoFp8).
|
|
||||||
if len(loaded_weight.shape) == 0:
|
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
|
||||||
|
|
||||||
assert param.size() == loaded_weight.size()
|
assert param.size() == loaded_weight.size()
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
@ -462,3 +458,55 @@ def initialize_dummy_weights(
|
|||||||
param.data.copy_(tmp_param)
|
param.data.copy_(tmp_param)
|
||||||
else:
|
else:
|
||||||
param.uniform_(low, high)
|
param.uniform_(low, high)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
||||||
|
"""Remap the name of FP8 k/v_scale parameters.
|
||||||
|
|
||||||
|
This function handles the remapping of FP8 k/v_scale parameter names.
|
||||||
|
It detects if the given name ends with a suffix and attempts to remap
|
||||||
|
it to the expected name format in the model. If the remapped name is not
|
||||||
|
found in the params_dict, a warning is printed and None is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The original loaded checkpoint parameter name.
|
||||||
|
params_dict (dict): Dictionary containing the model's named parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The remapped parameter name if successful, or the original name
|
||||||
|
if no remapping is needed.
|
||||||
|
None: If the remapped name is not found in params_dict.
|
||||||
|
"""
|
||||||
|
if name.endswith(".kv_scale"):
|
||||||
|
print_warning_once(
|
||||||
|
"DEPRECATED. Found kv_scale in the checkpoint. "
|
||||||
|
"This format is deprecated in favor of separate k_scale and "
|
||||||
|
"v_scale tensors and will be removed in a future release. "
|
||||||
|
"Functionally, we will remap kv_scale to k_scale and duplicate "
|
||||||
|
"k_scale to v_scale")
|
||||||
|
# NOTE: we remap the deprecated kv_scale to k_scale
|
||||||
|
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
|
||||||
|
if remapped_name not in params_dict:
|
||||||
|
print_warning_once(
|
||||||
|
f"Found kv_scale in the checkpoint (e.g. {name}), "
|
||||||
|
"but not found the expected name in the model "
|
||||||
|
f"(e.g. {remapped_name}). kv_scale is "
|
||||||
|
"not loaded.")
|
||||||
|
return None
|
||||||
|
return remapped_name
|
||||||
|
|
||||||
|
possible_scale_names = [".k_scale", ".v_scale"]
|
||||||
|
for scale_name in possible_scale_names:
|
||||||
|
if name.endswith(scale_name):
|
||||||
|
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||||
|
if remapped_name not in params_dict:
|
||||||
|
print_warning_once(
|
||||||
|
f"Found {scale_name} in the checkpoint (e.g. {name}), "
|
||||||
|
"but not found the expected name in the model "
|
||||||
|
f"(e.g. {remapped_name}). {scale_name} is "
|
||||||
|
"not loaded.")
|
||||||
|
return None
|
||||||
|
return remapped_name
|
||||||
|
|
||||||
|
# If there were no matches, return the untouched param name
|
||||||
|
return name
|
||||||
|
|||||||
@ -44,10 +44,10 @@ from vllm.model_executor.layers.sampler import Sampler
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader, kv_cache_scales_loader)
|
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
from vllm.utils import is_hip, print_warning_once
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA
|
from .interfaces import SupportsLoRA
|
||||||
from .utils import is_pp_missing_parameter, make_layers
|
from .utils import is_pp_missing_parameter, make_layers
|
||||||
@ -460,18 +460,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
# Remapping the name of FP8 kv-scale.
|
# Remapping the name of FP8 kv-scale.
|
||||||
if name.endswith("kv_scale"):
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
remapped_kv_scale_name = name.replace(
|
if name is None:
|
||||||
".kv_scale", ".attn.kv_scale")
|
continue
|
||||||
if remapped_kv_scale_name not in params_dict:
|
|
||||||
print_warning_once(
|
|
||||||
f"Found kv scale in the checkpoint (e.g. {name}), "
|
|
||||||
"but not found the expected name in the model "
|
|
||||||
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
|
|
||||||
"not loaded.")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
name = remapped_kv_scale_name
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -42,10 +42,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
@ -415,19 +415,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
# Remapping the name of FP8 kv-scale.
|
# Remapping the name of FP8 kv-scale.
|
||||||
if name.endswith("kv_scale"):
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
remapped_kv_scale_name = name.replace(
|
if name is None:
|
||||||
".kv_scale", ".attn.kv_scale")
|
continue
|
||||||
if remapped_kv_scale_name not in params_dict:
|
|
||||||
print_warning_once(
|
|
||||||
"Found kv scale in the checkpoint "
|
|
||||||
f"(e.g. {name}), but not found the expected "
|
|
||||||
f"name in the model "
|
|
||||||
f"(e.g. {remapped_kv_scale_name}). "
|
|
||||||
"kv-scale is not loaded.")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
name = remapped_kv_scale_name
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
|||||||
@ -43,10 +43,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA
|
from .interfaces import SupportsLoRA
|
||||||
|
|
||||||
@ -382,18 +382,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
|||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
# Remapping the name of FP8 kv-scale.
|
# Remapping the name of FP8 kv-scale.
|
||||||
if name.endswith("kv_scale"):
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
remapped_kv_scale_name = name.replace(
|
if name is None:
|
||||||
".kv_scale", ".attn.kv_scale")
|
continue
|
||||||
if remapped_kv_scale_name not in params_dict:
|
|
||||||
print_warning_once(
|
|
||||||
f"Found kv scale in the checkpoint (e.g. {name}), "
|
|
||||||
"but not found the expected name in the model "
|
|
||||||
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
|
|
||||||
"not loaded.")
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
name = remapped_kv_scale_name
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user