fix pplx a2a

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-03 18:21:17 +00:00
parent 5f4a501b9a
commit e080e068ed
5 changed files with 48 additions and 36 deletions

View File

@ -100,13 +100,23 @@ class PPLXAll2AllManager(All2AllManagerBase):
logger.debug("PPLX NVSHMEM UID = %s", uid)
nvshmem_init(uid, self.rank, self.world_size)
self.handle_cache = Cache()
# self.handle_cache = Cache()
self.handle_caches = [Cache(), Cache()]
def get_handle(self, kwargs):
import pplx_kernels as pplx
return self.handle_cache.get_or_create(
return self.handle_caches[0].get_or_create(
kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
def get_handles(self, kwargs):
import pplx_kernels as pplx
first_handle = self.handle_caches[0].get_or_create(kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
second_handle = self.handle_caches[1].get_or_create(kwargs, pplx.AllToAll.internode
if self.internode else pplx.AllToAll.intranode)
return [first_handle, second_handle]
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
@ -116,9 +126,10 @@ class PPLXAll2AllManager(All2AllManagerBase):
raise NotImplementedError
def destroy(self):
with self.handle_cache._lock:
for _, handle in self.handle_cache._cache.items():
handle.destroy()
for handle_cache in self.handle_caches:
with handle_cache._lock:
for _, handle in handle_cache._cache.items():
handle.destroy()
if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize

View File

@ -272,10 +272,10 @@ class FusedMoEMethodBase(QuantizeMethodBase):
group_name=all2all_manager.cpu_group.group_name,
)
handle = all2all_manager.get_handle(all_to_all_args)
handles = all2all_manager.get_handles(all_to_all_args)
prepare_finalize = PplxPrepareAndFinalize(
handle,
handles,
max_num_tokens=moe.max_num_tokens,
world_size=all2all_manager.world_size,
rank=all2all_manager.rank,

View File

@ -45,7 +45,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
assert False
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
ubatch_ctx = get_current_ubatch_context()
@ -128,10 +127,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send
# torch.cuda.synchronize()
torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv
# torch.cuda.synchronize()
torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default")
torch.cuda.synchronize()
@ -145,7 +144,6 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
assert False
num_tokens = output.size(0) # M
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
@ -177,9 +175,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True)
# torch.cuda.synchronize()
torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False)
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default")
# torch.cuda.synchronize()
torch.cuda.synchronize()

View File

@ -1363,10 +1363,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode()
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
use_dummy_input):
print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
if save_results:
results.append(model_output)
print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_slices, attn_metadata,
is_dummy_run) -> torch.Tensor:

View File

@ -8,6 +8,7 @@ from torch.library import custom_op
from vllm import forward_context
from vllm.utils import current_stream
from vllm.distributed.parallel_state import get_dp_group
class UBatchContext:
@ -69,28 +70,28 @@ class UBatchContext:
torch.cuda.set_stream(self.current_stream)
def ctx_valid_state(self):
# assert forward_context._forward_context == self.forward_context
# assert current_stream() == self.current_stream
# assert not self.cpu_wait_event.is_set()
assert forward_context._forward_context == self.forward_context
assert current_stream() == self.current_stream
assert not self.cpu_wait_event.is_set()
pass
def _signal_comm_done(self):
# self.ctx_valid_state()
self.ctx_valid_state()
self.gpu_comm_done_event.record(self.comm_stream)
def _signal_compute_done(self):
# self.ctx_valid_state()
self.ctx_valid_state()
self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self):
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
# self.ctx_valid_state()
self.ctx_valid_state()
self.comm_stream.wait_event(self.gpu_compute_done_event)
# print("Compute stream done", flush=True)
def _wait_comm_done(self):
# print(f"{self.id} Waiting on COMM stream", flush=True)
# self.ctx_valid_state()
self.ctx_valid_state()
self.compute_stream.wait_event(self.gpu_comm_done_event)
# print("Comm stream done", flush=True)
@ -104,42 +105,42 @@ class UBatchContext:
def _cpu_yield(self):
# print(f"UBatchContext: {self.id} yielding CPU", flush=True)
# self.ctx_valid_state()
self.ctx_valid_state()
self.cpu_signal_event.set()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
# self.ctx_valid_state()
self.ctx_valid_state()
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
def yield_and_switch_from_compute_to_comm(self):
assert current_stream() == self.compute_stream
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Yield and switch from {self.stream_string()}", flush=True)
# self.ctx_valid_state()
dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} "
f"Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state()
self._signal_compute_done()
self._cpu_yield()
# self.ctx_valid_state()
self.ctx_valid_state()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Resuming on stream {self.stream_string()}", flush=True)
print(f"DP: {dp_rank} UB: {self.id} "
f"Resuming on stream {self.stream_string()}", flush=True)
self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
# dp_rank = get_dp_group().rank_in_group
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Yield and switch from {self.stream_string()}", flush=True)
# self.ctx_valid_state()
dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} "
f"Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state()
self._signal_comm_done()
self._cpu_yield()
# self.ctx_valid_state()
self.ctx_valid_state()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
# print(f"DP: {dp_rank} UB: {self.id} "
# f"Resuming on stream {self.stream_string()}", flush=True)
print(f"DP: {dp_rank} UB: {self.id} "
f"Resuming on stream {self.stream_string()}", flush=True)
self._wait_comm_done()