diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 476e11bb4cebe..38932ff283d6a 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -76,7 +76,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): ) n = num_splits.size(0) - logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}") + # logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}") if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1] // 2: # First time around (CUDAGraph capture), allocate the static buffer if self.cg_buf_tile_scheduler_metadata is None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index be9136fef240e..2957cfb1c85b0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -95,6 +95,7 @@ import dataclasses @dataclasses.dataclass class CUDAGraphMetaData: cudagraph: torch.cuda.CUDAGraph + using_ubatching: bool outputs: Optional[Any] = None class GPUModelRunner(LoRAModelRunnerMixin): @@ -220,7 +221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.use_cuda_graph = (self.compilation_config.level == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) - self.use_cuda_graph = True + # self.use_cuda_graph = True logger.info(f"self.use_cuda_graph {self.use_cuda_graph}") # TODO(woosuk): Provide an option to tune the max cudagraph batch size. # The convention is different. @@ -230,7 +231,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): reversed(self.compilation_config.cudagraph_capture_sizes)) logger.info(f"cudagraph capture sizes {self.cudagraph_batch_sizes}") self.full_cuda_graph = self.compilation_config.full_cuda_graph - self.full_cuda_graph = True + # self.full_cuda_graph = True logger.info(f"full_cuda_graph {self.full_cuda_graph}") # Cache the device properties. @@ -1567,8 +1568,47 @@ class GPUModelRunner(LoRAModelRunnerMixin): skip_cuda_graphs: bool = False, build_cuda_graph: bool = False): + @dataclasses.dataclass + class UbatchMetadata: + ubatch_id: int + context: UBatchContext + ubatch_slice: UbatchSlice + + input_ids: torch.Tensor + positions: torch.Tensor + inputs_embeds: Optional[torch.Tensor] + intermediate_tensors: Optional[IntermediateTensors] + + num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1 + def _make_ubatch_contexts(ubatch_slices, + attn_metadata, + is_dummy_run, + num_tokens_across_dp) -> list[UBatchContext]: + ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices), + compute_stream=current_stream(), + device=self.device) + + for i, (_, tokens_slice) in enumerate(ubatch_slices): + is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start + assert not is_dummy_ubatch or i == len( + ubatch_slices) - 1 or is_dummy_run + + num_tokens = num_dummy_tokens if is_dummy_ubatch or \ + is_dummy_run else (tokens_slice.stop - tokens_slice.start) + # TODO (Sage) Instead of using this setter we should be able + # to just create the forward context in advance and pass it + # to the UBatchContext's __init__ method + ubatch_ctxs[i].forward_context = create_forward_context( + attn_metadata[i] + if attn_metadata is not None else None, + self.vllm_config, + num_tokens=num_tokens, + num_tokens_across_dp=num_tokens_across_dp, + skip_cuda_graphs=skip_cuda_graphs) + return ubatch_ctxs + def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple: if use_dummy_input: # print("MAKING DUMMY BATCH") @@ -1580,41 +1620,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _run(token_slice: slice, context, - use_dummy_input: bool = False, - build_cuda_graph: bool = False): + use_dummy_input: bool = False): input_ids, positions, inputs_embeds, intermediate_tensors = \ model_inputs(token_slice, use_dummy_input) with context: - # model_output = self.model( - # input_ids=input_ids, - # positions=positions, - # intermediate_tensors=intermediate_tensors, - # inputs_embeds=inputs_embeds, - # ) - num_tokens = token_slice.stop - token_slice.start - if build_cuda_graph and num_tokens not in self.cudagraphs: - print(f"Capturing for {num_tokens}") - assert use_dummy_input - self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph()) - with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph): - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - self.cudagraphs[num_tokens].outputs = model_output - elif num_tokens in self.cudagraphs and not skip_cuda_graphs: - logger.info("GRAPH REPLAY") - self.cudagraphs[num_tokens].cudagraph.replay() - model_output = self.cudagraphs[num_tokens].outputs - else: - model_output = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) + model_output = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) if isinstance(context, UBatchContext): # Clone before we leave the ubatch context model_output = model_output.clone() @@ -1623,45 +1638,38 @@ class GPUModelRunner(LoRAModelRunnerMixin): @torch.inference_mode() def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, - use_dummy_input, build_cuda_graph): + use_dummy_input): # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) - model_output = _run(token_slice, ubatch_ctx, use_dummy_input, build_cuda_graph) + model_output = _run(token_slice, ubatch_ctx, use_dummy_input) if save_results: results.append((ubatch_ctx.id, model_output)) # print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) def _run_ubatches(ubatch_slices, attn_metadata, - is_dummy_run, num_tokens_across_dp, build_cuda_graph) -> torch.Tensor: + is_dummy_run, num_tokens_across_dp) -> torch.Tensor: results: list[tuple[int, torch.Tensor]] = [] assert len(ubatch_slices) == 2, "Only two ubatches has been tested" root_stream = current_stream() - ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices), - compute_stream=root_stream, - device=self.device) + ubatch_ctxs = _make_ubatch_contexts(ubatch_slices=ubatch_slices, + attn_metadata=attn_metadata, + is_dummy_run=is_dummy_run, + num_tokens_across_dp=num_tokens_across_dp) # Ubatches will manually manage the forward context, so we override # it to None here so we can have it restored correctly later with override_forward_context(None): ubatch_threads = [] for i, (_, tokens_slice) in enumerate(ubatch_slices): + # TODO (Sage) Consolidate all of this is_dummy_run + # is_dummy_ubatch, is attn_metadata==None, num_tokens==0 + # nonsense into some unified structure. It's way to hard + # to keep track of and keep consistent right now. is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start assert not is_dummy_ubatch or i == len( ubatch_slices) - 1 or is_dummy_run - num_tokens = num_dummy_tokens if is_dummy_ubatch or \ - is_dummy_run else (tokens_slice.stop - tokens_slice.start) - # if num_tokens_across_dp is None: - # print(f"GOING TO CALL AR: {i}") - ubatch_ctxs[i].forward_context = create_forward_context( - attn_metadata[i] - if attn_metadata is not None else None, - self.vllm_config, - num_tokens=num_tokens, - num_tokens_across_dp=num_tokens_across_dp, - skip_cuda_graphs=skip_cuda_graphs) - thread = threading.Thread(target=_ubatch_thread, args=( ubatch_ctxs[i], @@ -1670,8 +1678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): not is_dummy_ubatch or is_dummy_run, is_dummy_ubatch - or is_dummy_run, - build_cuda_graph + or is_dummy_run )) ubatch_threads.append(thread) thread.start() @@ -1685,6 +1692,84 @@ class GPUModelRunner(LoRAModelRunnerMixin): sorted_results = [value for position, value in sorted(results)] return torch.cat(sorted_results, dim=0) + # def _run_for_real(input_ids, + # positions, + # intermediate_tensors, + # input_embeds, + # attn_metadata, + # num_tokens_across_dp, + # token_slices, + # skip_cuda_graphs): + # # run micro-batched + # if len(token_slices) > 1: + # assert len(token_slices) == 2 + # # num_tokens = ubatch_slices[1][1].stop + # # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}") + # # assert not is_dummy_run + # model_output = _run_ubatches(token_slices, + # attn_metadata, + # is_dummy_run, + # num_tokens_across_dp=num_tokens_across_dp) + # # run single batch + # else: + # # print("RUN NORMAL") + # num_tokens = token_slices[0].stop - token_slices[0].start + # if num_tokens == 0: + # num_tokens = 1 + # model_output = _run( + # token_slices[0], + # set_forward_context(attn_metadata, + # vllm_config=self.vllm_config, + # num_tokens=num_tokens, + # num_tokens_across_dp=num_tokens_across_dp, + # skip_cuda_graphs=skip_cuda_graphs), + # is_dummy_run) + # return model_output + + # num_tokens = token_slice[0].stop - token_slice[1].start + # # We have multiple sets of inputs here which is a bummer. + # # We'll need to pass around a bunch of lists which super sucks. + # input_ids, positions, inputs_embeds, intermediate_tensors = \ + # model_inputs(token_slice, use_dummy_input) + # if build_cuda_graph and num_tokens not in self.cudagraphs: + # print(f"Capturing for {num_tokens}") + # # assert use_dummy_input + # using_ubaching = ubatch_slices is not None + # assert using_ubaching + # self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph(), + # using_ubatching=using_ubaching) + # with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph): + # # TODO (Sage) I assume we can just get these before calling this function + # # Args to delete: + # # attn_metadata + # # skip_cudagraphs + # model_output = self._run_for_real( + # input_ids=input_ids, + # positions=positions, + # intermediate_tensors=intermediate_tensors, + # inputs_embeds=inputs_embeds, + # attn_metadata=attn_metadata, + # num_tokens_across_dp=num_tokens_across_dp, + # skip_cuda_graphs=skip_cuda_graphs + # ) + # self.cudagraphs[num_tokens].outputs = model_output + # elif num_tokens in self.cudagraphs and not skip_cuda_graphs: + # logger.info("GRAPH REPLAY") + # assert self.cudagraphs[num_tokens].using_ubatching == ubatch_slices is not None + # self.cudagraphs[num_tokens].cudagraph.replay() + # model_output = self.cudagraphs[num_tokens].outputs + # else: + # # TODO (Sage) We need to figure out how to move some of this context management + # # logic outside of the graph capture + # model_output = self._run_for_real( + # input_ids=input_ids, + # positions=positions, + # intermediate_tensors=intermediate_tensors, + # inputs_embeds=inputs_embeds, + # attn_metadata=attn_metadata, + # num_tokens_across_dp=num_tokens_across_dp, + # skip_cuda_graphs=skip_cuda_graphs + # ) # run micro-batched if ubatch_slices is not None: # num_tokens = ubatch_slices[1][1].stop @@ -1693,8 +1778,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): model_output = _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run, - num_tokens_across_dp=num_tokens_across_dp, - build_cuda_graph=build_cuda_graph) + num_tokens_across_dp=num_tokens_across_dp) # run single batch else: # print("RUN NORMAL") @@ -1705,8 +1789,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_tokens=num_scheduled_tokens or 1, num_tokens_across_dp=num_tokens_across_dp, skip_cuda_graphs=skip_cuda_graphs), - is_dummy_run, - build_cuda_graph=build_cuda_graph) + is_dummy_run) return model_output