Avoid division by zero in cache DS MLA kernel (#26174)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Matthew Bonanni 2025-10-03 13:35:17 -04:00 committed by yewentao256
parent 2d68bba3cd
commit 13e211bbbc

View File

@ -16,7 +16,7 @@
#include <algorithm>
#include <cassert>
#include <cfloat> // FLT_MIN
#include <cfloat>
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
@ -479,6 +479,7 @@ __global__ void concat_and_cache_ds_mla_kernel(
// Compute the scale for the tile
float tile_scale = max_abs / 448.f;
tile_scale = fmaxf(tile_scale, FLT_MIN);
// The first lane of each half-warp writes the scale to kv_cache
if ((lane_idx == 0) || (lane_idx == 16)) {