diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 46b38b74b4280..6b1404d0318c9 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -53,6 +53,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): super().__init__() self.buffers = buffers + for buffer in self.buffers: + buffer.set_num_sms(4) self.max_tokens_per_rank = max_tokens_per_rank self.use_fp8_dispatch = use_fp8_dispatch # The dispatch function returns a handle that the combine function diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 52fe30779ce7e..0b066de449371 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -444,6 +444,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # 2*(192*128)*(64*1024) = 3gb # (assuming 192 QK head dim, 128 heads, and fp16) 128 * 1024) + self.chunked_prefill_workspace_size = scheduler_config.max_num_seqs * cache_config.block_size assert self.chunked_prefill_workspace_size >= \ scheduler_config.max_num_seqs * cache_config.block_size self.chunked_prefill_workspace = torch.empty( @@ -566,8 +567,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): decode_threshold=1) def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor, - ubatch_id: Optional[int] = None): + seq_lens: torch.Tensor): return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens, @@ -600,8 +600,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False, - ubatch_id: Optional[int] = None) -> M: + fast_build: bool = False) -> M: num_reqs = common_attn_metadata.num_reqs num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len @@ -724,7 +723,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens=seq_lens[:num_decodes], - ubatch_id=ubatch_id ) attn_metadata = self.metadata_cls( diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index d519564d5e1e0..290dd252a1578 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -65,15 +65,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.num_q_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config) - self.cg_buf_tile_scheduler_metadata = [None, None] - self.cg_buf_num_splits = [None, None] + self.cg_buf_tile_scheduler_metadata = None + self.cg_buf_num_splits = None def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor, - ubatch_id: Optional[int] = None) -> FlashMLADecodeMetadata: + seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: # print(f"UBATCH ID: {ubatch_id}") - ubatch_id = 0 if ubatch_id is None else ubatch_id - assert ubatch_id < 2 tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens, @@ -85,28 +82,28 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): # if False: n = num_splits.size(0) # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_num_splits[ubatch_id] is None: + if self.cg_buf_num_splits is None: # logger.info(f"ALLOCATING FLASH MLA DATA FOR SIZE {n}") - self.cg_buf_num_splits[ubatch_id] = num_splits - self.cg_buf_tile_scheduler_metadata[ubatch_id] = tile_scheduler_metadata - elif n <= self.cg_buf_num_splits[ubatch_id].size(0): - assert self.cg_buf_tile_scheduler_metadata[ubatch_id] is not None + self.cg_buf_num_splits = num_splits + self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata + elif n <= self.cg_buf_num_splits.size(0): + assert self.cg_buf_tile_scheduler_metadata is not None # Metadata per-SM, fixed size (#SMs, TileMetadataSize) - assert (self.cg_buf_tile_scheduler_metadata[ubatch_id].size() == + assert (self.cg_buf_tile_scheduler_metadata.size() == tile_scheduler_metadata.size()) - self.cg_buf_tile_scheduler_metadata[ubatch_id].\ + self.cg_buf_tile_scheduler_metadata.\ copy_(tile_scheduler_metadata) - tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[ubatch_id] + tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata # Num splits is per-batch, varying size (batch_size,) n = num_splits.size(0) # logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}") # make sure static buffer is large enough - assert n <= self.cg_buf_num_splits[ubatch_id].size(0) - num_splits_view = self.cg_buf_num_splits[ubatch_id][:n] + assert n <= self.cg_buf_num_splits.size(0) + num_splits_view = self.cg_buf_num_splits[:n] num_splits_view.copy_(num_splits) - self.cg_buf_num_splits[ubatch_id][n:].fill_(0) # fill the rest with 0s + self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s num_splits = num_splits_view return FlashMLADecodeMetadata( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3e2947b6d2551..6f1295b235a04 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -197,7 +197,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # self.model: nn.Module # Set after load_model # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] - self.attn_metadata_builders: list[AttentionMetadataBuilder] = [] + # Outer index is kv cache index, inner index is ubatch id + self.attn_metadata_builders: list[list[AttentionMetadataBuilder]] = [] self.attn_backends: list[type[AttentionBackend]] = [] # self.kv_cache_config: KVCacheConfig @@ -361,6 +362,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.kv_sharing_fast_prefill_logits_indices = torch.zeros( self.max_num_tokens, dtype=torch.int32, device=self.device) + def get_builder(self, index: int, ubatch_id: Optional[int] = None) -> AttentionMetadataBuilder: + if ubatch_id is None: + return self.attn_metadata_builders[index][0] + else: + return self.attn_metadata_builders[index][ubatch_id] + def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: """ Update the order of requests in the batch based on the attention @@ -379,7 +386,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if len(self.kv_cache_config.kv_cache_groups) == 0: return - self.attn_metadata_builders[0].reorder_batch(self.input_batch, + self.get_builder(0).reorder_batch(self.input_batch, scheduler_output) # For models with multiple KV cache groups, the groups should agree on @@ -390,7 +397,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # re-order the batch (not only the first). # TODO(tdoublep): verify this during engine init instead of at runtime for i in range(1, len(self.kv_cache_config.kv_cache_groups)): - batch_reordered = self.attn_metadata_builders[i].reorder_batch( + batch_reordered = self.get_builder(i).reorder_batch( self.input_batch, scheduler_output) assert not batch_reordered @@ -943,7 +950,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 - builder = self.attn_metadata_builders[kv_cache_group_id] + builder = self.get_builder(kv_cache_group_id) if self.cascade_attn_enabled: common_prefix_len = self._compute_cascade_attn_prefix_len( num_scheduled_tokens, @@ -960,10 +967,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): common_attn_metadata_list): assert common_attn_metadata.max_query_len == 1 attn_metadata_i = ( - builder.build( + self.get_builder(kv_cache_group_id, ubatch_id=ubid).build( common_prefix_len=common_prefix_len, - common_attn_metadata=common_attn_metadata, - ubatch_id=ubid)) + common_attn_metadata=common_attn_metadata)) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][layer_name] = attn_metadata_i @@ -1023,8 +1029,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): attn_metadata[layer_name] = local_attn_metadata_i attention_cuda_graphs = all( - b.can_run_in_cudagraph(common_attn_metadata) - for b in self.attn_metadata_builders) + builder_list[0].can_run_in_cudagraph(common_attn_metadata) + for builder_list in self.attn_metadata_builders) # Hot-Swap lora model if self.lora_config: @@ -1897,27 +1903,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): is_dummy_run=is_dummy_run) if num_scheduled_tokens not in self.cudagraphs \ and not skip_cuda_graphs and build_cuda_graph: - # if is_global_first_rank(): - # logger.info(f"CAPTURING {num_scheduled_tokens}") + if is_global_first_rank(): + logger.info(f"CAPTURING {num_scheduled_tokens}") return self._capture_ubatches(ubatch_metadata, self.model) elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs: # assert False cudagraph_metadata = self.cudagraphs[num_scheduled_tokens] - # if is_global_first_rank(): - # logger.info(f"UBATCH REPLAY {num_scheduled_tokens}") + if is_global_first_rank(): + logger.info(f"UBATCH REPLAY {num_scheduled_tokens}") cudagraph_metadata.cudagraph.replay() return cudagraph_metadata.outputs else: - # if is_global_first_rank(): - # logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}") + if is_global_first_rank(): + logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}") return self._run_ubatches(ubatch_metadata, self.model) # run normal batch else: input_ids, positions, inputs_embeds, intermediate_tensors = \ self.model_inputs(slice(0, num_scheduled_tokens), scheduler_output, is_dummy_run) - # if is_global_first_rank(): - # logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}") + if is_global_first_rank(): + logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}") skip_cuda_graphs = self.parallel_config.enable_microbatching with set_forward_context(attn_metadata, vllm_config=self.vllm_config, @@ -2725,18 +2731,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ) for ubid, common_attn_metadata in enumerate(common_attn_metadata_list): attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id]. + self.get_builder(kv_cache_group_id, ubatch_id=ubid). build( common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ubatch_id=ubid + common_attn_metadata=common_attn_metadata )) for layer_name in kv_cache_group_spec.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][layer_name] = attn_metadata_i else: - attn_metadata_i = self.attn_metadata_builders[ - kv_cache_group_id].build_for_cudagraph_capture( + attn_metadata_i = self.get_builder( + kv_cache_group_id).build_for_cudagraph_capture( common_attn_metadata) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -3071,7 +3076,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _initialize_single_attn_backend( self, kv_cache_spec: KVCacheSpec, layer_names: list[str] - ) -> tuple[AttentionBackend, AttentionMetadataBuilder]: + ) -> tuple[type[AttentionBackend], list[AttentionMetadataBuilder]]: if isinstance(kv_cache_spec, AttentionSpec): attn_backend_i = get_attn_backend( kv_cache_spec.head_size, @@ -3098,12 +3103,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): raise ValueError( f"Unknown KV cache spec type: {type(kv_cache_spec)}") + builders = [] attn_metadata_builder_i = attn_backend_i.get_builder_cls()( kv_cache_spec, layer_names, self.vllm_config, self.device, ) + builders.append(attn_metadata_builder_i) + if self.parallel_config.enable_microbatching: + attn_metadata_builder_2 = attn_backend_i.get_builder_cls()( + kv_cache_spec, + layer_names, + self.vllm_config, + self.device, + ) + builders.append(attn_metadata_builder_2) + if (self.full_cuda_graph and not attn_metadata_builder_i.full_cudagraph_supported): @@ -3111,7 +3127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): f"Full CUDAGraph not supported for " f"{attn_backend_i.__name__}. Turn off CompilationConfig." f"full_cuda_graph or use a different attention backend.") - return attn_backend_i, attn_metadata_builder_i + return attn_backend_i, builders def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ @@ -3124,11 +3140,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): kv_cache_config.kv_cache_groups): kv_cache_spec = kv_cache_group_spec.kv_cache_spec - attn_backend_i, attn_metadata_builder_i = ( + attn_backend_i, attn_metadata_builders = ( self._initialize_single_attn_backend( kv_cache_spec, kv_cache_group_spec.layer_names)) self.attn_backends.append(attn_backend_i) - self.attn_metadata_builders.append(attn_metadata_builder_i) + self.attn_metadata_builders.append(attn_metadata_builders) if len(self.attn_backends) > 0: return @@ -3157,11 +3173,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): assert len(attn_specs) == len(attn_layers), \ "All or none of the layers are expected to be encoder-only" - attn_backend, attn_metadata_builder = ( + attn_backend, attn_metadata_builders = ( self._initialize_single_attn_backend(attn_specs[0], attn_layers.keys())) self.attn_backends.append(attn_backend) - self.attn_metadata_builders.append(attn_metadata_builder) + self.attn_metadata_builders.append(attn_metadata_builders) self.is_encoder_only_model = True def may_reinitialize_input_batch(self,