reintegrate full cudagraphs

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-26 03:57:48 +00:00
parent 78228a67ce
commit af68574e3d
3 changed files with 62 additions and 101 deletions

View File

@ -47,6 +47,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
print("SHOULDNT BE HERE DURING CAPTURE")
return F.embedding(input_, layer.weight)

View File

@ -75,15 +75,15 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
1, # MQA for the decode path
)
n = num_splits.size(0)
# logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1] // 2:
if self.runner.full_cuda_graph:
n = num_splits.size(0)
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
if self.cg_buf_num_splits is None:
self.cg_buf_num_splits = num_splits
else:
assert self.cg_buf_num_splits is not None
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
elif n <= self.cg_buf_num_splits.size(0):
assert self.cg_buf_tile_scheduler_metadata is not None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata.size() ==

View File

@ -221,7 +221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_cuda_graph = (self.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
# self.use_cuda_graph = True
self.use_cuda_graph = True
logger.info(f"self.use_cuda_graph {self.use_cuda_graph}")
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
@ -1691,89 +1691,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for thread in ubatch_threads:
thread.join()
torch.cuda.synchronize()
# torch.cuda.synchronize()
torch.cuda.set_stream(root_stream)
sorted_results = [value for position, value in sorted(results)]
return torch.cat(sorted_results, dim=0)
# def _run_for_real(input_ids,
# positions,
# intermediate_tensors,
# input_embeds,
# attn_metadata,
# num_tokens_across_dp,
# token_slices,
# skip_cuda_graphs):
# # run micro-batched
# if len(token_slices) > 1:
# assert len(token_slices) == 2
# # num_tokens = ubatch_slices[1][1].stop
# # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
# # assert not is_dummy_run
# model_output = _run_ubatches(token_slices,
# attn_metadata,
# is_dummy_run,
# num_tokens_across_dp=num_tokens_across_dp)
# # run single batch
# else:
# # print("RUN NORMAL")
# num_tokens = token_slices[0].stop - token_slices[0].start
# if num_tokens == 0:
# num_tokens = 1
# model_output = _run(
# token_slices[0],
# set_forward_context(attn_metadata,
# vllm_config=self.vllm_config,
# num_tokens=num_tokens,
# num_tokens_across_dp=num_tokens_across_dp,
# skip_cuda_graphs=skip_cuda_graphs),
# is_dummy_run)
# return model_output
# num_tokens = token_slice[0].stop - token_slice[1].start
# # We have multiple sets of inputs here which is a bummer.
# # We'll need to pass around a bunch of lists which super sucks.
# input_ids, positions, inputs_embeds, intermediate_tensors = \
# model_inputs(token_slice, use_dummy_input)
# if build_cuda_graph and num_tokens not in self.cudagraphs:
# print(f"Capturing for {num_tokens}")
# # assert use_dummy_input
# using_ubaching = ubatch_slices is not None
# assert using_ubaching
# self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph(),
# using_ubatching=using_ubaching)
# with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph):
# # TODO (Sage) I assume we can just get these before calling this function
# # Args to delete:
# # attn_metadata
# # skip_cudagraphs
# model_output = self._run_for_real(
# input_ids=input_ids,
# positions=positions,
# intermediate_tensors=intermediate_tensors,
# inputs_embeds=inputs_embeds,
# attn_metadata=attn_metadata,
# num_tokens_across_dp=num_tokens_across_dp,
# skip_cuda_graphs=skip_cuda_graphs
# )
# self.cudagraphs[num_tokens].outputs = model_output
# elif num_tokens in self.cudagraphs and not skip_cuda_graphs:
# logger.info("GRAPH REPLAY")
# assert self.cudagraphs[num_tokens].using_ubatching == ubatch_slices is not None
# self.cudagraphs[num_tokens].cudagraph.replay()
# model_output = self.cudagraphs[num_tokens].outputs
# else:
# # TODO (Sage) We need to figure out how to move some of this context management
# # logic outside of the graph capture
# model_output = self._run_for_real(
# input_ids=input_ids,
# positions=positions,
# intermediate_tensors=intermediate_tensors,
# inputs_embeds=inputs_embeds,
# attn_metadata=attn_metadata,
# num_tokens_across_dp=num_tokens_across_dp,
# skip_cuda_graphs=skip_cuda_graphs
# )
# run micro-batched
if ubatch_slices is not None:
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
@ -1786,25 +1708,63 @@ class GPUModelRunner(LoRAModelRunnerMixin):
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp
)
model_output = _run_ubatches(ubatch_metadata)
if num_scheduled_tokens not in self.cudagraphs \
and not skip_cuda_graphs and build_cuda_graph:
# DO capture
self.cudagraphs[num_scheduled_tokens] = \
CUDAGraphMetaData(
cudagraph=torch.cuda.CUDAGraph(),
using_ubatching=True
)
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
model_output = _run_ubatches(ubatch_metadata)
self.cudagraphs[num_scheduled_tokens].outputs = model_output
return self.cudagraphs[num_scheduled_tokens].outputs
elif num_scheduled_tokens in self.cudagraphs:
self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
return self.cudagraphs[num_scheduled_tokens].outputs
else:
return _run_ubatches(ubatch_metadata)
# run single batch
else:
# print("RUN NORMAL")
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
model_output = _run(
context = set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs),
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
)
return model_output
if num_scheduled_tokens not in self.cudagraphs \
and not skip_cuda_graphs and build_cuda_graph:
self.cudagraphs[num_scheduled_tokens] = \
CUDAGraphMetaData(
cudagraph=torch.cuda.CUDAGraph(),
using_ubatching=False
)
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
model_output = _run(
context = set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs),
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
)
self.cudagraphs[num_scheduled_tokens].outputs = model_output
return self.cudagraphs[num_scheduled_tokens].outputs
elif num_scheduled_tokens in self.cudagraphs:
self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
return self.cudagraphs[num_scheduled_tokens].outputs
else:
return _run(
context = set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs),
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
)
@torch.inference_mode()
def execute_model(