From 1c41175b2a635dc23177fc37abb51befb5db30af Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 25 Jul 2025 20:08:05 +0000 Subject: [PATCH] full cudagraphs Signed-off-by: Sage Moore --- vllm/compilation/decorators.py | 1 + .../layers/fused_moe/fused_moe.py | 2 +- vllm/v1/attention/backends/mla/common.py | 10 +- vllm/v1/attention/backends/mla/flashmla.py | 41 ++- vllm/v1/worker/gpu_model_runner.py | 337 +++++++++++++----- 5 files changed, 286 insertions(+), 105 deletions(-) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index f3592324d8cfa..baf98306ad241 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -189,6 +189,7 @@ def _support_torch_compile( CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() or _should_ignore_torch_compile( self.__class__) + self.do_not_compile = True if self.do_not_compile: return diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index c412f695ae766..1df98bb915a2f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -915,7 +915,7 @@ def fused_topk( # This is used by the Deepseek-V2 and Deepseek-V3 model -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +# @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ebf27c3c251b3..6f5fe31722d6f 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -560,11 +560,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) + scheduler_output, + decode_threshold=1) - def _build_decode(self, - block_table_tensor: torch.Tensor, + def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor, ubatch_id: Optional[int] = None): return MLACommonDecodeMetadata( @@ -723,7 +722,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens=seq_lens[:num_decodes], - ubatch_id=ubatch_id) + ubatch_id=ubatch_id + ) attn_metadata = self.metadata_cls( num_reqs=common_attn_metadata.num_reqs, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 6274552c62c6c..4d3ed6d576dc4 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -64,14 +64,15 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): self.num_q_heads = vllm_config.model_config.get_num_attention_heads( vllm_config.parallel_config) - self.cg_buf_tile_scheduler_metadata = None - self.cg_buf_num_splits = None + self.cg_buf_tile_scheduler_metadata = [None, None] + self.cg_buf_num_splits = [None, None] - def _build_decode( - self, - block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor, - ubatch_id: Optional[int] = None) -> FlashMLADecodeMetadata: + def _build_decode(self, block_table_tensor: torch.Tensor, + seq_lens: torch.Tensor, + ubatch_id: Optional[int] = None) -> FlashMLADecodeMetadata: + # print(f"UBATCH ID: {ubatch_id}") + ubatch_id = 0 if ubatch_id is None else ubatch_id + assert ubatch_id < 2 tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens, @@ -80,27 +81,31 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ) if self.compilation_config.full_cuda_graph: + # if False: + n = num_splits.size(0) # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_tile_scheduler_metadata is None: - self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata - self.cg_buf_num_splits = num_splits - else: - assert self.cg_buf_num_splits is not None + if self.cg_buf_num_splits[ubatch_id] is None: + # logger.info(f"ALLOCATING FLASH MLA DATA FOR SIZE {n}") + self.cg_buf_num_splits[ubatch_id] = num_splits + self.cg_buf_tile_scheduler_metadata[ubatch_id] = tile_scheduler_metadata + elif n <= self.cg_buf_num_splits[ubatch_id].size(0): + assert self.cg_buf_tile_scheduler_metadata[ubatch_id] is not None # Metadata per-SM, fixed size (#SMs, TileMetadataSize) - assert (self.cg_buf_tile_scheduler_metadata.size() == + assert (self.cg_buf_tile_scheduler_metadata[ubatch_id].size() == tile_scheduler_metadata.size()) - self.cg_buf_tile_scheduler_metadata.\ + self.cg_buf_tile_scheduler_metadata[ubatch_id].\ copy_(tile_scheduler_metadata) - tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata + tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[ubatch_id] # Num splits is per-batch, varying size (batch_size,) n = num_splits.size(0) + # logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}") # make sure static buffer is large enough - assert n <= self.cg_buf_num_splits.size(0) - num_splits_view = self.cg_buf_num_splits[:n] + assert n <= self.cg_buf_num_splits[ubatch_id].size(0) + num_splits_view = self.cg_buf_num_splits[ubatch_id][:n] num_splits_view.copy_(num_splits) - self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s + self.cg_buf_num_splits[ubatch_id][n:].fill_(0) # fill the rest with 0s num_splits = num_splits_view return FlashMLADecodeMetadata( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e125a1217e192..f96354ac578ca 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -110,7 +110,13 @@ class UbatchMetadata: positions: torch.Tensor inputs_embeds: Optional[torch.Tensor] intermediate_tensors: Optional[IntermediateTensors] + num_tokens: int +@dataclasses.dataclass +class CUDAGraphMetaData: + cudagraph: torch.cuda.CUDAGraph + ubatch_metadata: UbatchMetadata + outputs: Optional[Any] = None class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): @@ -247,6 +253,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): == CompilationLevel.PIECEWISE and self.vllm_config.compilation_config.use_cudagraph and not self.model_config.enforce_eager) + self.use_cuda_graph = True # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. # self.cudagraph_batch_sizes sorts in ascending order. @@ -254,6 +261,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.cudagraph_batch_sizes = list( reversed(self.compilation_config.cudagraph_capture_sizes)) self.full_cuda_graph = self.compilation_config.full_cuda_graph + self.cudagraphs = {} # Cache the device properties. self._init_device_properties() @@ -638,7 +646,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens_after_padding = None ubatch_abort = False num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch( - ubatch_slices) + ubatch_slices, True) if num_pad_tokens > 0: # Check if the padding would result in an empty second ubatch. # If so abort ubatching @@ -871,12 +879,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): builder, ) - # Fill unused with -1. Needed for reshape_and_cache in full cuda - # graph mode. - if self.vllm_config.compilation_config.full_cuda_graph: - self.input_batch.block_table[kv_cache_group_id]\ - .slot_mapping.fill_(-1) - if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( ubatch_slices, common_attn_metadata) @@ -1446,8 +1448,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens_padded = num_tokens_unpadded - if (self.use_cuda_graph - and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + # if (self.use_cuda_graph + # and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + if False: # Use piecewise CUDA graphs. # Add padding to the batch size. num_tokens_padded = self.vllm_config.pad_for_cudagraph( @@ -1470,7 +1473,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def get_dp_padding_ubatch( self, - ubatch_slices: UBatchSlices) -> tuple[int, Optional[torch.Tensor]]: + ubatch_slices: UBatchSlices, + include_cudagraphs: bool) -> tuple[int, Optional[torch.Tensor]]: dp_size = self.vllm_config.parallel_config.data_parallel_size if dp_size == 1: @@ -1490,6 +1494,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens num_tokens_padded = round_up(num_tokens_unpadded, 2) + if (include_cudagraphs and self.use_cuda_graph + and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): + # Add padding to the batch size. + num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) + else: + # Eager mode. + # Pad tokens to multiple of tensor_parallel_size when + # enabled collective fusion for SP + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + if self.vllm_config.compilation_config.pass_config. \ + enable_sequence_parallelism and tp_size > 1: + num_tokens_padded = round_up(num_tokens_unpadded, tp_size) num_tokens_per_ubatch = num_tokens_padded // 2 @@ -1538,14 +1554,39 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return DPMetadata.should_ubatch_across_dp(should_ubatch, dp_size, dp_rank) + def _get_dummy_model_inputs(self, tokens_slice) -> tuple: + # Dummy batch. (hopefully we are the last one so we can just + # update this to a one token batch and return) + + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[tokens_slice] + else: + input_ids = self.input_ids[tokens_slice] + inputs_embeds = None + + if self.uses_mrope: + positions = self.mrope_positions[:, tokens_slice] + else: + positions = self.positions[tokens_slice] + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + assert False + + + return input_ids, positions, inputs_embeds, intermediate_tensors + def _get_model_inputs(self, tokens_slice: slice, - scheduler_output: "SchedulerOutput"): + scheduler_output: Optional["SchedulerOutput"]): assert tokens_slice.stop - tokens_slice.start > 0 # _prepare_inputs may reorder the batch, so we must gather multi # modal outputs after that to ensure the correct order if self.is_multimodal_model: # Run the multimodal encoder if any. + assert scheduler_output is not None self._execute_mm_encoder(scheduler_output) mm_embeds = self._gather_mm_embeddings(scheduler_output) else: @@ -1584,10 +1625,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): tokens_slice, intermediate_tensors, True) return input_ids, positions, inputs_embeds, intermediate_tensors + def model_inputs(self, tokens_slice: slice, + scheduler_output: Optional["SchedulerOutput"], + use_dummy_input: bool) -> tuple: + if use_dummy_input: + # print("MAKING DUMMY BATCH") + # assert num_dummy_tokens == 1 + return self._get_dummy_model_inputs(tokens_slice) + else: + return self._get_model_inputs(tokens_slice, scheduler_output) def _make_ubatch_metadata(self, ubatch_slices, attn_metadata, compute_stream, num_tokens_across_dp, skip_cuda_graphs, - scheduler_output) -> list[UbatchMetadata]: + scheduler_output, is_dummy_run) -> list[UbatchMetadata]: # Create one forward context per ubatch forward_contexts = [] @@ -1611,16 +1661,80 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ubatch_metadata: list[UbatchMetadata] = [] for i, (_, tokens_slice) in enumerate(ubatch_slices): input_ids, positions, inputs_embeds, intermediate_tensors = \ - self._get_model_inputs(tokens_slice, scheduler_output) + self.model_inputs(tokens_slice, scheduler_output, is_dummy_run) ubatch_metadata.append( UbatchMetadata(context=ubatch_ctxs[i], input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors)) + intermediate_tensors=intermediate_tensors, + num_tokens=tokens_slice.stop - tokens_slice.start)) return ubatch_metadata + def _capture_ubatches(self, ubatch_metadata, model) -> torch.Tensor: + def _capture_ubatch_thread(results, ubatch_metadata, start_signal): + # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) + context = ubatch_metadata.context + with torch.cuda.stream(context.compute_stream): + _ = torch.cuda.current_blas_handle() + with torch.cuda.stream(context.comm_stream): + _ = torch.cuda.current_blas_handle() + with context: + start_signal.wait() + model_output = model( + input_ids=ubatch_metadata.input_ids, + positions=ubatch_metadata.positions, + intermediate_tensors=ubatch_metadata.intermediate_tensors, + inputs_embeds=ubatch_metadata.inputs_embeds, + ) + + results.append((ubatch_metadata.context.id, model_output)) + + results: list[tuple[int, torch.Tensor]] = [] + compute_stream = ubatch_metadata[0].context.compute_stream + num_tokens = ubatch_metadata[0].num_tokens + ubatch_metadata[1].num_tokens + + # Ubatches will manually manage the forward context, so we override + # it to None here so we can have it restored correctly later + with override_forward_context(None): + ubatch_threads = [] + start_signals = [] + for metadata in ubatch_metadata: + start_signal = threading.Event() + thread = threading.Thread(target=_capture_ubatch_thread, + args=( + results, + metadata, + start_signal, + )) + ubatch_threads.append(thread) + thread.start() + start_signals.append(start_signal) + + # DO capture + cudagraph_metadata = \ + CUDAGraphMetaData( + cudagraph=torch.cuda.CUDAGraph(), + ubatch_metadata=ubatch_metadata, + ) + with torch.cuda.graph(cudagraph_metadata.cudagraph, + stream=compute_stream): + # logger.info("STARTING WAKEUP LOOP") + for start_signal in start_signals: + start_signal.set() + # logger.info("FINISHED WAKEUP LOOP") + ubatch_metadata[0].context.cpu_wait_event.set() + for thread in ubatch_threads: + thread.join() + sorted_results = [value for position, value in sorted(results)] + result = torch.cat(sorted_results, dim=0) + cudagraph_metadata.outputs = result + # if is_global_first_rank(): + # logger.info(f"IN UBATCH RUNNER: Capturing for {num_tokens} tokens") + self.cudagraphs[num_tokens] = cudagraph_metadata + return cudagraph_metadata.outputs + def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor: @torch.inference_mode() @@ -1659,12 +1773,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): return result def _run_model(self, - attn_metadata: PerLayerAttnMetadata, + attn_metadata: Optional[PerLayerAttnMetadata], num_scheduled_tokens: int, - scheduler_output: "SchedulerOutput", + scheduler_output: Optional["SchedulerOutput"] = None, ubatch_slices: Optional[UBatchSlices] = None, + is_dummy_run: bool = False, num_tokens_across_dp: Optional[torch.Tensor] = None, - skip_cuda_graphs: bool = False): + skip_cuda_graphs: bool = False, + build_cuda_graph: bool = False): # run micro-batched if ubatch_slices is not None: @@ -1677,18 +1793,36 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): compute_stream=compute_stream, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs, - scheduler_output=scheduler_output) - return self._run_ubatches(ubatch_metadata, self.model) + scheduler_output=scheduler_output, + is_dummy_run=is_dummy_run) + if num_scheduled_tokens not in self.cudagraphs \ + and not skip_cuda_graphs and build_cuda_graph: + # if is_global_first_rank(): + # logger.info(f"CAPTURING {num_scheduled_tokens}") + return self._capture_ubatches(ubatch_metadata, self.model) + elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs: + # assert False + cudagraph_metadata = self.cudagraphs[num_scheduled_tokens] + # if is_global_first_rank(): + # logger.info(f"UBATCH REPLAY {num_scheduled_tokens}") + cudagraph_metadata.cudagraph.replay() + return cudagraph_metadata.outputs + else: + # if is_global_first_rank(): + # logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}") + return self._run_ubatches(ubatch_metadata, self.model) # run normal batch else: input_ids, positions, inputs_embeds, intermediate_tensors = \ - self._get_model_inputs(slice(0, num_scheduled_tokens), - scheduler_output) + self.model_inputs(slice(0, num_scheduled_tokens), + scheduler_output, is_dummy_run) + # if is_global_first_rank(): + # logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}") with set_forward_context(attn_metadata, vllm_config=self.vllm_config, num_tokens=num_scheduled_tokens or 1, num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs): + skip_cuda_graphs=True): return self.model( input_ids=input_ids, positions=positions, @@ -1776,6 +1910,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # If attention doesn't support CUDA Graphs for this batch, but we # compiled with full CUDA graphs, we have to skip them entirely. skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs + # print(f"SKIPPING CUDA GRAPHS: {skip_cuda_graphs} {self.full_cuda_graph}") # Run the model. # Use persistent buffers for CUDA graphs. @@ -2358,18 +2493,44 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): def _dummy_run( self, num_tokens: int, + skip_attn: bool = True, + # Maybe return a cudagraph here capture_attn_cudagraph: bool = False, + + # For profiling runs we dont want microbatching but for + # dp dummy runs we do. + allow_microbatching: bool = False, + build_cuda_graph: bool = False, skip_eplb: bool = False, is_profile: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: - # _dummy_run doesn't go through _prepare_inputs so - # we synchronize with other DP groups that may be - # attempting to microbatch here. - if self.parallel_config.enable_microbatching: - _ = self.should_ubatch(False) - num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + # if allow_microbatching: + # logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN") + + + # TODO(Sage) We need some more code to properly handle + # mixing normal and dummy runs. The DP padding needs to + # be properly setup. Since we only support microbatching + # in cuda graph capture it's fine to ignore the DP padding + # for now. + ubatch_enabled = self.parallel_config.enable_microbatching + should_ubatch = False + if ubatch_enabled: + should_ubatch = num_tokens >= \ + self.parallel_config.microbatching_token_threshold and \ + allow_microbatching and capture_attn_cudagraph + should_ubatch = self.should_ubatch(should_ubatch) + # _dummy_run doesn't go through _prepare_inputs so + # we synchronize with other DP ranks here + # logger.info(f"NUM TOKENS {num_tokens} SHOULD UBATCH {should_ubatch}") + # Padding for DP + # logger.info("PADDING DUMMY") + num_tokens_across_dp = None + num_pad = 0 + if not should_ubatch: + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_tokens += num_pad # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively @@ -2385,9 +2546,30 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) - attn_metadata: Optional[dict[str, Any]] = None + ubatch_slices = None + # We currently only microbatch if the number of tokens is + # over a certain threshold. + # logger.info("PADDING DUMMY DONE") + if should_ubatch: + # We only support decode-only cudagraphs + assert num_reqs == num_tokens + assert num_tokens % 2 == 0 + num_tokens_per_ubatch = num_tokens // 2 + num_tokens_across_dp = torch.tensor([num_tokens_per_ubatch] * 2, + device="cpu", + dtype=torch.int32) + ubatch_slices = [(slice(0, num_reqs // 2), + slice(0, num_tokens // 2)), + (slice(num_reqs // 2, num_reqs), + slice(num_tokens // 2, num_tokens))] + + + # attn_metadata: Optional[dict[str, Any]] = None + attn_metadata: Optional[PerLayerAttnMetadata]= None if capture_attn_cudagraph: attn_metadata = {} + if ubatch_slices is not None: + attn_metadata = [dict() for _ in range(len(ubatch_slices))] # Make sure max_model_len is used at the graph capture time. self.seq_lens_np[:num_reqs] = self.max_model_len @@ -2395,7 +2577,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) - for kv_cache_group_id, kv_cache_group_spec in enumerate( + max_query_len = num_tokens + if ubatch_slices is not None: + max_query_len = 1 + for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): common_attn_metadata = CommonAttentionMetadata( query_start_loc=self.query_start_loc[:num_reqs + 1], @@ -2407,65 +2592,47 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_computed_tokens_cpu_tensor[:num_reqs], num_reqs=num_reqs, num_actual_tokens=num_tokens, - max_query_len=num_tokens, + max_query_len=max_query_len, block_table_tensor=self.input_batch.block_table[ kv_cache_group_id].get_device_tensor()[:num_reqs], slot_mapping=self.input_batch. block_table[kv_cache_group_id].slot_mapping[:num_tokens]) - attn_metadata_i = self.attn_metadata_builders[ - kv_cache_group_id].build_for_cudagraph_capture( - common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: - attn_metadata[layer_name] = attn_metadata_i + + if ubatch_slices is not None: + common_attn_metadata_list = split_attn_metadata( + ubatch_slices, + common_attn_metadata + ) + for ubid, common_attn_metadata in enumerate(common_attn_metadata_list): + attn_metadata_i = ( + self.attn_metadata_builders[kv_cache_group_id]. + build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ubatch_id=ubid + )) + for layer_name in kv_cache_group_spec.layer_names: + assert type(attn_metadata) is list + attn_metadata[ubid][layer_name] = attn_metadata_i + else: + attn_metadata_i = self.attn_metadata_builders[ + kv_cache_group_id].build_for_cudagraph_capture( + common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): - model = self.model - if self.is_multimodal_model: - model_kwargs = self._init_model_kwargs_for_multimodal_model( - num_reqs=num_reqs) - input_ids = None - inputs_embeds = self.inputs_embeds[:num_tokens] - else: - input_ids = self.input_ids[:num_tokens] - inputs_embeds = None - model_kwargs = {} - - if self.uses_mrope: - positions = self.mrope_positions[:, :num_tokens] - else: - positions = self.positions[:num_tokens] - - if get_pp_group().is_first_rank: - intermediate_tensors = None - else: - if self.intermediate_tensors is None: - self.intermediate_tensors = ( - self.model.make_empty_intermediate_tensors( - batch_size=self.max_num_tokens, - dtype=self.model_config.dtype, - device=self.device)) - - intermediate_tensors = self.sync_and_slice_intermediate_tensors( - slice(0, num_tokens), None, False) - - with self.maybe_randomize_inputs(input_ids), set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp): - outputs = model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **MultiModalKwargs.as_kwargs( - model_kwargs, - device=self.device, - ), - ) - + outputs = self._run_model( + attn_metadata, + num_tokens, + ubatch_slices=ubatch_slices, + is_dummy_run=True, + num_tokens_across_dp=num_tokens_across_dp, + build_cuda_graph=build_cuda_graph + ) if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: @@ -2749,6 +2916,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # can reuse the memory pool allocated for the large shapes. with freeze_gc(), graph_capture(device=self.device): full_cg = self.full_cuda_graph + allow_microbatching = self.parallel_config.enable_microbatching # Only rank 0 should print progress bar during capture compilation_cases = reversed(self.cudagraph_batch_sizes) if is_global_first_rank(): @@ -2758,13 +2926,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): desc="Capturing CUDA graph shapes") for num_tokens in compilation_cases: # We skip EPLB here since we don't want to record dummy metrics + # if is_global_first_rank(): + # logger.info(f"CAPTURE SIZE {num_tokens} WARMING UP {self.compilation_config.cudagraph_num_of_warmups}") for _ in range( self.compilation_config.cudagraph_num_of_warmups): self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg, + allow_microbatching=allow_microbatching, skip_eplb=True) + # if is_global_first_rank(): + # logger.info(f"CAPTURE SIZE {num_tokens} STARTING CAPTURE") self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg, + allow_microbatching=allow_microbatching, + build_cuda_graph=True, skip_eplb=True) end_time = time.perf_counter()