diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index b2c3d035a62cd..dc524650f554c 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -475,6 +475,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): max_query_len: int, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, + ubatch_id: int = 0 ) -> M: num_reqs = req_slice.stop - req_slice.start num_tokens = token_slice.stop - token_slice.start @@ -586,6 +587,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens=seq_lens[:num_decodes], + ubatch_id=ubatch_id ) return self.metadata_cls( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 53ad8dc6be70c..ac6389e9efd61 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -63,11 +63,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) - self.cg_buf_tile_scheduler_metadata = None - self.cg_buf_num_splits = None + self.cg_buf_tile_scheduler_metadata = [None, None] + self.cg_buf_num_splits = [None, None] def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: + seq_lens: torch.Tensor, ubatch_id = 0) -> FlashMLADecodeMetadata: + assert ubatch_id < 2 tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens, @@ -79,27 +80,27 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): if self.runner.full_cuda_graph: n = num_splits.size(0) # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_num_splits is None: - self.cg_buf_num_splits = num_splits - self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata - elif n <= self.cg_buf_num_splits.size(0): - assert self.cg_buf_tile_scheduler_metadata is not None + if self.cg_buf_num_splits[ubatch_id] is None: + self.cg_buf_num_splits[ubatch_id] = num_splits + self.cg_buf_tile_scheduler_metadata[ubatch_id] = tile_scheduler_metadata + elif n <= self.cg_buf_num_splits[ubatch_id].size(0): + assert self.cg_buf_tile_scheduler_metadata[ubatch_id] is not None # Metadata per-SM, fixed size (#SMs, TileMetadataSize) - assert (self.cg_buf_tile_scheduler_metadata.size() == + assert (self.cg_buf_tile_scheduler_metadata[ubatch_id].size() == tile_scheduler_metadata.size()) - self.cg_buf_tile_scheduler_metadata.\ + self.cg_buf_tile_scheduler_metadata[ubatch_id].\ copy_(tile_scheduler_metadata) - tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata + tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[ubatch_id] # Num splits is per-batch, varying size (batch_size,) n = num_splits.size(0) # logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}") # make sure static buffer is large enough - assert n <= self.cg_buf_num_splits.size(0) - num_splits_view = self.cg_buf_num_splits[:n] + assert n <= self.cg_buf_num_splits[ubatch_id].size(0) + num_splits_view = self.cg_buf_num_splits[ubatch_id][:n] num_splits_view.copy_(num_splits) - self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s + self.cg_buf_num_splits[ubatch_id][n:].fill_(0) # fill the rest with 0s num_splits = num_splits_view return FlashMLADecodeMetadata( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 8aa78e7018129..3695796b70dec 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -17,6 +17,7 @@ import vllm.envs as envs from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) from vllm.logger import init_logger +from vllm.v1.worker.block_table import BlockTable logger = init_logger(__name__) @@ -29,6 +30,8 @@ class CommonAttentionMetadata: """ query_start_loc: torch.Tensor + + # query_start_loc_cpu: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" seq_lens: torch.Tensor """(batch_size,), the length of each request including both computed tokens @@ -41,6 +44,47 @@ class CommonAttentionMetadata: max_query_len: int """Longest query in batch""" + # block_table: BlockTable + + # def compute_request_slice(self, token_slice: slice) -> slice: + # """ + # return + # - num_decodes: number of decode requests + # - num_prefills: number of prefill requests + # - num_decode_tokens: number of decode tokens + # - num_prefill_tokens: number of prefill tokens + # """ + # if self.max_query_len == 1: + # # Pure decode + # return token_slice + # else: + # # Find the first query_start_loc that's greater than the token_slice.start + # first_reqest = (self.query_start_loc_cpu >= token_slice.start).int().argmax(dim=-1).item() + # last_request = (self.query_start_loc_cpu < token_slice.stop).int().argmax(dim=-1).item() + # return slice(first_reqest, last_request) + + # # Slice the current CommonAttentionMetatdata into two + # def _slice(self, token_slice: slice) -> CommonAttentionMetadata: + # request_slice = self.compute_request_slice(token_slice) + # query_start_loc = slice_query_start_locs( + # self.query_start_loc, request_slice) + + # seq_lens = self.seq_lens[request_slice] + # num_requests = request_slice.stop - request_slice.start + # num_actual_tokens = token_slice.stop - token_slice.start + # #TODO(Sage) update this for prefill + # max_query_len = 1 + + # block_table = self.block_table + # block_table_tensor = block_table.get_device_tensor()[req_slice] + # block_table.slot_mapping[token_slice].copy_( + # block_table.slot_mapping_cpu[token_slice], + # non_blocking=True) + # block_table.slot_mapping[token_slice.stop:].fill_(-1) + # slot_mapping = block_table.slot_mapping[token_slice] + + # pass + M = TypeVar("M") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1d981feb373fb..5642a2e9456af 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -839,6 +839,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): max_query_len=max(tokens[req_slice]), common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, + ubatch_id=ubid )) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list @@ -1583,7 +1584,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _make_ubatch_contexts(ubatch_slices, attn_metadata, compute_stream, - is_dummy_run, num_tokens_across_dp, skip_cuda_graphs) -> list[UBatchContext]: ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices), @@ -1623,7 +1623,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ubatch_slices=ubatch_slices, attn_metadata=attn_metadata, compute_stream=compute_stream, - is_dummy_run=is_dummy_run, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs ) @@ -2369,7 +2368,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # _dummy_run doesn't go through _prepare_inputs so # we synchronize with other DP ranks here # logger.info(f"NUM TOKENS {num_tokens} SHOULD UBATCH {should_ubatch}") - should_ubatch = self.should_ubatch(allow_microbatching) + should_ubatch = self.should_ubatch(should_ubatch) # Padding for DP # logger.info("PADDING DUMMY") num_tokens_across_dp = None @@ -2451,6 +2450,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): max_query_len=max_query_len, common_prefix_len=0, common_attn_metadata=common_attn_metadata, + ubatch_id=ubid )) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list