various fixes

This commit is contained in:
Sage Moore 2025-05-30 14:19:12 +00:00
parent 895a6c2a08
commit 5b0249b86e
3 changed files with 13 additions and 13 deletions

View File

@ -661,7 +661,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
f"Hidden size mismatch {hidden_states.size(-1)} "
f"!= {w1.size(2)}")
print("in batched triton experts", hidden_states.shape, expert_num_tokens)
# print("in batched triton experts", hidden_states.shape, expert_num_tokens)
assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous"

View File

@ -129,7 +129,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send
torch.cuda.synchronize()
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv
# torch.cuda.synchronize()
@ -176,7 +176,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True)
torch.cuda.synchronize()
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False)
# torch.cuda.synchronize()

View File

@ -113,29 +113,29 @@ class UBatchContext:
def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream
dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state()
# self._signal_compute_done()
self._signal_compute_done()
self._cpu_yield()
self.ctx_valid_state()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
# self._wait_compute_done()
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state()
# self._signal_comm_done()
self._signal_comm_done()
self._cpu_yield()
self.ctx_valid_state()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
# self._wait_comm_done()
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
self._wait_comm_done()
_CURRENT_CONTEXT: dict = {}