various fixes

This commit is contained in:
Sage Moore 2025-05-30 14:19:12 +00:00
parent 895a6c2a08
commit 5b0249b86e
3 changed files with 13 additions and 13 deletions

View File

@ -661,7 +661,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
f"Hidden size mismatch {hidden_states.size(-1)} " f"Hidden size mismatch {hidden_states.size(-1)} "
f"!= {w1.size(2)}") f"!= {w1.size(2)}")
print("in batched triton experts", hidden_states.shape, expert_num_tokens) # print("in batched triton experts", hidden_states.shape, expert_num_tokens)
assert hidden_states.is_contiguous( assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous" ), "Hidden_states must be contiguous"

View File

@ -129,7 +129,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1 ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
yield_and_switch_from_compute_to_comm_impl(schedule="default") yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send dispatch(True) # Send
torch.cuda.synchronize() # torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True) # print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv dispatch(False) # Recv
# torch.cuda.synchronize() # torch.cuda.synchronize()
@ -176,7 +176,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
yield_and_switch_from_compute_to_comm_impl(schedule="default") yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True) combine(True)
torch.cuda.synchronize() # torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True) # print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False) combine(False)
# torch.cuda.synchronize() # torch.cuda.synchronize()

View File

@ -113,29 +113,29 @@ class UBatchContext:
def yield_and_switch_from_compute_to_comm(self): def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream assert current_stream() == self.compute_stream
dp_rank = get_dp_group().rank_in_group # dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True) # print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state() self.ctx_valid_state()
# self._signal_compute_done() self._signal_compute_done()
self._cpu_yield() self._cpu_yield()
self.ctx_valid_state() self.ctx_valid_state()
assert self.current_stream == self.compute_stream assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream) self.update_stream(self.comm_stream)
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True) # print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
# self._wait_compute_done() self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self): def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream assert current_stream() == self.comm_stream
dp_rank = get_dp_group().rank_in_group # dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True) # print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state() self.ctx_valid_state()
# self._signal_comm_done() self._signal_comm_done()
self._cpu_yield() self._cpu_yield()
self.ctx_valid_state() self.ctx_valid_state()
assert self.current_stream == self.comm_stream assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream) self.update_stream(self.compute_stream)
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True) # print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
# self._wait_comm_done() self._wait_comm_done()
_CURRENT_CONTEXT: dict = {} _CURRENT_CONTEXT: dict = {}