mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-03 18:35:42 +08:00
various fixes
This commit is contained in:
parent
895a6c2a08
commit
5b0249b86e
@ -661,7 +661,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
f"Hidden size mismatch {hidden_states.size(-1)} "
|
f"Hidden size mismatch {hidden_states.size(-1)} "
|
||||||
f"!= {w1.size(2)}")
|
f"!= {w1.size(2)}")
|
||||||
|
|
||||||
print("in batched triton experts", hidden_states.shape, expert_num_tokens)
|
# print("in batched triton experts", hidden_states.shape, expert_num_tokens)
|
||||||
|
|
||||||
assert hidden_states.is_contiguous(
|
assert hidden_states.is_contiguous(
|
||||||
), "Hidden_states must be contiguous"
|
), "Hidden_states must be contiguous"
|
||||||
|
|||||||
@ -129,7 +129,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
dispatch(True) # Send
|
dispatch(True) # Send
|
||||||
torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
||||||
dispatch(False) # Recv
|
dispatch(False) # Recv
|
||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
@ -176,7 +176,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
)
|
)
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
combine(True)
|
combine(True)
|
||||||
torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
||||||
combine(False)
|
combine(False)
|
||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
|
|||||||
@ -113,29 +113,29 @@ class UBatchContext:
|
|||||||
|
|
||||||
def yield_and_switch_from_compute_to_comm(self):
|
def yield_and_switch_from_compute_to_comm(self):
|
||||||
assert current_stream() == self.compute_stream
|
assert current_stream() == self.compute_stream
|
||||||
dp_rank = get_dp_group().rank_in_group
|
# dp_rank = get_dp_group().rank_in_group
|
||||||
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
# self._signal_compute_done()
|
self._signal_compute_done()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
assert self.current_stream == self.compute_stream
|
assert self.current_stream == self.compute_stream
|
||||||
self.update_stream(self.comm_stream)
|
self.update_stream(self.comm_stream)
|
||||||
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||||
# self._wait_compute_done()
|
self._wait_compute_done()
|
||||||
|
|
||||||
def yield_and_switch_from_comm_to_compute(self):
|
def yield_and_switch_from_comm_to_compute(self):
|
||||||
assert current_stream() == self.comm_stream
|
assert current_stream() == self.comm_stream
|
||||||
dp_rank = get_dp_group().rank_in_group
|
# dp_rank = get_dp_group().rank_in_group
|
||||||
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
# self._signal_comm_done()
|
self._signal_comm_done()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
assert self.current_stream == self.comm_stream
|
assert self.current_stream == self.comm_stream
|
||||||
self.update_stream(self.compute_stream)
|
self.update_stream(self.compute_stream)
|
||||||
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||||
# self._wait_comm_done()
|
self._wait_comm_done()
|
||||||
|
|
||||||
|
|
||||||
_CURRENT_CONTEXT: dict = {}
|
_CURRENT_CONTEXT: dict = {}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user