diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index a6f8602a05882..d06eac2b3d4fe 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -375,7 +375,7 @@ void reshape_and_cache( torch::Tensor& slot_mapping, // [num_tokens] const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale) { - int num_tokens = key.size(0); + int num_tokens = slot_mapping.size(0); int num_heads = key.size(1); int head_size = key.size(2); int block_size = key_cache.size(3);