[bugfix][deepseek] fix flashmla kernel selection (#25956)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
youkaichao 2025-10-01 00:30:36 +08:00 committed by simon-mo
parent d0b178cef1
commit 83f3c9beae

View File

@ -136,7 +136,7 @@ def flash_mla_with_kvcache(
descale_k is None
), "descale_q and descale_k should be both None or both not None"
if (descale_q is not None) and (descale_k is not None):
if indices is None and q.element_size() == 1:
out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8(
q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale,
causal, tile_scheduler_metadata, num_splits, descale_q, descale_k)