debugging

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-03 19:25:01 +00:00
parent e080e068ed
commit 2e3484c237
4 changed files with 21 additions and 21 deletions

View File

@ -83,7 +83,7 @@ def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
] * 100
# with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset,

View File

@ -125,15 +125,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
yield_and_switch_from_compute_to_comm_impl(schedule="default")
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send
torch.cuda.synchronize()
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv
torch.cuda.synchronize()
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default")
torch.cuda.synchronize()
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
# torch.cuda.synchronize()
return expert_x, expert_x_scale, expert_num_tokens
def finalize(
@ -173,11 +173,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send,
)
yield_and_switch_from_compute_to_comm_impl(schedule="default")
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True)
torch.cuda.synchronize()
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False)
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default")
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
torch.cuda.synchronize()

View File

@ -1363,12 +1363,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode()
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
use_dummy_input):
print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
if save_results:
results.append(model_output)
print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_slices, attn_metadata,
is_dummy_run) -> torch.Tensor:

View File

@ -115,32 +115,32 @@ class UBatchContext:
def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream
dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} "
f"Yield and switch from {self.stream_string()}", flush=True)
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state()
self._signal_compute_done()
self._cpu_yield()
self.ctx_valid_state()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
print(f"DP: {dp_rank} UB: {self.id} "
f"Resuming on stream {self.stream_string()}", flush=True)
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Resuming on stream {self.stream_string()}", flush=True)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} "
f"Yield and switch from {self.stream_string()}", flush=True)
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state()
self._signal_comm_done()
self._cpu_yield()
self.ctx_valid_state()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
print(f"DP: {dp_rank} UB: {self.id} "
f"Resuming on stream {self.stream_string()}", flush=True)
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Resuming on stream {self.stream_string()}", flush=True)
self._wait_comm_done()