initial full cudagraphs support. normal runs are working. ubatching does not

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-25 19:14:31 +00:00
parent 97dbafaad6
commit 144b148de2
5 changed files with 79 additions and 24 deletions

View File

@ -157,6 +157,7 @@ def _support_torch_compile(
vllm_config.compilation_config.level in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo()
self.do_not_compile = True
if self.do_not_compile:
return
compilation_counter.num_models_seen += 1

View File

@ -4518,6 +4518,11 @@ class VllmConfig:
"Piecewise compilation is not supported with "
"microbatching. Disabling piecewiseching compilation.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION
if not self.model_config.enforce_eager:
self.compilation_config.full_cuda_graph = True
logger.warning_once(
"Enabling fullcudagraphs for microbatching"
)
if (self.kv_events_config is not None
and self.kv_events_config.enable_kv_cache_events

View File

@ -903,7 +903,7 @@ def fused_topk(
# This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
# @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,

View File

@ -75,7 +75,9 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1, # MQA for the decode path
)
if self.runner.full_cuda_graph:
n = num_splits.size(0)
logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1]:
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
@ -92,6 +94,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]

View File

@ -91,6 +91,12 @@ UbatchSlice: TypeAlias = tuple[slice, slice]
UBatchSlices: TypeAlias = list[UbatchSlice]
import dataclasses
@dataclasses.dataclass
class CUDAGraphMetaData:
cudagraph: torch.cuda.CUDAGraph
outputs: Optional[Any] = None
class GPUModelRunner(LoRAModelRunnerMixin):
def __init__(
@ -132,6 +138,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
self.cudagraphs = {}
# Model-related.
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
@ -213,6 +220,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_cuda_graph = (self.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
self.use_cuda_graph = True
logger.info(f"self.use_cuda_graph {self.use_cuda_graph}")
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
@ -222,6 +230,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
reversed(self.compilation_config.cudagraph_capture_sizes))
self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.full_cuda_graph = True
logger.info(f"full_cuda_graph {self.full_cuda_graph}")
# Cache the device properties.
self._init_device_properties()
@ -1553,7 +1563,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: Optional["SchedulerOutput"] = None,
is_dummy_run: bool = False,
num_tokens_across_dp: Optional[torch.Tensor] = None,
skip_cuda_graphs: bool = False):
skip_cuda_graphs: bool = False,
build_cuda_graph: bool = False):
num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1
@ -1566,16 +1577,43 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert scheduler_output is not None
return self._get_model_inputs(tokens_slice, scheduler_output)
def _run(token_slice: slice, context, use_dummy_input: bool = False):
def _run(token_slice: slice,
context,
use_dummy_input: bool = False,
build_cuda_graph: bool = False):
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(token_slice, use_dummy_input)
with context:
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
# model_output = self.model(
# input_ids=input_ids,
# positions=positions,
# intermediate_tensors=intermediate_tensors,
# inputs_embeds=inputs_embeds,
# )
num_tokens = token_slice.stop - token_slice.start
if build_cuda_graph and num_tokens not in self.cudagraphs:
print(f"Capturing for {num_tokens}")
assert use_dummy_input
self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph())
with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph):
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
self.cudagraphs[num_tokens].outputs = model_output
elif num_tokens in self.cudagraphs and not skip_cuda_graphs:
logger.info("GRAPH REPLAY")
self.cudagraphs[num_tokens].cudagraph.replay()
model_output = self.cudagraphs[num_tokens].outputs
else:
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if isinstance(context, UBatchContext):
# Clone before we leave the ubatch context
model_output = model_output.clone()
@ -1584,16 +1622,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode()
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
use_dummy_input):
use_dummy_input, build_cuda_graph):
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
model_output = _run(token_slice, ubatch_ctx, use_dummy_input, build_cuda_graph)
if save_results:
results.append((ubatch_ctx.id, model_output))
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_slices, attn_metadata,
is_dummy_run, num_tokens_across_dp) -> torch.Tensor:
is_dummy_run, num_tokens_across_dp, build_cuda_graph) -> torch.Tensor:
results: list[tuple[int, torch.Tensor]] = []
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
root_stream = current_stream()
@ -1632,6 +1670,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
or is_dummy_run,
is_dummy_ubatch
or is_dummy_run,
build_cuda_graph
))
ubatch_threads.append(thread)
thread.start()
@ -1650,8 +1689,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# num_tokens = ubatch_slices[1][1].stop
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
# assert not is_dummy_run
model_output = _run_ubatches(ubatch_slices, attn_metadata,
is_dummy_run, num_tokens_across_dp=num_tokens_across_dp)
model_output = _run_ubatches(ubatch_slices,
attn_metadata,
is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp,
build_cuda_graph=build_cuda_graph)
# run single batch
else:
# print("RUN NORMAL")
@ -1662,7 +1704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs),
is_dummy_run)
is_dummy_run,
build_cuda_graph=build_cuda_graph)
return model_output
@ -2222,6 +2265,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# For profiling runs we dont want microbatching but for
# dp dummy runs we do.
allow_microbatching: bool = False,
build_cuda_graph: bool = False
) -> torch.Tensor:
if allow_microbatching:
@ -2239,7 +2283,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# _dummy_run doesn't go through _prepare_inputs so
# we synchronize with other DP ranks here
should_ubatch = self.should_ubatch(allow_microbatching)
assert not should_ubatch
# Padding for DP
# logger.info("PADDING DUMMY")
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
@ -2295,10 +2338,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# over a certain threshold.
if should_ubatch:
assert num_tokens % 2 == 0
# TODO (Sage) Add actual slices here
assert False
dummy_microbatches = [(slice(0, 0), slice(0, 0)),
(slice(0, 0), slice(0, 0))]
dummy_microbatches = [(slice(0, num_tokens // 2),
slice(0, num_tokens // 2)),
(slice(num_tokens // 2, num_tokens),
slice(num_tokens // 2, num_tokens))]
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
@ -2307,7 +2350,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens,
ubatch_slices=dummy_microbatches,
is_dummy_run=True,
num_tokens_across_dp=num_tokens_across_dp
num_tokens_across_dp=num_tokens_across_dp,
build_cuda_graph=build_cuda_graph
)
if self.use_aux_hidden_state_outputs:
hidden_states, _ = outputs
@ -2505,10 +2549,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.compilation_config.cudagraph_num_of_warmups):
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
allow_microbatching=allow_microbatching)
allow_microbatching=allow_microbatching,
build_cuda_graph=True)
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
allow_microbatching=allow_microbatching)
allow_microbatching=allow_microbatching,
build_cuda_graph=True)
logger.info("CAPTURE MODEL END")
end_time = time.perf_counter()