[Bugfix] fixes the decoding metadata of dense mla's fp8 kvcache. (#27144)

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Tao He 2025-10-22 02:27:03 +08:00 committed by GitHub
parent 647214f3d5
commit 250fb1b8ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 10 additions and 1 deletions

View File

@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
GIT_TAG 28417e516fcbf6257a422ba117ef5b6f44da5682
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
@ -66,6 +66,7 @@ if(FLASH_MLA_ARCHS)
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu
)
set(FlashMLA_INCLUDES

View File

@ -102,6 +102,12 @@ def get_mla_metadata(
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
- num_splits: (batch_size + 1), dtype torch.int32.
"""
if is_fp8_kvcache and topk is None:
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
cache_seqlens,
num_q_tokens_per_head_k,
num_heads_k,
)
return torch.ops._flashmla_C.get_mla_decoding_metadata(
cache_seqlens,
num_q_tokens_per_head_k,

View File

@ -91,6 +91,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
device_properties = torch.cuda.get_device_properties(self.device)
num_sms = device_properties.multi_processor_count
@ -123,6 +124,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
seq_lens_device,
self.num_q_heads,
1, # MQA for the decode path
is_fp8_kvcache=self.is_fp8_kvcache,
)
# TODO: we can disambiguate between decode and mixed-prefill decode here