factored out some of the context creation code along with misc commeted infra

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-25 23:16:59 +00:00
parent 44a2b3494e
commit e2ba707d64
2 changed files with 142 additions and 59 deletions

View File

@ -76,7 +76,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
)
n = num_splits.size(0)
logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
# logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1] // 2:
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:

View File

@ -95,6 +95,7 @@ import dataclasses
@dataclasses.dataclass
class CUDAGraphMetaData:
cudagraph: torch.cuda.CUDAGraph
using_ubatching: bool
outputs: Optional[Any] = None
class GPUModelRunner(LoRAModelRunnerMixin):
@ -220,7 +221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_cuda_graph = (self.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
self.use_cuda_graph = True
# self.use_cuda_graph = True
logger.info(f"self.use_cuda_graph {self.use_cuda_graph}")
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
@ -230,7 +231,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
reversed(self.compilation_config.cudagraph_capture_sizes))
logger.info(f"cudagraph capture sizes {self.cudagraph_batch_sizes}")
self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.full_cuda_graph = True
# self.full_cuda_graph = True
logger.info(f"full_cuda_graph {self.full_cuda_graph}")
# Cache the device properties.
@ -1567,8 +1568,47 @@ class GPUModelRunner(LoRAModelRunnerMixin):
skip_cuda_graphs: bool = False,
build_cuda_graph: bool = False):
@dataclasses.dataclass
class UbatchMetadata:
ubatch_id: int
context: UBatchContext
ubatch_slice: UbatchSlice
input_ids: torch.Tensor
positions: torch.Tensor
inputs_embeds: Optional[torch.Tensor]
intermediate_tensors: Optional[IntermediateTensors]
num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1
def _make_ubatch_contexts(ubatch_slices,
attn_metadata,
is_dummy_run,
num_tokens_across_dp) -> list[UBatchContext]:
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
compute_stream=current_stream(),
device=self.device)
for i, (_, tokens_slice) in enumerate(ubatch_slices):
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
assert not is_dummy_ubatch or i == len(
ubatch_slices) - 1 or is_dummy_run
num_tokens = num_dummy_tokens if is_dummy_ubatch or \
is_dummy_run else (tokens_slice.stop - tokens_slice.start)
# TODO (Sage) Instead of using this setter we should be able
# to just create the forward context in advance and pass it
# to the UBatchContext's __init__ method
ubatch_ctxs[i].forward_context = create_forward_context(
attn_metadata[i]
if attn_metadata is not None else None,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs)
return ubatch_ctxs
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
if use_dummy_input:
# print("MAKING DUMMY BATCH")
@ -1580,41 +1620,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _run(token_slice: slice,
context,
use_dummy_input: bool = False,
build_cuda_graph: bool = False):
use_dummy_input: bool = False):
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(token_slice, use_dummy_input)
with context:
# model_output = self.model(
# input_ids=input_ids,
# positions=positions,
# intermediate_tensors=intermediate_tensors,
# inputs_embeds=inputs_embeds,
# )
num_tokens = token_slice.stop - token_slice.start
if build_cuda_graph and num_tokens not in self.cudagraphs:
print(f"Capturing for {num_tokens}")
assert use_dummy_input
self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph())
with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph):
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
self.cudagraphs[num_tokens].outputs = model_output
elif num_tokens in self.cudagraphs and not skip_cuda_graphs:
logger.info("GRAPH REPLAY")
self.cudagraphs[num_tokens].cudagraph.replay()
model_output = self.cudagraphs[num_tokens].outputs
else:
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if isinstance(context, UBatchContext):
# Clone before we leave the ubatch context
model_output = model_output.clone()
@ -1623,45 +1638,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode()
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
use_dummy_input, build_cuda_graph):
use_dummy_input):
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
model_output = _run(token_slice, ubatch_ctx, use_dummy_input, build_cuda_graph)
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
if save_results:
results.append((ubatch_ctx.id, model_output))
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_slices, attn_metadata,
is_dummy_run, num_tokens_across_dp, build_cuda_graph) -> torch.Tensor:
is_dummy_run, num_tokens_across_dp) -> torch.Tensor:
results: list[tuple[int, torch.Tensor]] = []
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
root_stream = current_stream()
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
compute_stream=root_stream,
device=self.device)
ubatch_ctxs = _make_ubatch_contexts(ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp)
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
ubatch_threads = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
# TODO (Sage) Consolidate all of this is_dummy_run
# is_dummy_ubatch, is attn_metadata==None, num_tokens==0
# nonsense into some unified structure. It's way to hard
# to keep track of and keep consistent right now.
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
assert not is_dummy_ubatch or i == len(
ubatch_slices) - 1 or is_dummy_run
num_tokens = num_dummy_tokens if is_dummy_ubatch or \
is_dummy_run else (tokens_slice.stop - tokens_slice.start)
# if num_tokens_across_dp is None:
# print(f"GOING TO CALL AR: {i}")
ubatch_ctxs[i].forward_context = create_forward_context(
attn_metadata[i]
if attn_metadata is not None else None,
self.vllm_config,
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs)
thread = threading.Thread(target=_ubatch_thread,
args=(
ubatch_ctxs[i],
@ -1670,8 +1678,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
not is_dummy_ubatch
or is_dummy_run,
is_dummy_ubatch
or is_dummy_run,
build_cuda_graph
or is_dummy_run
))
ubatch_threads.append(thread)
thread.start()
@ -1685,6 +1692,84 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sorted_results = [value for position, value in sorted(results)]
return torch.cat(sorted_results, dim=0)
# def _run_for_real(input_ids,
# positions,
# intermediate_tensors,
# input_embeds,
# attn_metadata,
# num_tokens_across_dp,
# token_slices,
# skip_cuda_graphs):
# # run micro-batched
# if len(token_slices) > 1:
# assert len(token_slices) == 2
# # num_tokens = ubatch_slices[1][1].stop
# # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
# # assert not is_dummy_run
# model_output = _run_ubatches(token_slices,
# attn_metadata,
# is_dummy_run,
# num_tokens_across_dp=num_tokens_across_dp)
# # run single batch
# else:
# # print("RUN NORMAL")
# num_tokens = token_slices[0].stop - token_slices[0].start
# if num_tokens == 0:
# num_tokens = 1
# model_output = _run(
# token_slices[0],
# set_forward_context(attn_metadata,
# vllm_config=self.vllm_config,
# num_tokens=num_tokens,
# num_tokens_across_dp=num_tokens_across_dp,
# skip_cuda_graphs=skip_cuda_graphs),
# is_dummy_run)
# return model_output
# num_tokens = token_slice[0].stop - token_slice[1].start
# # We have multiple sets of inputs here which is a bummer.
# # We'll need to pass around a bunch of lists which super sucks.
# input_ids, positions, inputs_embeds, intermediate_tensors = \
# model_inputs(token_slice, use_dummy_input)
# if build_cuda_graph and num_tokens not in self.cudagraphs:
# print(f"Capturing for {num_tokens}")
# # assert use_dummy_input
# using_ubaching = ubatch_slices is not None
# assert using_ubaching
# self.cudagraphs[num_tokens] = CUDAGraphMetaData(cudagraph=torch.cuda.CUDAGraph(),
# using_ubatching=using_ubaching)
# with torch.cuda.graph(self.cudagraphs[num_tokens].cudagraph):
# # TODO (Sage) I assume we can just get these before calling this function
# # Args to delete:
# # attn_metadata
# # skip_cudagraphs
# model_output = self._run_for_real(
# input_ids=input_ids,
# positions=positions,
# intermediate_tensors=intermediate_tensors,
# inputs_embeds=inputs_embeds,
# attn_metadata=attn_metadata,
# num_tokens_across_dp=num_tokens_across_dp,
# skip_cuda_graphs=skip_cuda_graphs
# )
# self.cudagraphs[num_tokens].outputs = model_output
# elif num_tokens in self.cudagraphs and not skip_cuda_graphs:
# logger.info("GRAPH REPLAY")
# assert self.cudagraphs[num_tokens].using_ubatching == ubatch_slices is not None
# self.cudagraphs[num_tokens].cudagraph.replay()
# model_output = self.cudagraphs[num_tokens].outputs
# else:
# # TODO (Sage) We need to figure out how to move some of this context management
# # logic outside of the graph capture
# model_output = self._run_for_real(
# input_ids=input_ids,
# positions=positions,
# intermediate_tensors=intermediate_tensors,
# inputs_embeds=inputs_embeds,
# attn_metadata=attn_metadata,
# num_tokens_across_dp=num_tokens_across_dp,
# skip_cuda_graphs=skip_cuda_graphs
# )
# run micro-batched
if ubatch_slices is not None:
# num_tokens = ubatch_slices[1][1].stop
@ -1693,8 +1778,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
model_output = _run_ubatches(ubatch_slices,
attn_metadata,
is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp,
build_cuda_graph=build_cuda_graph)
num_tokens_across_dp=num_tokens_across_dp)
# run single batch
else:
# print("RUN NORMAL")
@ -1705,8 +1789,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs),
is_dummy_run,
build_cuda_graph=build_cuda_graph)
is_dummy_run)
return model_output