debugging

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-03 19:25:01 +00:00
parent e080e068ed
commit 2e3484c237
4 changed files with 21 additions and 21 deletions

View File

@ -83,7 +83,7 @@ def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] ] * 100
# with DP, each rank should process different prompts. # with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset, # usually all the DP ranks process a full dataset,

View File

@ -125,15 +125,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
ubatch_ctx = get_current_ubatch_context() ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
yield_and_switch_from_compute_to_comm_impl(schedule="default") # yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send dispatch(True) # Send
torch.cuda.synchronize() # torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True) # print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv dispatch(False) # Recv
torch.cuda.synchronize() # torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True) # print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default") # yield_and_switch_from_comm_to_compute_impl(schedule="default")
torch.cuda.synchronize() # torch.cuda.synchronize()
return expert_x, expert_x_scale, expert_num_tokens return expert_x, expert_x_scale, expert_num_tokens
def finalize( def finalize(
@ -173,11 +173,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send, do_recv=not send,
) )
yield_and_switch_from_compute_to_comm_impl(schedule="default") # yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True) combine(True)
torch.cuda.synchronize() # torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True) # print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False) combine(False)
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True) # print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default") # yield_and_switch_from_comm_to_compute_impl(schedule="default")
torch.cuda.synchronize() torch.cuda.synchronize()

View File

@ -1363,12 +1363,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode() @torch.inference_mode()
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
use_dummy_input): use_dummy_input):
print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
model_output = _run(token_slice, ubatch_ctx, use_dummy_input) model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
if save_results: if save_results:
results.append(model_output) results.append(model_output)
print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) # print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_slices, attn_metadata, def _run_ubatches(ubatch_slices, attn_metadata,
is_dummy_run) -> torch.Tensor: is_dummy_run) -> torch.Tensor:

View File

@ -115,32 +115,32 @@ class UBatchContext:
def yield_and_switch_from_compute_to_comm(self): def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream assert current_stream() == self.compute_stream
dp_rank = get_dp_group().rank_in_group # dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} " # print(f"DP: {dp_rank} UB: {self.id} "
f"Yield and switch from {self.stream_string()}", flush=True) # f"Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state() self.ctx_valid_state()
self._signal_compute_done() self._signal_compute_done()
self._cpu_yield() self._cpu_yield()
self.ctx_valid_state() self.ctx_valid_state()
assert self.current_stream == self.compute_stream assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream) self.update_stream(self.comm_stream)
print(f"DP: {dp_rank} UB: {self.id} " # print(f"DP: {dp_rank} UB: {self.id} "
f"Resuming on stream {self.stream_string()}", flush=True) # f"Resuming on stream {self.stream_string()}", flush=True)
self._wait_compute_done() self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self): def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream assert current_stream() == self.comm_stream
dp_rank = get_dp_group().rank_in_group # dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} " # print(f"DP: {dp_rank} UB: {self.id} "
f"Yield and switch from {self.stream_string()}", flush=True) # f"Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state() self.ctx_valid_state()
self._signal_comm_done() self._signal_comm_done()
self._cpu_yield() self._cpu_yield()
self.ctx_valid_state() self.ctx_valid_state()
assert self.current_stream == self.comm_stream assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream) self.update_stream(self.compute_stream)
print(f"DP: {dp_rank} UB: {self.id} " # print(f"DP: {dp_rank} UB: {self.id} "
f"Resuming on stream {self.stream_string()}", flush=True) # f"Resuming on stream {self.stream_string()}", flush=True)
self._wait_comm_done() self._wait_comm_done()