mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 22:07:07 +08:00
various fixes
This commit is contained in:
parent
895a6c2a08
commit
5b0249b86e
@ -661,7 +661,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
f"Hidden size mismatch {hidden_states.size(-1)} "
|
||||
f"!= {w1.size(2)}")
|
||||
|
||||
print("in batched triton experts", hidden_states.shape, expert_num_tokens)
|
||||
# print("in batched triton experts", hidden_states.shape, expert_num_tokens)
|
||||
|
||||
assert hidden_states.is_contiguous(
|
||||
), "Hidden_states must be contiguous"
|
||||
|
||||
@ -129,7 +129,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
dispatch(True) # Send
|
||||
torch.cuda.synchronize()
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
||||
dispatch(False) # Recv
|
||||
# torch.cuda.synchronize()
|
||||
@ -176,7 +176,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
)
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
combine(True)
|
||||
torch.cuda.synchronize()
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
||||
combine(False)
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
@ -113,29 +113,29 @@ class UBatchContext:
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert current_stream() == self.compute_stream
|
||||
dp_rank = get_dp_group().rank_in_group
|
||||
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||
self.ctx_valid_state()
|
||||
# self._signal_compute_done()
|
||||
self._signal_compute_done()
|
||||
self._cpu_yield()
|
||||
self.ctx_valid_state()
|
||||
assert self.current_stream == self.compute_stream
|
||||
self.update_stream(self.comm_stream)
|
||||
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||
# self._wait_compute_done()
|
||||
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
assert current_stream() == self.comm_stream
|
||||
dp_rank = get_dp_group().rank_in_group
|
||||
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||
self.ctx_valid_state()
|
||||
# self._signal_comm_done()
|
||||
self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
self.ctx_valid_state()
|
||||
assert self.current_stream == self.comm_stream
|
||||
self.update_stream(self.compute_stream)
|
||||
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||
# self._wait_comm_done()
|
||||
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||
self._wait_comm_done()
|
||||
|
||||
|
||||
_CURRENT_CONTEXT: dict = {}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user