diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index d3e5300dbbd6b..51e77bc6bfa41 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -67,6 +67,20 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None + device_properties = torch.cuda.get_device_properties(self.device) + num_sms = device_properties.multi_processor_count + + if self.compilation_config.full_cuda_graph: + self.cg_buf_tile_scheduler_metadata = torch.empty( + (num_sms, 8), # TileSchedulerMetaDataSize == 8 + device=self.device, + dtype=torch.int32, + ) + self.cg_buf_num_splits = torch.empty( + (vllm_config.scheduler_config.max_num_seqs + 1), + device=self.device, + dtype=torch.int32) + def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ @@ -77,28 +91,24 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ) if self.compilation_config.full_cuda_graph: - # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_tile_scheduler_metadata is None: - self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata - self.cg_buf_num_splits = num_splits - else: - assert self.cg_buf_num_splits is not None + assert self.cg_buf_tile_scheduler_metadata is not None + assert self.cg_buf_num_splits is not None - # Metadata per-SM, fixed size (#SMs, TileMetadataSize) - assert (self.cg_buf_tile_scheduler_metadata.size() == - tile_scheduler_metadata.size()) - self.cg_buf_tile_scheduler_metadata.\ - copy_(tile_scheduler_metadata) - tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata + # Metadata per-SM, fixed size (#SMs, TileMetadataSize) + assert (self.cg_buf_tile_scheduler_metadata.size() == + tile_scheduler_metadata.size()) + self.cg_buf_tile_scheduler_metadata.\ + copy_(tile_scheduler_metadata) + tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata - # Num splits is per-batch, varying size (batch_size,) - n = num_splits.size(0) - # make sure static buffer is large enough - assert n <= self.cg_buf_num_splits.size(0) - num_splits_view = self.cg_buf_num_splits[:n] - num_splits_view.copy_(num_splits) - self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s - num_splits = num_splits_view + # Num splits is per-batch, varying size (batch_size,) + n = num_splits.size(0) + # make sure static buffer is large enough + assert n <= self.cg_buf_num_splits.size(0) + num_splits_view = self.cg_buf_num_splits[:n] + num_splits_view.copy_(num_splits) + self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s + num_splits = num_splits_view return FlashMLADecodeMetadata( block_table=block_table_tensor,