capture works replay does not

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-28 19:14:48 +00:00
parent af68574e3d
commit 4672c72f44
6 changed files with 121 additions and 82 deletions

View File

@ -84,7 +84,7 @@ def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] * 100
] * 5
# import random
# import string
# prompts = [''.join(random.choices(string.ascii_letters, k=128)) for _ in range(2048)]

View File

@ -155,7 +155,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1 = a1 * rank_topk_weights.to(a1.dtype)
# Dispatch
yield_and_switch_from_compute_to_comm_impl(schedule="default")
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
expert_x, expert_num_tokens, handle, event, hook = \
self.buffers[a2a_idx].low_latency_dispatch(a1,
rank_topk_ids,
@ -165,7 +165,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
async_finish=False,
return_recv_hook=False)
self.handles[a2a_idx] = handle
yield_and_switch_from_comm_to_compute_impl(schedule="default")
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
a1.dtype)
@ -188,7 +188,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights = torch.ones_like(topk_weights)
# TODO (varun) : Enable zero copy mode
yield_and_switch_from_compute_to_comm_impl(schedule="default")
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
_, event, hook = self.buffers[a2a_idx].low_latency_combine(
fused_expert_output,
topk_ids,
@ -199,5 +199,5 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return_recv_hook=False,
out=output)
# event.current_stream_wait()
yield_and_switch_from_comm_to_compute_impl(schedule="default")
# yield_and_switch_from_comm_to_compute_impl(schedule="default")

View File

@ -47,7 +47,6 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
print("SHOULDNT BE HERE DURING CAPTURE")
return F.embedding(input_, layer.weight)

View File

@ -94,7 +94,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]

View File

@ -722,6 +722,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# logger.info(f"ATTEMPTING TO PAD UBATCH {should_ubatch}")
assert should_ubatch
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}")
# logger.info("UBATCH PADDING DONE")
if num_pad_tokens > 0:
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
@ -730,7 +731,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# We bail out of ubatching here. This accounts for the case where
# the padding would result in an "empty" second ubatch.
# TODO: just make the second ubatch a dummy ubatch
logger.info("FALLING BACK AND DISABLING UBATCHING")
# logger.info("FALLING BACK AND DISABLING UBATCHING")
ubatch_bailout = True
# Note that if we are attempting to ubatch by this point then we know that no
@ -738,7 +739,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if ubatch_slices:
should_ubatch = self.should_ubatch(False if ubatch_bailout else True)
if not should_ubatch:
logger.info("SUCCESSFULLY BAILED OUT")
# logger.info("SUCCESSFULLY BAILED OUT")
num_pad_tokens = 0
num_tokens_after_padding = None
ubatch_slices = None
@ -1344,7 +1345,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_padded = num_tokens_unpadded
logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}")
# logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}")
if (self.use_cuda_graph
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
@ -1388,7 +1389,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_padded = round_up(num_tokens_unpadded, 2)
if (self.use_cuda_graph
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded)
else:
@ -1581,10 +1581,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _make_ubatch_contexts(ubatch_slices,
attn_metadata,
compute_stream,
is_dummy_run,
num_tokens_across_dp) -> list[UBatchContext]:
num_tokens_across_dp,
skip_cuda_graphs) -> list[UBatchContext]:
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
compute_stream=current_stream(),
compute_stream=compute_stream,
device=self.device)
for i, (_, tokens_slice) in enumerate(ubatch_slices):
@ -1617,13 +1619,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _make_ubatch_metadata(ubatch_slices,
attn_metadata,
compute_stream,
is_dummy_run,
num_tokens_across_dp) -> list[UbatchMetadata]:
num_tokens_across_dp,
skip_cuda_graphs) -> list[UbatchMetadata]:
ubatch_ctxs = _make_ubatch_contexts(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
compute_stream=compute_stream,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs
)
# First get some inputs
ubatch_metadata: list[UbatchMetadata] = []
@ -1644,8 +1650,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_ids,
positions,
inputs_embeds,
intermediate_tensors):
intermediate_tensors,
start_signal=None):
with context:
if start_signal is not None:
start_signal.wait()
model_output = self.model(
input_ids=input_ids,
positions=positions,
@ -1659,42 +1668,64 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return model_output
@torch.inference_mode()
def _ubatch_thread(results, ubatch_metadata):
def _ubatch_thread(results, ubatch_metadata, start_signal):
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
context = ubatch_metadata.context
with torch.cuda.stream(context.compute_stream):
_ = torch.cuda.current_blas_handle()
with torch.cuda.stream(context.comm_stream):
_ = torch.cuda.current_blas_handle()
model_output = _run(context=ubatch_metadata.context,
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
inputs_embeds=ubatch_metadata.inputs_embeds,
intermediate_tensors=ubatch_metadata.intermediate_tensors)
intermediate_tensors=ubatch_metadata.intermediate_tensors,
start_signal=start_signal)
results.append((ubatch_metadata.context.id, model_output))
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_metadata) -> torch.Tensor:
def _run_ubatches(ubatch_metadata, num_tokens, should_capture=False) -> torch.Tensor:
results: list[tuple[int, torch.Tensor]] = []
root_stream = current_stream()
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
ubatch_threads = []
start_signals = []
for metadata in ubatch_metadata:
start_signal = threading.Event()
thread = threading.Thread(target=_ubatch_thread,
args=(
results,
metadata,
start_signal,
))
ubatch_threads.append(thread)
thread.start()
ubatch_metadata[0].context.cpu_wait_event.set()
start_signals.append(start_signal)
for thread in ubatch_threads:
thread.join()
# torch.cuda.synchronize()
torch.cuda.set_stream(root_stream)
sorted_results = [value for position, value in sorted(results)]
return torch.cat(sorted_results, dim=0)
# DO capture
cudagraph_metadata = \
CUDAGraphMetaData(
cudagraph=torch.cuda.CUDAGraph(),
using_ubatching=True
)
with torch.cuda.graph(cudagraph_metadata.cudagraph,
stream=compute_stream):
# logger.info("STARTING WAKEUP LOOP")
for start_signal in start_signals:
start_signal.set()
# logger.info("FINISHED WAKEUP LOOP")
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
cudagraph_metadata.outputs = result
logger.info(f"Capturing for {num_tokens} tokens")
self.cudagraphs[num_tokens] = cudagraph_metadata
return cudagraph_metadata.outputs
# run micro-batched
if ubatch_slices is not None:
@ -1702,69 +1733,71 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# num_tokens = ubatch_slices[1][1].stop
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
# assert not is_dummy_run
compute_stream = torch.cuda.Stream(device=self.device)
ubatch_metadata = _make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
compute_stream=compute_stream,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs
)
if num_scheduled_tokens not in self.cudagraphs \
and not skip_cuda_graphs and build_cuda_graph:
# DO capture
self.cudagraphs[num_scheduled_tokens] = \
CUDAGraphMetaData(
cudagraph=torch.cuda.CUDAGraph(),
using_ubatching=True
)
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
model_output = _run_ubatches(ubatch_metadata)
self.cudagraphs[num_scheduled_tokens].outputs = model_output
return self.cudagraphs[num_scheduled_tokens].outputs
elif num_scheduled_tokens in self.cudagraphs:
self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
return self.cudagraphs[num_scheduled_tokens].outputs
return _run_ubatches(ubatch_metadata, num_scheduled_tokens, should_capture=True)
elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
cudagraph_metadata.cudagraph.replay()
return cudagraph_metadata.outputs
else:
assert False
return _run_ubatches(ubatch_metadata)
# run single batch
else:
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
if num_scheduled_tokens not in self.cudagraphs \
and not skip_cuda_graphs and build_cuda_graph:
self.cudagraphs[num_scheduled_tokens] = \
CUDAGraphMetaData(
cudagraph=torch.cuda.CUDAGraph(),
using_ubatching=False
)
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
model_output = _run(
context = set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs),
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
)
self.cudagraphs[num_scheduled_tokens].outputs = model_output
return self.cudagraphs[num_scheduled_tokens].outputs
elif num_scheduled_tokens in self.cudagraphs:
self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
return self.cudagraphs[num_scheduled_tokens].outputs
else:
return _run(
context = set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs),
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
)
# if num_scheduled_tokens not in self.cudagraphs \
# and not skip_cuda_graphs and build_cuda_graph:
# assert False
# logger.info(f"GRAPH BUILD{num_scheduled_tokens}")
# self.cudagraphs[num_scheduled_tokens] = \
# CUDAGraphMetaData(
# cudagraph=torch.cuda.CUDAGraph(),
# using_ubatching=False
# )
# with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
# model_output = _run(
# context = set_forward_context(attn_metadata,
# vllm_config=self.vllm_config,
# num_tokens=num_scheduled_tokens or 1,
# num_tokens_across_dp=num_tokens_across_dp,
# skip_cuda_graphs=skip_cuda_graphs),
# input_ids=input_ids,
# positions=positions,
# inputs_embeds=inputs_embeds,
# intermediate_tensors=intermediate_tensors
# )
# self.cudagraphs[num_scheduled_tokens].outputs = model_output
# return self.cudagraphs[num_scheduled_tokens].outputs
# elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
# assert False
# # logger.info(f"GRAPH REPLAY {num_scheduled_tokens}")
# self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
# return self.cudagraphs[num_scheduled_tokens].outputs
# else:
# logger.info(f"NORMAL RUN {num_scheduled_tokens}")
return _run(
context = set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs),
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
)
@torch.inference_mode()
def execute_model(
@ -2325,8 +2358,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
build_cuda_graph: bool = False
) -> torch.Tensor:
if allow_microbatching:
logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN")
# if allow_microbatching:
# logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN")
# TODO(Sage) We need some more code to properly handle
@ -2339,6 +2372,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
allow_microbatching and capture_attn_cudagraph
# _dummy_run doesn't go through _prepare_inputs so
# we synchronize with other DP ranks here
# logger.info(f"NUM TOKENS {num_tokens} SHOULD UBATCH {should_ubatch}")
should_ubatch = self.should_ubatch(allow_microbatching)
# Padding for DP
# logger.info("PADDING DUMMY")

View File

@ -30,7 +30,6 @@ class UBatchContext:
self.id = id
self.comm_stream = comm_stream
self.compute_stream = compute_stream
self.original_stream = current_stream()
self.forward_context = None #fwd_ctx
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
@ -58,7 +57,7 @@ class UBatchContext:
self.cpu_signal_event.set()
self.cpu_wait_event.clear()
self.current_stream = self.compute_stream
torch.cuda.set_stream(self.original_stream)
torch.cuda.set_stream(self.current_stream)
return False
def _restore_context(self):
@ -76,26 +75,31 @@ class UBatchContext:
pass
def _signal_comm_done(self):
assert False
self.ctx_valid_state()
self.gpu_comm_done_event.record(self.comm_stream)
def _signal_compute_done(self):
assert False
self.ctx_valid_state()
self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self):
assert False
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
self.ctx_valid_state()
self.comm_stream.wait_event(self.gpu_compute_done_event)
# print("Compute stream done", flush=True)
def _wait_comm_done(self):
assert False
# print(f"{self.id} Waiting on COMM stream", flush=True)
self.ctx_valid_state()
self.compute_stream.wait_event(self.gpu_comm_done_event)
# print("Comm stream done", flush=True)
def stream_string(self):
assert False
if current_stream() == self.compute_stream:
assert self.current_stream == self.compute_stream
return "COMPUTE"
@ -114,6 +118,7 @@ class UBatchContext:
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
def yield_and_switch_from_compute_to_comm(self):
assert False
assert current_stream() == self.compute_stream
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} "
@ -129,6 +134,7 @@ class UBatchContext:
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert False
assert current_stream() == self.comm_stream
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} "