one a2a kernel per microbatch group

This commit is contained in:
Sage Moore 2025-05-30 04:06:39 +00:00
parent 5cc573e791
commit 895a6c2a08
3 changed files with 80 additions and 39 deletions

View File

@ -307,13 +307,23 @@ class AllToAllCache:
return instance return instance
# Global singleton from typing import List
_all_to_all_cache = AllToAllCache() _all_to_all_cache: List[AllToAllCache] = [AllToAllCache(), AllToAllCache()]
# Factory function as a cleaner interface # Factory function with cache ID support
def get_all_to_all(**kwargs): def get_all_to_all(cache_id: int, **kwargs):
return _all_to_all_cache.get_or_create(**kwargs) """Get or create an AllToAll instance from the specified cache.
Args:
cache_id: Integer ID of the cache to use (0 or 1)
**kwargs: Arguments passed to AllToAll creation
Returns:
AllToAll instance from the specified cache
"""
assert cache_id in (0, 1), f"cache_id must be 0 or 1, got {cache_id}"
return _all_to_all_cache[cache_id].get_or_create(**kwargs)
@CustomOp.register("unquantized_fused_moe") @CustomOp.register("unquantized_fused_moe")
@ -693,24 +703,25 @@ def _construct_prepare_finalize(
if moe.use_pplx_kernels: if moe.use_pplx_kernels:
logger.debug("using PplxPrepareAndFinalize") logger.debug("using PplxPrepareAndFinalize")
all_to_all = get_all_to_all( kwargs = {
max_num_tokens=max_num_tokens, "max_num_tokens" :max_num_tokens,
num_experts=moe.num_experts, "num_experts" :moe.num_experts,
experts_per_token=moe.experts_per_token, # topk "experts_per_token" :moe.experts_per_token, # topk
rank=rank, "rank" :rank,
world_size=world_size, "world_size" :world_size,
dp_size=dp_size, "dp_size" :dp_size,
hidden_dim=moe.hidden_dim, "hidden_dim":moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, "hidden_dim_bytes" :moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to "hidden_dim_scale_bytes" :(0 if moe.in_dtype.itemsize != 1 else
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) // ((moe.hidden_dim + moe.block_size - 1) //
moe.block_size * torch.float32.itemsize))) moe.block_size * torch.float32.itemsize)),
}
a2as = [get_all_to_all(0, **kwargs), get_all_to_all(1, **kwargs)]
return PplxPrepareAndFinalize( return PplxPrepareAndFinalize(
all_to_all, a2as,
max_num_tokens=max_num_tokens, max_num_tokens=max_num_tokens,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,

View File

@ -19,7 +19,7 @@ from vllm.v1.worker.ubatching import (
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(self, def __init__(self,
a2a: pplx.AllToAll, a2as: list[pplx.AllToAll],
max_num_tokens: int, max_num_tokens: int,
world_size: int, world_size: int,
rank: int, rank: int,
@ -28,7 +28,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
block_shape: Optional[list[int]] = None): block_shape: Optional[list[int]] = None):
super().__init__() super().__init__()
assert max_num_tokens > 0 assert max_num_tokens > 0
self.a2a = a2a self.a2as = a2as
self.block_shape = block_shape self.block_shape = block_shape
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.world_size = world_size self.world_size = world_size
@ -49,6 +49,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K hidden_dim = a1.size(-1) # K
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
assert rank_topk_ids.size(0) == num_tokens assert rank_topk_ids.size(0) == num_tokens
# assert expert_map is None, "NYI" # assert expert_map is None, "NYI"
@ -110,7 +113,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
bound_m: Optional[torch.Tensor] = None bound_m: Optional[torch.Tensor] = None
def dispatch(send: bool): def dispatch(send: bool):
self.a2a.dispatch( self.a2as[a2a_idx].dispatch(
out_expert_num_tokens=expert_num_tokens, out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x, out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale, out_expert_x_scale=expert_x_scale,
@ -122,9 +125,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send, do_recv=not send,
) )
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
yield_and_switch_from_compute_to_comm_impl(schedule="default") yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send dispatch(True) # Send
torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv dispatch(False) # Recv
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default") yield_and_switch_from_comm_to_compute_impl(schedule="default")
return expert_x, expert_x_scale, expert_num_tokens return expert_x, expert_x_scale, expert_num_tokens
@ -141,6 +150,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# This argument is optional # This argument is optional
# There's not much point setting this unless it is != topk_ids.size(0) # There's not much point setting this unless it is != topk_ids.size(0)
bound_m: Optional[torch.Tensor] = None bound_m: Optional[torch.Tensor] = None
ubatch_ctx = get_current_ubatch_context()
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
assert topk_ids.size(0) == num_tokens, ( assert topk_ids.size(0) == num_tokens, (
f"{topk_ids.size(0)} == {num_tokens}") f"{topk_ids.size(0)} == {num_tokens}")
@ -153,7 +165,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights = torch.ones_like(topk_weights) topk_weights = torch.ones_like(topk_weights)
def combine(send: bool): def combine(send: bool):
self.a2a.combine( self.a2as[a2a_idx].combine(
out_tokens=output, out_tokens=output,
indices=topk_ids, indices=topk_ids,
weights=topk_weights, weights=topk_weights,
@ -162,8 +174,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_send=send, do_send=send,
do_recv=not send, do_recv=not send,
) )
yield_and_switch_from_compute_to_comm_impl(schedule="default") yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True) combine(True)
torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False) combine(False)
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default") yield_and_switch_from_comm_to_compute_impl(schedule="default")

View File

@ -7,6 +7,7 @@ import os
from typing import Optional from typing import Optional
from torch.library import Library from torch.library import Library
from torch.library import custom_op, register_kernel from torch.library import custom_op, register_kernel
from vllm.distributed import (get_dp_group)
from vllm.utils import current_stream from vllm.utils import current_stream
from vllm import forward_context from vllm import forward_context
@ -53,7 +54,7 @@ class UBatchContext:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
global _CURRENT_CONTEXT global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = None _CURRENT_CONTEXT[threading.get_ident()] = None
print("Finishing ubatch %d\n" % self.id) print("Finishing ubatch %d\n" % self.id, flush=True)
self.cpu_signal_event.set() self.cpu_signal_event.set()
self.cpu_wait_event.clear() self.cpu_wait_event.clear()
self.current_stream = self.compute_stream self.current_stream = self.compute_stream
@ -81,49 +82,63 @@ class UBatchContext:
self.gpu_compute_done_event.record(self.compute_stream) self.gpu_compute_done_event.record(self.compute_stream)
def _wait_compute_done(self): def _wait_compute_done(self):
print("Waiting on compute stream") # print(f"{self.id} Waiting on COMPUTE stream", flush=True)
self.ctx_valid_state() self.ctx_valid_state()
self.comm_stream.wait_event(self.gpu_compute_done_event) self.comm_stream.wait_event(self.gpu_compute_done_event)
print("Compute stream done") # print("Compute stream done", flush=True)
def _wait_comm_done(self): def _wait_comm_done(self):
print("Waiting on comm stream") # print(f"{self.id} Waiting on COMM stream", flush=True)
self.ctx_valid_state() self.ctx_valid_state()
self.compute_stream.wait_event(self.gpu_comm_done_event) self.compute_stream.wait_event(self.gpu_comm_done_event)
print("Comm stream done") # print("Comm stream done", flush=True)
def stream_string(self):
if current_stream() == self.compute_stream:
assert self.current_stream == self.compute_stream
return "COMPUTE"
elif current_stream() == self.comm_stream:
assert self.current_stream == self.comm_stream
return "COMM"
def _cpu_yield(self): def _cpu_yield(self):
print("UBatchContext: %d yielding CPU\n" % self.id) # print(f"UBatchContext: {self.id} yielding CPU", flush=True)
self.ctx_valid_state() self.ctx_valid_state()
self.cpu_signal_event.set() self.cpu_signal_event.set()
self.cpu_wait_event.wait() self.cpu_wait_event.wait()
self.cpu_wait_event.clear() self.cpu_wait_event.clear()
self._restore_context() self._restore_context()
self.ctx_valid_state() self.ctx_valid_state()
print("UBatchContext: %d resuming CPU\n" % self.id) # print(f"UBatchContext: {self.id} resuming CPU", flush=True)
def yield_and_switch_from_compute_to_comm(self): def yield_and_switch_from_compute_to_comm(self):
print("Yield and switch from compute") assert current_stream() == self.compute_stream
dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state() self.ctx_valid_state()
self._signal_compute_done() # self._signal_compute_done()
self._cpu_yield() self._cpu_yield()
self.ctx_valid_state() self.ctx_valid_state()
assert self.current_stream == self.compute_stream assert self.current_stream == self.compute_stream
self.update_stream(self.comm_stream) self.update_stream(self.comm_stream)
self._wait_compute_done() print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
# self._wait_compute_done()
def yield_and_switch_from_comm_to_compute(self): def yield_and_switch_from_comm_to_compute(self):
assert current_stream() == self.comm_stream
dp_rank = get_dp_group().rank_in_group
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
self.ctx_valid_state() self.ctx_valid_state()
self._signal_comm_done() # self._signal_comm_done()
self._cpu_yield() self._cpu_yield()
self.ctx_valid_state() self.ctx_valid_state()
assert self.current_stream == self.comm_stream assert self.current_stream == self.comm_stream
self.update_stream(self.compute_stream) self.update_stream(self.compute_stream)
self._wait_comm_done() print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
# self._wait_comm_done()
_CURRENT_CONTEXT: dict = {} _CURRENT_CONTEXT: dict = {}
def get_current_ubatch_context() -> Optional[UBatchContext]: def get_current_ubatch_context() -> Optional[UBatchContext]:
global _CURRENT_CONTEXT global _CURRENT_CONTEXT
""" """