diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 83a2d42e6d46f..883f1e8150143 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -345,7 +345,7 @@ void single_query_cached_kv_attention_launcher( constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_context_len * sizeof(T); + int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); int shared_mem_size = std::max(logits_size, outputs_size);