diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 5f1b268a1d6fe..54d5af2ad29d8 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -157,6 +157,7 @@ def _support_torch_compile( vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS ] or not supports_dynamo() + self.do_not_compile = True if self.do_not_compile: return compilation_counter.num_models_seen += 1 diff --git a/vllm/config.py b/vllm/config.py index c4fc320ec4d92..ab963b24f693b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -4518,6 +4518,11 @@ class VllmConfig: "Piecewise compilation is not supported with " "microbatching. Disabling piecewiseching compilation.") self.compilation_config.level = CompilationLevel.NO_COMPILATION + if not self.model_config.enforce_eager: + self.compilation_config.full_cuda_graph = True + logger.warning_once( + "Enabling fullcudagraphs for microbatching" + ) if (self.kv_events_config is not None and self.kv_events_config.enable_kv_cache_events diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 437e80696ac65..85ee54fc68ec2 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -903,7 +903,7 @@ def fused_topk( # This is used by the Deepseek-V2 and Deepseek-V3 model -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +# @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index be26e0060db5e..2b42e71053c81 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -75,7 +75,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): 1, # MQA for the decode path ) - if self.runner.full_cuda_graph: + n = num_splits.size(0) + logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}") + if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1]: # First time around (CUDAGraph capture), allocate the static buffer if self.cg_buf_tile_scheduler_metadata is None: self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata @@ -92,6 +94,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): # Num splits is per-batch, varying size (batch_size,) n = num_splits.size(0) + logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}") # make sure static buffer is large enough assert n <= self.cg_buf_num_splits.size(0) num_splits_view = self.cg_buf_num_splits[:n] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f98fafe7996aa..76b7b6b2f1e81 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -91,6 +91,12 @@ UbatchSlice: TypeAlias = tuple[slice, slice] UBatchSlices: TypeAlias = list[UbatchSlice] +import dataclasses +@dataclasses.dataclass +class CUDAGraphMetaData: + cudagraph: torch.cuda.CUDAGraph + outputs: Optional[Any] = None + class GPUModelRunner(LoRAModelRunnerMixin): def __init__( @@ -132,6 +138,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs + self.cudagraphs = {} # Model-related. self.num_query_heads = model_config.get_num_attention_heads( parallel_config) @@ -213,6 +220,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.use_cuda_graph = (self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) + self.use_cuda_graph = True logger.info(f"self.use_cuda_graph {self.use_cuda_graph}") # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. @@ -222,6 +230,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): reversed(self.compilation_config.cudagraph_capture_sizes)) self.full_cuda_graph = self.compilation_config.full_cuda_graph + self.full_cuda_graph = True + logger.info(f"full_cuda_graph {self.full_cuda_graph}") # Cache the device properties. self._init_device_properties() @@ -1553,7 +1563,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output: Optional["SchedulerOutput"] = None, is_dummy_run: bool = False, num_tokens_across_dp: Optional[torch.Tensor] = None, - skip_cuda_graphs: bool = False): + skip_cuda_graphs: bool = False, + build_cuda_graph: bool = False): num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1 @@ -1566,16 +1577,43 @@ class GPUModelRunner(LoRAModelRunnerMixin): assert scheduler_output is not None return self._get_model_inputs(tokens_slice, scheduler_output) - def _run(token_slice: slice, context, use_dummy_input: bool = False): + def _run(token_slice: slice, + context, + use_dummy_input: bool = False, + build_cuda_graph: bool = False): input_ids, positions, inputs_embeds, intermediate_tensors = \ model_inputs(token_slice, use_dummy_input) with context: - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + # model_output = self.model( + # input_ids=input_ids, + # positions=positions, + # intermediate_tensors=intermediate_tensors, + # inputs_embeds=inputs_embeds, + # ) + num_tokens = token_slice.stop - token_slice.start + if build_cuda_graph and num_tokens not in self.cudagraphs: + print(f"Capturing for {num_tokens}") + assert use_dummy_input + self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph()) + with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph): + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + self.cudagraphs[num_tokens].outputs = model_output + elif num_tokens in self.cudagraphs and not skip_cuda_graphs: + logger.info("GRAPH REPLAY") + self.cudagraphs[num_tokens].cudagraph.replay() + model_output = self.cudagraphs[num_tokens].outputs + else: + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) if isinstance(context, UBatchContext): # Clone before we leave the ubatch context model_output = model_output.clone() @@ -1584,16 +1622,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): @torch.inference_mode() def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, - use_dummy_input): + use_dummy_input, build_cuda_graph): # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) - model_output = _run(token_slice, ubatch_ctx, use_dummy_input) + model_output = _run(token_slice, ubatch_ctx, use_dummy_input, build_cuda_graph) if save_results: results.append((ubatch_ctx.id, model_output)) # print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) def _run_ubatches(ubatch_slices, attn_metadata, - is_dummy_run, num_tokens_across_dp) -> torch.Tensor: + is_dummy_run, num_tokens_across_dp, build_cuda_graph) -> torch.Tensor: results: list[tuple[int, torch.Tensor]] = [] assert len(ubatch_slices) == 2, "Only two ubatches has been tested" root_stream = current_stream() @@ -1632,6 +1670,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): or is_dummy_run, is_dummy_ubatch or is_dummy_run, + build_cuda_graph )) ubatch_threads.append(thread) thread.start() @@ -1650,8 +1689,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): # num_tokens = ubatch_slices[1][1].stop # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}") # assert not is_dummy_run - model_output = _run_ubatches(ubatch_slices, attn_metadata, - is_dummy_run, num_tokens_across_dp=num_tokens_across_dp) + model_output = _run_ubatches(ubatch_slices, + attn_metadata, + is_dummy_run, + num_tokens_across_dp=num_tokens_across_dp, + build_cuda_graph=build_cuda_graph) # run single batch else: # print("RUN NORMAL") @@ -1662,7 +1704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens=num_scheduled_tokens or 1, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs), - is_dummy_run) + is_dummy_run, + build_cuda_graph=build_cuda_graph) return model_output @@ -2222,6 +2265,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # For profiling runs we dont want microbatching but for # dp dummy runs we do. allow_microbatching: bool = False, + build_cuda_graph: bool = False ) -> torch.Tensor: if allow_microbatching: @@ -2239,7 +2283,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): # _dummy_run doesn't go through _prepare_inputs so # we synchronize with other DP ranks here should_ubatch = self.should_ubatch(allow_microbatching) - assert not should_ubatch # Padding for DP # logger.info("PADDING DUMMY") num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) @@ -2295,10 +2338,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): # over a certain threshold. if should_ubatch: assert num_tokens % 2 == 0 - # TODO (Sage) Add actual slices here - assert False - dummy_microbatches = [(slice(0, 0), slice(0, 0)), - (slice(0, 0), slice(0, 0))] + dummy_microbatches = [(slice(0, num_tokens // 2), + slice(0, num_tokens // 2)), + (slice(num_tokens // 2, num_tokens), + slice(num_tokens // 2, num_tokens))] with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): @@ -2307,7 +2350,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens, ubatch_slices=dummy_microbatches, is_dummy_run=True, - num_tokens_across_dp=num_tokens_across_dp + num_tokens_across_dp=num_tokens_across_dp, + build_cuda_graph=build_cuda_graph ) if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs @@ -2505,10 +2549,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.compilation_config.cudagraph_num_of_warmups): self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg, - allow_microbatching=allow_microbatching) + allow_microbatching=allow_microbatching, + build_cuda_graph=True) self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg, - allow_microbatching=allow_microbatching) + allow_microbatching=allow_microbatching, + build_cuda_graph=True) logger.info("CAPTURE MODEL END") end_time = time.perf_counter()