diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 6b3480091f28..84c2345b44d8 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -16,7 +16,7 @@ #include #include -#include // FLT_MIN +#include #ifdef USE_ROCM #include @@ -479,6 +479,7 @@ __global__ void concat_and_cache_ds_mla_kernel( // Compute the scale for the tile float tile_scale = max_abs / 448.f; + tile_scale = fmaxf(tile_scale, FLT_MIN); // The first lane of each half-warp writes the scale to kv_cache if ((lane_idx == 0) || (lane_idx == 16)) {