mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 03:24:28 +08:00
[Bugfix] fixes the decoding metadata of dense mla's fp8 kvcache. (#27144)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
647214f3d5
commit
250fb1b8ea
@ -19,7 +19,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
flashmla
|
||||
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
|
||||
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
|
||||
GIT_TAG 28417e516fcbf6257a422ba117ef5b6f44da5682
|
||||
GIT_PROGRESS TRUE
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
@ -66,6 +66,7 @@ if(FLASH_MLA_ARCHS)
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
|
||||
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_metadata.cu
|
||||
)
|
||||
|
||||
set(FlashMLA_INCLUDES
|
||||
|
||||
@ -102,6 +102,12 @@ def get_mla_metadata(
|
||||
(num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
|
||||
- num_splits: (batch_size + 1), dtype torch.int32.
|
||||
"""
|
||||
if is_fp8_kvcache and topk is None:
|
||||
return torch.ops._flashmla_extension_C.get_mla_decoding_metadata_dense_fp8(
|
||||
cache_seqlens,
|
||||
num_q_tokens_per_head_k,
|
||||
num_heads_k,
|
||||
)
|
||||
return torch.ops._flashmla_C.get_mla_decoding_metadata(
|
||||
cache_seqlens,
|
||||
num_q_tokens_per_head_k,
|
||||
|
||||
@ -91,6 +91,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
@ -123,6 +124,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
seq_lens_device,
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
is_fp8_kvcache=self.is_fp8_kvcache,
|
||||
)
|
||||
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user