one a2a kernel per microbatch group

This commit is contained in:
Sage Moore 2025-05-30 04:06:39 +00:00
parent 5cc573e791
commit 895a6c2a08
3 changed files with 80 additions and 39 deletions

View File

@ -307,13 +307,23 @@ class AllToAllCache:
return instance
# Global singleton
_all_to_all_cache = AllToAllCache()
from typing import List
_all_to_all_cache: List[AllToAllCache] = [AllToAllCache(), AllToAllCache()]
# Factory function as a cleaner interface
def get_all_to_all(**kwargs):
return _all_to_all_cache.get_or_create(**kwargs)
# Factory function with cache ID support
def get_all_to_all(cache_id: int, **kwargs):
"""Get or create an AllToAll instance from the specified cache.
Args:
cache_id: Integer ID of the cache to use (0 or 1)
**kwargs: Arguments passed to AllToAll creation
Returns:
AllToAll instance from the specified cache
"""
assert cache_id in (0, 1), f"cache_id must be 0 or 1, got {cache_id}"
return _all_to_all_cache[cache_id].get_or_create(**kwargs)
@CustomOp.register("unquantized_fused_moe")
@ -692,25 +702,26 @@ def _construct_prepare_finalize(
if moe.use_pplx_kernels:
logger.debug("using PplxPrepareAndFinalize")
all_to_all = get_all_to_all(
max_num_tokens=max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
kwargs = {
"max_num_tokens" :max_num_tokens,
"num_experts" :moe.num_experts,
"experts_per_token" :moe.experts_per_token, # topk
"rank" :rank,
"world_size" :world_size,
"dp_size" :dp_size,
"hidden_dim":moe.hidden_dim,
"hidden_dim_bytes" :moe.hidden_dim * moe.in_dtype.itemsize,
"hidden_dim_scale_bytes" :(0 if moe.in_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) //
moe.block_size * torch.float32.itemsize)))
moe.block_size * torch.float32.itemsize)),
}
a2as = [get_all_to_all(0, **kwargs), get_all_to_all(1, **kwargs)]
return PplxPrepareAndFinalize(
all_to_all,
a2as,
max_num_tokens=max_num_tokens,
world_size=world_size,
rank=rank,

View File

@ -19,7 +19,7 @@ from vllm.v1.worker.ubatching import (
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(self,
a2a: pplx.AllToAll,
a2as: list[pplx.AllToAll],
max_num_tokens: int,
world_size: int,
rank: int,
@ -28,7 +28,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
block_shape: Optional[list[int]] = None):
super().__init__()
assert max_num_tokens > 0
self.a2a = a2a
self.a2as = a2as
self.block_shape = block_shape
self.max_num_tokens = max_num_tokens
self.world_size = world_size
@ -49,6 +49,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
assert rank_topk_ids.size(0) == num_tokens
# assert expert_map is None, "NYI"
@ -110,7 +113,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
bound_m: Optional[torch.Tensor] = None
def dispatch(send: bool):
self.a2a.dispatch(
self.a2as[a2a_idx].dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
@ -122,9 +125,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send,
)
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send
torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default")
return expert_x, expert_x_scale, expert_num_tokens
@ -141,6 +150,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0)
bound_m: Optional[torch.Tensor] = None
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
assert topk_ids.size(0) == num_tokens, (
f"{topk_ids.size(0)} == {num_tokens}")
@ -153,7 +165,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights = torch.ones_like(topk_weights)
def combine(send: bool):
self.a2a.combine(
self.a2as[a2a_idx].combine(
out_tokens=output,
indices=topk_ids,
weights=topk_weights,
@ -162,8 +174,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_send=send,
do_recv=not send,
)
yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True)
torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False)
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default")

View File

@ -7,6 +7,7 @@ import os
from typing import Optional
from torch.library import Library
from torch.library import custom_op, register_kernel
from vllm.distributed import (get_dp_group)
from vllm.utils import current_stream
from vllm import forward_context
@ -53,7 +54,7 @@ class UBatchContext:
def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = None
print("Finishing ubatch %d\n" % self.id)
print("Finishing ubatch %d\n" % self.id, flush=True)
self.cpu_signal_event.set()
self.cpu_wait_event.clear()
self.current_stream = self.compute_stream
@ -81,49 +82,63 @@ class UBatchContext:
self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self):
print("Waiting on compute stream")
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
self.ctx_valid_state()
self.comm_stream.wait_event(self.gpu_compute_done_event)
print("Compute stream done")
# print("Compute stream done", flush=True)
def _wait_comm_done(self):
print("Waiting on comm stream")
# print(f"{self.id} Waiting on COMM stream", flush=True)
self.ctx_valid_state()
self.compute_stream.wait_event(self.gpu_comm_done_event)
print("Comm stream done")
# print("Comm stream done", flush=True)
def stream_string(self):
if current_stream() == self.compute_stream:
assert self.current_stream == self.compute_stream
return "COMPUTE"
elif current_stream() == self.comm_stream:
assert self.current_stream == self.comm_stream
return "COMM"
def _cpu_yield(self):
print("UBatchContext: %d yielding CPU\n" % self.id)
# print(f"UBatchContext: {self.id} yielding CPU", flush=True)
self.ctx_valid_state()
self.cpu_signal_event.set()
self.cpu_wait_event.wait()
self.cpu_wait_event.clear()
self._restore_context()
self.ctx_valid_state()
print("UBatchContext: %d resuming CPU\n" % self.id)
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
def yield_and_switch_from_compute_to_comm(self):
print("Yield and switch from compute")
assert current_stream() == self.compute_stream
dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state()
self._signal_compute_done()
# self._signal_compute_done()
self._cpu_yield()
self.ctx_valid_state()
assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream)
self._wait_compute_done()
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
# self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state()
self._signal_comm_done()
# self._signal_comm_done()
self._cpu_yield()
self.ctx_valid_state()
assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream)
self._wait_comm_done()
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
# self._wait_comm_done()
_CURRENT_CONTEXT: dict = {}
def get_current_ubatch_context() -> Optional[UBatchContext]:
global _CURRENT_CONTEXT
"""