From af68574e3d23f3e69c0d5d036d7636d109c25039 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 26 Jun 2025 03:57:48 +0000 Subject: [PATCH] reintegrate full cudagraphs Signed-off-by: Sage Moore --- .../layers/vocab_parallel_embedding.py | 1 + vllm/v1/attention/backends/mla/flashmla.py | 12 +- vllm/v1/worker/gpu_model_runner.py | 150 +++++++----------- 3 files changed, 62 insertions(+), 101 deletions(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 9ff3a7a7327d9..dd9fc04ce5a5b 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -47,6 +47,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: + print("SHOULDNT BE HERE DURING CAPTURE") return F.embedding(input_, layer.weight) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 38932ff283d6a..3fc33e2f86175 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -75,15 +75,15 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): 1, # MQA for the decode path ) - n = num_splits.size(0) # logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}") - if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1] // 2: + if self.runner.full_cuda_graph: + n = num_splits.size(0) # First time around (CUDAGraph capture), allocate the static buffer - if self.cg_buf_tile_scheduler_metadata is None: - self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata + if self.cg_buf_num_splits is None: self.cg_buf_num_splits = num_splits - else: - assert self.cg_buf_num_splits is not None + self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata + elif n <= self.cg_buf_num_splits.size(0): + assert self.cg_buf_tile_scheduler_metadata is not None # Metadata per-SM, fixed size (#SMs, TileMetadataSize) assert (self.cg_buf_tile_scheduler_metadata.size() == diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 59c02fd5daa18..fa2e85bfb7531 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -221,7 +221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.use_cuda_graph = (self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) - # self.use_cuda_graph = True + self.use_cuda_graph = True logger.info(f"self.use_cuda_graph {self.use_cuda_graph}") # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. @@ -1691,89 +1691,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): for thread in ubatch_threads: thread.join() - torch.cuda.synchronize() + # torch.cuda.synchronize() torch.cuda.set_stream(root_stream) sorted_results = [value for position, value in sorted(results)] return torch.cat(sorted_results, dim=0) - # def _run_for_real(input_ids, - # positions, - # intermediate_tensors, - # input_embeds, - # attn_metadata, - # num_tokens_across_dp, - # token_slices, - # skip_cuda_graphs): - # # run micro-batched - # if len(token_slices) > 1: - # assert len(token_slices) == 2 - # # num_tokens = ubatch_slices[1][1].stop - # # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}") - # # assert not is_dummy_run - # model_output = _run_ubatches(token_slices, - # attn_metadata, - # is_dummy_run, - # num_tokens_across_dp=num_tokens_across_dp) - # # run single batch - # else: - # # print("RUN NORMAL") - # num_tokens = token_slices[0].stop - token_slices[0].start - # if num_tokens == 0: - # num_tokens = 1 - # model_output = _run( - # token_slices[0], - # set_forward_context(attn_metadata, - # vllm_config=self.vllm_config, - # num_tokens=num_tokens, - # num_tokens_across_dp=num_tokens_across_dp, - # skip_cuda_graphs=skip_cuda_graphs), - # is_dummy_run) - # return model_output - - # num_tokens = token_slice[0].stop - token_slice[1].start - # # We have multiple sets of inputs here which is a bummer. - # # We'll need to pass around a bunch of lists which super sucks. - # input_ids, positions, inputs_embeds, intermediate_tensors = \ - # model_inputs(token_slice, use_dummy_input) - # if build_cuda_graph and num_tokens not in self.cudagraphs: - # print(f"Capturing for {num_tokens}") - # # assert use_dummy_input - # using_ubaching = ubatch_slices is not None - # assert using_ubaching - # self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph(), - # using_ubatching=using_ubaching) - # with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph): - # # TODO (Sage) I assume we can just get these before calling this function - # # Args to delete: - # # attn_metadata - # # skip_cudagraphs - # model_output = self._run_for_real( - # input_ids=input_ids, - # positions=positions, - # intermediate_tensors=intermediate_tensors, - # inputs_embeds=inputs_embeds, - # attn_metadata=attn_metadata, - # num_tokens_across_dp=num_tokens_across_dp, - # skip_cuda_graphs=skip_cuda_graphs - # ) - # self.cudagraphs[num_tokens].outputs = model_output - # elif num_tokens in self.cudagraphs and not skip_cuda_graphs: - # logger.info("GRAPH REPLAY") - # assert self.cudagraphs[num_tokens].using_ubatching == ubatch_slices is not None - # self.cudagraphs[num_tokens].cudagraph.replay() - # model_output = self.cudagraphs[num_tokens].outputs - # else: - # # TODO (Sage) We need to figure out how to move some of this context management - # # logic outside of the graph capture - # model_output = self._run_for_real( - # input_ids=input_ids, - # positions=positions, - # intermediate_tensors=intermediate_tensors, - # inputs_embeds=inputs_embeds, - # attn_metadata=attn_metadata, - # num_tokens_across_dp=num_tokens_across_dp, - # skip_cuda_graphs=skip_cuda_graphs - # ) # run micro-batched if ubatch_slices is not None: assert len(ubatch_slices) == 2, "Only two ubatches has been tested" @@ -1786,25 +1708,63 @@ class GPUModelRunner(LoRAModelRunnerMixin): is_dummy_run=is_dummy_run, num_tokens_across_dp=num_tokens_across_dp ) - model_output = _run_ubatches(ubatch_metadata) + if num_scheduled_tokens not in self.cudagraphs \ + and not skip_cuda_graphs and build_cuda_graph: + # DO capture + self.cudagraphs[num_scheduled_tokens] = \ + CUDAGraphMetaData( + cudagraph=torch.cuda.CUDAGraph(), + using_ubatching=True + ) + with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph): + model_output = _run_ubatches(ubatch_metadata) + self.cudagraphs[num_scheduled_tokens].outputs = model_output + return self.cudagraphs[num_scheduled_tokens].outputs + elif num_scheduled_tokens in self.cudagraphs: + self.cudagraphs[num_scheduled_tokens].cudagraph.replay() + return self.cudagraphs[num_scheduled_tokens].outputs + else: + return _run_ubatches(ubatch_metadata) # run single batch else: - # print("RUN NORMAL") input_ids, positions, inputs_embeds, intermediate_tensors = \ model_inputs(slice(0, num_scheduled_tokens), is_dummy_run) - model_output = _run( - context = set_forward_context(attn_metadata, - vllm_config=self.vllm_config, - num_tokens=num_scheduled_tokens or 1, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs), - input_ids=input_ids, - positions=positions, - inputs_embeds=inputs_embeds, - intermediate_tensors=intermediate_tensors - ) - - return model_output + if num_scheduled_tokens not in self.cudagraphs \ + and not skip_cuda_graphs and build_cuda_graph: + self.cudagraphs[num_scheduled_tokens] = \ + CUDAGraphMetaData( + cudagraph=torch.cuda.CUDAGraph(), + using_ubatching=False + ) + with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph): + model_output = _run( + context = set_forward_context(attn_metadata, + vllm_config=self.vllm_config, + num_tokens=num_scheduled_tokens or 1, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs), + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors + ) + self.cudagraphs[num_scheduled_tokens].outputs = model_output + return self.cudagraphs[num_scheduled_tokens].outputs + elif num_scheduled_tokens in self.cudagraphs: + self.cudagraphs[num_scheduled_tokens].cudagraph.replay() + return self.cudagraphs[num_scheduled_tokens].outputs + else: + return _run( + context = set_forward_context(attn_metadata, + vllm_config=self.vllm_config, + num_tokens=num_scheduled_tokens or 1, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs), + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors + ) @torch.inference_mode() def execute_model(