diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 6291475164ba..ee6768bce26c 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845 + GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1 GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -37,9 +37,9 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu - ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu - ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu) + ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu + ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu) set(FlashMLA_INCLUDES ${flashmla_SOURCE_DIR}/csrc/cutlass/include diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index b85f27ac417c..1af26dfc3daa 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -91,7 +91,6 @@ def flash_mla_with_kvcache( out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( q, k_cache, - None, head_dim_v, cache_seqlens, block_table, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index b5aecff9937f..2b0f52cf80bf 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -70,6 +70,22 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None + device_properties = torch.cuda.get_device_properties(self.device) + num_sms = device_properties.multi_processor_count + + if self.compilation_config.full_cuda_graph: + self.cg_buf_tile_scheduler_metadata = torch.zeros( + # Upper bound on size (<= #SMs, TileSchedulerMetaDataSize) + # TileSchedulerMetaDataSize = 8 + (num_sms, 8), + device=self.device, + dtype=torch.int32, + ) + self.cg_buf_num_splits = torch.empty( + (vllm_config.scheduler_config.max_num_seqs + 1), + device=self.device, + dtype=torch.int32) + def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ @@ -80,28 +96,28 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ) if self.compilation_config.full_cuda_graph: - # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_tile_scheduler_metadata is None: - self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata - self.cg_buf_num_splits = num_splits - else: - assert self.cg_buf_num_splits is not None + assert self.cg_buf_tile_scheduler_metadata is not None + assert self.cg_buf_num_splits is not None - # Metadata per-SM, fixed size (#SMs, TileMetadataSize) - assert (self.cg_buf_tile_scheduler_metadata.size() == - tile_scheduler_metadata.size()) - self.cg_buf_tile_scheduler_metadata.\ - copy_(tile_scheduler_metadata) - tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata + sm_parts = tile_scheduler_metadata.size(0) + # Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize) + assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0) + tile_scheduler_metadata_view = \ + self.cg_buf_tile_scheduler_metadata[:sm_parts] + tile_scheduler_metadata_view.copy_(tile_scheduler_metadata) + tile_scheduler_metadata = tile_scheduler_metadata_view - # Num splits is per-batch, varying size (batch_size,) - n = num_splits.size(0) - # make sure static buffer is large enough - assert n <= self.cg_buf_num_splits.size(0) - num_splits_view = self.cg_buf_num_splits[:n] - num_splits_view.copy_(num_splits) - self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s - num_splits = num_splits_view + # Num splits is per-batch, varying size (batch_size,) + n = num_splits.size(0) + # make sure static buffer is large enough + assert n <= self.cg_buf_num_splits.size(0) + num_splits_view = self.cg_buf_num_splits[:n] + num_splits_view.copy_(num_splits) + # Num splits needs to monotonically increasing + # (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise + # it needs to monotonically increasing by 1) + self.cg_buf_num_splits[n:].fill_(num_splits[-1]) + num_splits = num_splits_view return FlashMLADecodeMetadata( block_table=block_table_tensor,