mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 00:07:02 +08:00
one a2a kernel per microbatch group
This commit is contained in:
parent
5cc573e791
commit
895a6c2a08
@ -307,13 +307,23 @@ class AllToAllCache:
|
||||
return instance
|
||||
|
||||
|
||||
# Global singleton
|
||||
_all_to_all_cache = AllToAllCache()
|
||||
from typing import List
|
||||
_all_to_all_cache: List[AllToAllCache] = [AllToAllCache(), AllToAllCache()]
|
||||
|
||||
|
||||
# Factory function as a cleaner interface
|
||||
def get_all_to_all(**kwargs):
|
||||
return _all_to_all_cache.get_or_create(**kwargs)
|
||||
# Factory function with cache ID support
|
||||
def get_all_to_all(cache_id: int, **kwargs):
|
||||
"""Get or create an AllToAll instance from the specified cache.
|
||||
|
||||
Args:
|
||||
cache_id: Integer ID of the cache to use (0 or 1)
|
||||
**kwargs: Arguments passed to AllToAll creation
|
||||
|
||||
Returns:
|
||||
AllToAll instance from the specified cache
|
||||
"""
|
||||
assert cache_id in (0, 1), f"cache_id must be 0 or 1, got {cache_id}"
|
||||
return _all_to_all_cache[cache_id].get_or_create(**kwargs)
|
||||
|
||||
|
||||
@CustomOp.register("unquantized_fused_moe")
|
||||
@ -692,25 +702,26 @@ def _construct_prepare_finalize(
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
logger.debug("using PplxPrepareAndFinalize")
|
||||
|
||||
all_to_all = get_all_to_all(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
experts_per_token=moe.experts_per_token, # topk
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to sizeof(float32)
|
||||
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
|
||||
|
||||
kwargs = {
|
||||
"max_num_tokens" :max_num_tokens,
|
||||
"num_experts" :moe.num_experts,
|
||||
"experts_per_token" :moe.experts_per_token, # topk
|
||||
"rank" :rank,
|
||||
"world_size" :world_size,
|
||||
"dp_size" :dp_size,
|
||||
"hidden_dim":moe.hidden_dim,
|
||||
"hidden_dim_bytes" :moe.hidden_dim * moe.in_dtype.itemsize,
|
||||
"hidden_dim_scale_bytes" :(0 if moe.in_dtype.itemsize != 1 else
|
||||
((moe.hidden_dim + moe.block_size - 1) //
|
||||
moe.block_size * torch.float32.itemsize)))
|
||||
moe.block_size * torch.float32.itemsize)),
|
||||
}
|
||||
|
||||
|
||||
a2as = [get_all_to_all(0, **kwargs), get_all_to_all(1, **kwargs)]
|
||||
|
||||
return PplxPrepareAndFinalize(
|
||||
all_to_all,
|
||||
a2as,
|
||||
max_num_tokens=max_num_tokens,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm.v1.worker.ubatching import (
|
||||
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
def __init__(self,
|
||||
a2a: pplx.AllToAll,
|
||||
a2as: list[pplx.AllToAll],
|
||||
max_num_tokens: int,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
@ -28,7 +28,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
block_shape: Optional[list[int]] = None):
|
||||
super().__init__()
|
||||
assert max_num_tokens > 0
|
||||
self.a2a = a2a
|
||||
self.a2as = a2as
|
||||
self.block_shape = block_shape
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.world_size = world_size
|
||||
@ -49,6 +49,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
ubatch_ctx = get_current_ubatch_context()
|
||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||
|
||||
assert rank_topk_ids.size(0) == num_tokens
|
||||
# assert expert_map is None, "NYI"
|
||||
@ -110,7 +113,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
|
||||
def dispatch(send: bool):
|
||||
self.a2a.dispatch(
|
||||
self.a2as[a2a_idx].dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
@ -122,9 +125,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
do_recv=not send,
|
||||
)
|
||||
|
||||
ubatch_ctx = get_current_ubatch_context()
|
||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
dispatch(True) # Send
|
||||
torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
||||
dispatch(False) # Recv
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
|
||||
return expert_x, expert_x_scale, expert_num_tokens
|
||||
@ -141,6 +150,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# This argument is optional
|
||||
# There's not much point setting this unless it is != topk_ids.size(0)
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
ubatch_ctx = get_current_ubatch_context()
|
||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||
|
||||
assert topk_ids.size(0) == num_tokens, (
|
||||
f"{topk_ids.size(0)} == {num_tokens}")
|
||||
@ -153,7 +165,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
def combine(send: bool):
|
||||
self.a2a.combine(
|
||||
self.a2as[a2a_idx].combine(
|
||||
out_tokens=output,
|
||||
indices=topk_ids,
|
||||
weights=topk_weights,
|
||||
@ -162,8 +174,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
do_send=send,
|
||||
do_recv=not send,
|
||||
)
|
||||
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
combine(True)
|
||||
torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
||||
combine(False)
|
||||
# torch.cuda.synchronize()
|
||||
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
@ -7,6 +7,7 @@ import os
|
||||
from typing import Optional
|
||||
from torch.library import Library
|
||||
from torch.library import custom_op, register_kernel
|
||||
from vllm.distributed import (get_dp_group)
|
||||
|
||||
from vllm.utils import current_stream
|
||||
from vllm import forward_context
|
||||
@ -53,7 +54,7 @@ class UBatchContext:
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = None
|
||||
print("Finishing ubatch %d\n" % self.id)
|
||||
print("Finishing ubatch %d\n" % self.id, flush=True)
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.clear()
|
||||
self.current_stream = self.compute_stream
|
||||
@ -81,49 +82,63 @@ class UBatchContext:
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
|
||||
def _wait_compute_done(self):
|
||||
print("Waiting on compute stream")
|
||||
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||
print("Compute stream done")
|
||||
# print("Compute stream done", flush=True)
|
||||
|
||||
def _wait_comm_done(self):
|
||||
print("Waiting on comm stream")
|
||||
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||
print("Comm stream done")
|
||||
# print("Comm stream done", flush=True)
|
||||
|
||||
def stream_string(self):
|
||||
if current_stream() == self.compute_stream:
|
||||
assert self.current_stream == self.compute_stream
|
||||
return "COMPUTE"
|
||||
elif current_stream() == self.comm_stream:
|
||||
assert self.current_stream == self.comm_stream
|
||||
return "COMM"
|
||||
|
||||
def _cpu_yield(self):
|
||||
print("UBatchContext: %d yielding CPU\n" % self.id)
|
||||
# print(f"UBatchContext: {self.id} yielding CPU", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.wait()
|
||||
self.cpu_wait_event.clear()
|
||||
self._restore_context()
|
||||
self.ctx_valid_state()
|
||||
print("UBatchContext: %d resuming CPU\n" % self.id)
|
||||
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
print("Yield and switch from compute")
|
||||
assert current_stream() == self.compute_stream
|
||||
dp_rank = get_dp_group().rank_in_group
|
||||
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self._signal_compute_done()
|
||||
# self._signal_compute_done()
|
||||
self._cpu_yield()
|
||||
self.ctx_valid_state()
|
||||
assert self.current_stream == self.compute_stream
|
||||
self.update_stream(self.comm_stream)
|
||||
self._wait_compute_done()
|
||||
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||
# self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
assert current_stream() == self.comm_stream
|
||||
dp_rank = get_dp_group().rank_in_group
|
||||
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self._signal_comm_done()
|
||||
# self._signal_comm_done()
|
||||
self._cpu_yield()
|
||||
self.ctx_valid_state()
|
||||
assert self.current_stream == self.comm_stream
|
||||
self.update_stream(self.compute_stream)
|
||||
self._wait_comm_done()
|
||||
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||
# self._wait_comm_done()
|
||||
|
||||
|
||||
_CURRENT_CONTEXT: dict = {}
|
||||
|
||||
def get_current_ubatch_context() -> Optional[UBatchContext]:
|
||||
global _CURRENT_CONTEXT
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user