diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index dc524650f554c..b2c3d035a62cd 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -475,7 +475,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): max_query_len: int, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - ubatch_id: int = 0 ) -> M: num_reqs = req_slice.stop - req_slice.start num_tokens = token_slice.stop - token_slice.start @@ -587,7 +586,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens=seq_lens[:num_decodes], - ubatch_id=ubatch_id ) return self.metadata_cls( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index ac6389e9efd61..ed0a35f1b460f 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -63,12 +63,11 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) - self.cg_buf_tile_scheduler_metadata = [None, None] - self.cg_buf_num_splits = [None, None] + self.cg_buf_tile_scheduler_metadata = None + self.cg_buf_num_splits = None def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor, ubatch_id = 0) -> FlashMLADecodeMetadata: - assert ubatch_id < 2 + seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens, @@ -76,31 +75,30 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): 1, # MQA for the decode path ) - # logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}") if self.runner.full_cuda_graph: n = num_splits.size(0) # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_num_splits[ubatch_id] is None: - self.cg_buf_num_splits[ubatch_id] = num_splits - self.cg_buf_tile_scheduler_metadata[ubatch_id] = tile_scheduler_metadata - elif n <= self.cg_buf_num_splits[ubatch_id].size(0): - assert self.cg_buf_tile_scheduler_metadata[ubatch_id] is not None + if self.cg_buf_num_splits is None: + self.cg_buf_num_splits = num_splits + self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata + elif n <= self.cg_buf_num_splits.size(0): + assert self.cg_buf_tile_scheduler_metadata is not None # Metadata per-SM, fixed size (#SMs, TileMetadataSize) - assert (self.cg_buf_tile_scheduler_metadata[ubatch_id].size() == + assert (self.cg_buf_tile_scheduler_metadata.size() == tile_scheduler_metadata.size()) - self.cg_buf_tile_scheduler_metadata[ubatch_id].\ + self.cg_buf_tile_scheduler_metadata.\ copy_(tile_scheduler_metadata) - tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[ubatch_id] + tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata # Num splits is per-batch, varying size (batch_size,) n = num_splits.size(0) # logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}") # make sure static buffer is large enough - assert n <= self.cg_buf_num_splits[ubatch_id].size(0) - num_splits_view = self.cg_buf_num_splits[ubatch_id][:n] + assert n <= self.cg_buf_num_splits.size(0) + num_splits_view = self.cg_buf_num_splits[:n] num_splits_view.copy_(num_splits) - self.cg_buf_num_splits[ubatch_id][n:].fill_(0) # fill the rest with 0s + self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s num_splits = num_splits_view return FlashMLADecodeMetadata( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c43917baaa9eb..34e62531a9c4f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -851,7 +851,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): max_query_len=max(tokens[req_slice]), common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, - ubatch_id=ubid )) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list