capture works replay does not

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-28 19:14:48 +00:00
parent af68574e3d
commit 4672c72f44
6 changed files with 121 additions and 82 deletions

View File

@ -84,7 +84,7 @@ def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] * 100 ] * 5
# import random # import random
# import string # import string
# prompts = [''.join(random.choices(string.ascii_letters, k=128)) for _ in range(2048)] # prompts = [''.join(random.choices(string.ascii_letters, k=128)) for _ in range(2048)]

View File

@ -155,7 +155,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1 = a1 * rank_topk_weights.to(a1.dtype) a1 = a1 * rank_topk_weights.to(a1.dtype)
# Dispatch # Dispatch
yield_and_switch_from_compute_to_comm_impl(schedule="default") # yield_and_switch_from_compute_to_comm_impl(schedule="default")
expert_x, expert_num_tokens, handle, event, hook = \ expert_x, expert_num_tokens, handle, event, hook = \
self.buffers[a2a_idx].low_latency_dispatch(a1, self.buffers[a2a_idx].low_latency_dispatch(a1,
rank_topk_ids, rank_topk_ids,
@ -165,7 +165,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
async_finish=False, async_finish=False,
return_recv_hook=False) return_recv_hook=False)
self.handles[a2a_idx] = handle self.handles[a2a_idx] = handle
yield_and_switch_from_comm_to_compute_impl(schedule="default") # yield_and_switch_from_comm_to_compute_impl(schedule="default")
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale, expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
a1.dtype) a1.dtype)
@ -188,7 +188,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
combine_topk_weights = torch.ones_like(topk_weights) combine_topk_weights = torch.ones_like(topk_weights)
# TODO (varun) : Enable zero copy mode # TODO (varun) : Enable zero copy mode
yield_and_switch_from_compute_to_comm_impl(schedule="default") # yield_and_switch_from_compute_to_comm_impl(schedule="default")
_, event, hook = self.buffers[a2a_idx].low_latency_combine( _, event, hook = self.buffers[a2a_idx].low_latency_combine(
fused_expert_output, fused_expert_output,
topk_ids, topk_ids,
@ -199,5 +199,5 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return_recv_hook=False, return_recv_hook=False,
out=output) out=output)
# event.current_stream_wait() # event.current_stream_wait()
yield_and_switch_from_comm_to_compute_impl(schedule="default") # yield_and_switch_from_comm_to_compute_impl(schedule="default")

View File

@ -47,7 +47,6 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
def embedding(self, layer: torch.nn.Module, def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor: input_: torch.Tensor) -> torch.Tensor:
print("SHOULDNT BE HERE DURING CAPTURE")
return F.embedding(input_, layer.weight) return F.embedding(input_, layer.weight)

View File

@ -94,7 +94,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
# Num splits is per-batch, varying size (batch_size,) # Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0) n = num_splits.size(0)
logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}") # logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
# make sure static buffer is large enough # make sure static buffer is large enough
assert n <= self.cg_buf_num_splits.size(0) assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n] num_splits_view = self.cg_buf_num_splits[:n]

View File

@ -722,6 +722,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# logger.info(f"ATTEMPTING TO PAD UBATCH {should_ubatch}") # logger.info(f"ATTEMPTING TO PAD UBATCH {should_ubatch}")
assert should_ubatch assert should_ubatch
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices) num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}")
# logger.info("UBATCH PADDING DONE") # logger.info("UBATCH PADDING DONE")
if num_pad_tokens > 0: if num_pad_tokens > 0:
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens: if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
@ -730,7 +731,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# We bail out of ubatching here. This accounts for the case where # We bail out of ubatching here. This accounts for the case where
# the padding would result in an "empty" second ubatch. # the padding would result in an "empty" second ubatch.
# TODO: just make the second ubatch a dummy ubatch # TODO: just make the second ubatch a dummy ubatch
logger.info("FALLING BACK AND DISABLING UBATCHING") # logger.info("FALLING BACK AND DISABLING UBATCHING")
ubatch_bailout = True ubatch_bailout = True
# Note that if we are attempting to ubatch by this point then we know that no # Note that if we are attempting to ubatch by this point then we know that no
@ -738,7 +739,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if ubatch_slices: if ubatch_slices:
should_ubatch = self.should_ubatch(False if ubatch_bailout else True) should_ubatch = self.should_ubatch(False if ubatch_bailout else True)
if not should_ubatch: if not should_ubatch:
logger.info("SUCCESSFULLY BAILED OUT") # logger.info("SUCCESSFULLY BAILED OUT")
num_pad_tokens = 0 num_pad_tokens = 0
num_tokens_after_padding = None num_tokens_after_padding = None
ubatch_slices = None ubatch_slices = None
@ -1344,7 +1345,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_padded = num_tokens_unpadded num_tokens_padded = num_tokens_unpadded
logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}") # logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}")
if (self.use_cuda_graph if (self.use_cuda_graph
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs. # Use piecewise CUDA graphs.
@ -1388,7 +1389,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_padded = round_up(num_tokens_unpadded, 2) num_tokens_padded = round_up(num_tokens_unpadded, 2)
if (self.use_cuda_graph if (self.use_cuda_graph
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]): and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size. # Add padding to the batch size.
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded) num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded)
else: else:
@ -1581,10 +1581,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _make_ubatch_contexts(ubatch_slices, def _make_ubatch_contexts(ubatch_slices,
attn_metadata, attn_metadata,
compute_stream,
is_dummy_run, is_dummy_run,
num_tokens_across_dp) -> list[UBatchContext]: num_tokens_across_dp,
skip_cuda_graphs) -> list[UBatchContext]:
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices), ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
compute_stream=current_stream(), compute_stream=compute_stream,
device=self.device) device=self.device)
for i, (_, tokens_slice) in enumerate(ubatch_slices): for i, (_, tokens_slice) in enumerate(ubatch_slices):
@ -1617,13 +1619,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _make_ubatch_metadata(ubatch_slices, def _make_ubatch_metadata(ubatch_slices,
attn_metadata, attn_metadata,
compute_stream,
is_dummy_run, is_dummy_run,
num_tokens_across_dp) -> list[UbatchMetadata]: num_tokens_across_dp,
skip_cuda_graphs) -> list[UbatchMetadata]:
ubatch_ctxs = _make_ubatch_contexts( ubatch_ctxs = _make_ubatch_contexts(
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
compute_stream=compute_stream,
is_dummy_run=is_dummy_run, is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs
) )
# First get some inputs # First get some inputs
ubatch_metadata: list[UbatchMetadata] = [] ubatch_metadata: list[UbatchMetadata] = []
@ -1644,8 +1650,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_ids, input_ids,
positions, positions,
inputs_embeds, inputs_embeds,
intermediate_tensors): intermediate_tensors,
start_signal=None):
with context: with context:
if start_signal is not None:
start_signal.wait()
model_output = self.model( model_output = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
@ -1659,42 +1668,64 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return model_output return model_output
@torch.inference_mode() @torch.inference_mode()
def _ubatch_thread(results, ubatch_metadata): def _ubatch_thread(results, ubatch_metadata, start_signal):
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
context = ubatch_metadata.context
with torch.cuda.stream(context.compute_stream):
_ = torch.cuda.current_blas_handle()
with torch.cuda.stream(context.comm_stream):
_ = torch.cuda.current_blas_handle()
model_output = _run(context=ubatch_metadata.context, model_output = _run(context=ubatch_metadata.context,
input_ids=ubatch_metadata.input_ids, input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions, positions=ubatch_metadata.positions,
inputs_embeds=ubatch_metadata.inputs_embeds, inputs_embeds=ubatch_metadata.inputs_embeds,
intermediate_tensors=ubatch_metadata.intermediate_tensors) intermediate_tensors=ubatch_metadata.intermediate_tensors,
start_signal=start_signal)
results.append((ubatch_metadata.context.id, model_output)) results.append((ubatch_metadata.context.id, model_output))
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) # print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_metadata) -> torch.Tensor: def _run_ubatches(ubatch_metadata, num_tokens, should_capture=False) -> torch.Tensor:
results: list[tuple[int, torch.Tensor]] = [] results: list[tuple[int, torch.Tensor]] = []
root_stream = current_stream()
# Ubatches will manually manage the forward context, so we override # Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later # it to None here so we can have it restored correctly later
with override_forward_context(None): with override_forward_context(None):
ubatch_threads = [] ubatch_threads = []
start_signals = []
for metadata in ubatch_metadata: for metadata in ubatch_metadata:
start_signal = threading.Event()
thread = threading.Thread(target=_ubatch_thread, thread = threading.Thread(target=_ubatch_thread,
args=( args=(
results, results,
metadata, metadata,
start_signal,
)) ))
ubatch_threads.append(thread) ubatch_threads.append(thread)
thread.start() thread.start()
ubatch_metadata[0].context.cpu_wait_event.set() start_signals.append(start_signal)
for thread in ubatch_threads: # DO capture
thread.join() cudagraph_metadata = \
CUDAGraphMetaData(
# torch.cuda.synchronize() cudagraph=torch.cuda.CUDAGraph(),
torch.cuda.set_stream(root_stream) using_ubatching=True
sorted_results = [value for position, value in sorted(results)] )
return torch.cat(sorted_results, dim=0) with torch.cuda.graph(cudagraph_metadata.cudagraph,
stream=compute_stream):
# logger.info("STARTING WAKEUP LOOP")
for start_signal in start_signals:
start_signal.set()
# logger.info("FINISHED WAKEUP LOOP")
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
cudagraph_metadata.outputs = result
logger.info(f"Capturing for {num_tokens} tokens")
self.cudagraphs[num_tokens] = cudagraph_metadata
return cudagraph_metadata.outputs
# run micro-batched # run micro-batched
if ubatch_slices is not None: if ubatch_slices is not None:
@ -1702,69 +1733,71 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# num_tokens = ubatch_slices[1][1].stop # num_tokens = ubatch_slices[1][1].stop
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}") # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
# assert not is_dummy_run # assert not is_dummy_run
compute_stream = torch.cuda.Stream(device=self.device)
ubatch_metadata = _make_ubatch_metadata( ubatch_metadata = _make_ubatch_metadata(
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
compute_stream=compute_stream,
is_dummy_run=is_dummy_run, is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs
) )
if num_scheduled_tokens not in self.cudagraphs \ if num_scheduled_tokens not in self.cudagraphs \
and not skip_cuda_graphs and build_cuda_graph: and not skip_cuda_graphs and build_cuda_graph:
# DO capture return _run_ubatches(ubatch_metadata, num_scheduled_tokens, should_capture=True)
self.cudagraphs[num_scheduled_tokens] = \ elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
CUDAGraphMetaData( cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
cudagraph=torch.cuda.CUDAGraph(), logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
using_ubatching=True cudagraph_metadata.cudagraph.replay()
) return cudagraph_metadata.outputs
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
model_output = _run_ubatches(ubatch_metadata)
self.cudagraphs[num_scheduled_tokens].outputs = model_output
return self.cudagraphs[num_scheduled_tokens].outputs
elif num_scheduled_tokens in self.cudagraphs:
self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
return self.cudagraphs[num_scheduled_tokens].outputs
else: else:
assert False
return _run_ubatches(ubatch_metadata) return _run_ubatches(ubatch_metadata)
# run single batch # run single batch
else: else:
input_ids, positions, inputs_embeds, intermediate_tensors = \ input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run) model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
if num_scheduled_tokens not in self.cudagraphs \ # if num_scheduled_tokens not in self.cudagraphs \
and not skip_cuda_graphs and build_cuda_graph: # and not skip_cuda_graphs and build_cuda_graph:
self.cudagraphs[num_scheduled_tokens] = \ # assert False
CUDAGraphMetaData( # logger.info(f"GRAPH BUILD{num_scheduled_tokens}")
cudagraph=torch.cuda.CUDAGraph(), # self.cudagraphs[num_scheduled_tokens] = \
using_ubatching=False # CUDAGraphMetaData(
) # cudagraph=torch.cuda.CUDAGraph(),
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph): # using_ubatching=False
model_output = _run( # )
context = set_forward_context(attn_metadata, # with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
vllm_config=self.vllm_config, # model_output = _run(
num_tokens=num_scheduled_tokens or 1, # context = set_forward_context(attn_metadata,
num_tokens_across_dp=num_tokens_across_dp, # vllm_config=self.vllm_config,
skip_cuda_graphs=skip_cuda_graphs), # num_tokens=num_scheduled_tokens or 1,
input_ids=input_ids, # num_tokens_across_dp=num_tokens_across_dp,
positions=positions, # skip_cuda_graphs=skip_cuda_graphs),
inputs_embeds=inputs_embeds, # input_ids=input_ids,
intermediate_tensors=intermediate_tensors # positions=positions,
) # inputs_embeds=inputs_embeds,
self.cudagraphs[num_scheduled_tokens].outputs = model_output # intermediate_tensors=intermediate_tensors
return self.cudagraphs[num_scheduled_tokens].outputs # )
elif num_scheduled_tokens in self.cudagraphs: # self.cudagraphs[num_scheduled_tokens].outputs = model_output
self.cudagraphs[num_scheduled_tokens].cudagraph.replay() # return self.cudagraphs[num_scheduled_tokens].outputs
return self.cudagraphs[num_scheduled_tokens].outputs # elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
else: # assert False
return _run( # # logger.info(f"GRAPH REPLAY {num_scheduled_tokens}")
context = set_forward_context(attn_metadata, # self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
vllm_config=self.vllm_config, # return self.cudagraphs[num_scheduled_tokens].outputs
num_tokens=num_scheduled_tokens or 1, # else:
num_tokens_across_dp=num_tokens_across_dp, # logger.info(f"NORMAL RUN {num_scheduled_tokens}")
skip_cuda_graphs=skip_cuda_graphs), return _run(
input_ids=input_ids, context = set_forward_context(attn_metadata,
positions=positions, vllm_config=self.vllm_config,
inputs_embeds=inputs_embeds, num_tokens=num_scheduled_tokens or 1,
intermediate_tensors=intermediate_tensors num_tokens_across_dp=num_tokens_across_dp,
) skip_cuda_graphs=skip_cuda_graphs),
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
@ -2325,8 +2358,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
build_cuda_graph: bool = False build_cuda_graph: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
if allow_microbatching: # if allow_microbatching:
logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN") # logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN")
# TODO(Sage) We need some more code to properly handle # TODO(Sage) We need some more code to properly handle
@ -2339,6 +2372,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
allow_microbatching and capture_attn_cudagraph allow_microbatching and capture_attn_cudagraph
# _dummy_run doesn't go through _prepare_inputs so # _dummy_run doesn't go through _prepare_inputs so
# we synchronize with other DP ranks here # we synchronize with other DP ranks here
# logger.info(f"NUM TOKENS {num_tokens} SHOULD UBATCH {should_ubatch}")
should_ubatch = self.should_ubatch(allow_microbatching) should_ubatch = self.should_ubatch(allow_microbatching)
# Padding for DP # Padding for DP
# logger.info("PADDING DUMMY") # logger.info("PADDING DUMMY")

View File

@ -30,7 +30,6 @@ class UBatchContext:
self.id = id self.id = id
self.comm_stream = comm_stream self.comm_stream = comm_stream
self.compute_stream = compute_stream self.compute_stream = compute_stream
self.original_stream = current_stream()
self.forward_context = None #fwd_ctx self.forward_context = None #fwd_ctx
self.cpu_wait_event = cpu_wait_event self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event self.cpu_signal_event = cpu_signal_event
@ -58,7 +57,7 @@ class UBatchContext:
self.cpu_signal_event.set() self.cpu_signal_event.set()
self.cpu_wait_event.clear() self.cpu_wait_event.clear()
self.current_stream = self.compute_stream self.current_stream = self.compute_stream
torch.cuda.set_stream(self.original_stream) torch.cuda.set_stream(self.current_stream)
return False return False
def _restore_context(self): def _restore_context(self):
@ -76,26 +75,31 @@ class UBatchContext:
pass pass
def _signal_comm_done(self): def _signal_comm_done(self):
assert False
self.ctx_valid_state() self.ctx_valid_state()
self.gpu_comm_done_event.record(self.comm_stream) self.gpu_comm_done_event.record(self.comm_stream)
def _signal_compute_done(self): def _signal_compute_done(self):
assert False
self.ctx_valid_state() self.ctx_valid_state()
self.gpu_compute_done_event.record(self.compute_stream) self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self): def _wait_compute_done(self):
assert False
# print(f"{self.id} Waiting on COMPUTE stream", flush=True) # print(f"{self.id} Waiting on COMPUTE stream", flush=True)
self.ctx_valid_state() self.ctx_valid_state()
self.comm_stream.wait_event(self.gpu_compute_done_event) self.comm_stream.wait_event(self.gpu_compute_done_event)
# print("Compute stream done", flush=True) # print("Compute stream done", flush=True)
def _wait_comm_done(self): def _wait_comm_done(self):
assert False
# print(f"{self.id} Waiting on COMM stream", flush=True) # print(f"{self.id} Waiting on COMM stream", flush=True)
self.ctx_valid_state() self.ctx_valid_state()
self.compute_stream.wait_event(self.gpu_comm_done_event) self.compute_stream.wait_event(self.gpu_comm_done_event)
# print("Comm stream done", flush=True) # print("Comm stream done", flush=True)
def stream_string(self): def stream_string(self):
assert False
if current_stream() == self.compute_stream: if current_stream() == self.compute_stream:
assert self.current_stream == self.compute_stream assert self.current_stream == self.compute_stream
return "COMPUTE" return "COMPUTE"
@ -114,6 +118,7 @@ class UBatchContext:
# print(f"UBatchContext: {self.id} resuming CPU", flush=True) # print(f"UBatchContext: {self.id} resuming CPU", flush=True)
def yield_and_switch_from_compute_to_comm(self): def yield_and_switch_from_compute_to_comm(self):
assert False
assert current_stream() == self.compute_stream assert current_stream() == self.compute_stream
# dp_rank = get_dp_group().rank_in_group # dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} " # print(f"DP: {dp_rank} UB: {self.id} "
@ -129,6 +134,7 @@ class UBatchContext:
self._wait_compute_done() self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self): def yield_and_switch_from_comm_to_compute(self):
assert False
assert current_stream() == self.comm_stream assert current_stream() == self.comm_stream
# dp_rank = get_dp_group().rank_in_group # dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} " # print(f"DP: {dp_rank} UB: {self.id} "