mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 15:55:40 +08:00
one a2a kernel per microbatch group
This commit is contained in:
parent
5cc573e791
commit
895a6c2a08
@ -307,13 +307,23 @@ class AllToAllCache:
|
|||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
||||||
# Global singleton
|
from typing import List
|
||||||
_all_to_all_cache = AllToAllCache()
|
_all_to_all_cache: List[AllToAllCache] = [AllToAllCache(), AllToAllCache()]
|
||||||
|
|
||||||
|
|
||||||
# Factory function as a cleaner interface
|
# Factory function with cache ID support
|
||||||
def get_all_to_all(**kwargs):
|
def get_all_to_all(cache_id: int, **kwargs):
|
||||||
return _all_to_all_cache.get_or_create(**kwargs)
|
"""Get or create an AllToAll instance from the specified cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache_id: Integer ID of the cache to use (0 or 1)
|
||||||
|
**kwargs: Arguments passed to AllToAll creation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AllToAll instance from the specified cache
|
||||||
|
"""
|
||||||
|
assert cache_id in (0, 1), f"cache_id must be 0 or 1, got {cache_id}"
|
||||||
|
return _all_to_all_cache[cache_id].get_or_create(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("unquantized_fused_moe")
|
@CustomOp.register("unquantized_fused_moe")
|
||||||
@ -692,25 +702,26 @@ def _construct_prepare_finalize(
|
|||||||
|
|
||||||
if moe.use_pplx_kernels:
|
if moe.use_pplx_kernels:
|
||||||
logger.debug("using PplxPrepareAndFinalize")
|
logger.debug("using PplxPrepareAndFinalize")
|
||||||
|
|
||||||
all_to_all = get_all_to_all(
|
kwargs = {
|
||||||
max_num_tokens=max_num_tokens,
|
"max_num_tokens" :max_num_tokens,
|
||||||
num_experts=moe.num_experts,
|
"num_experts" :moe.num_experts,
|
||||||
experts_per_token=moe.experts_per_token, # topk
|
"experts_per_token" :moe.experts_per_token, # topk
|
||||||
rank=rank,
|
"rank" :rank,
|
||||||
world_size=world_size,
|
"world_size" :world_size,
|
||||||
dp_size=dp_size,
|
"dp_size" :dp_size,
|
||||||
hidden_dim=moe.hidden_dim,
|
"hidden_dim":moe.hidden_dim,
|
||||||
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
|
"hidden_dim_bytes" :moe.hidden_dim * moe.in_dtype.itemsize,
|
||||||
# For blocked per token: set to
|
"hidden_dim_scale_bytes" :(0 if moe.in_dtype.itemsize != 1 else
|
||||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
|
||||||
# For per-token: set to sizeof(float32)
|
|
||||||
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
|
|
||||||
((moe.hidden_dim + moe.block_size - 1) //
|
((moe.hidden_dim + moe.block_size - 1) //
|
||||||
moe.block_size * torch.float32.itemsize)))
|
moe.block_size * torch.float32.itemsize)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
a2as = [get_all_to_all(0, **kwargs), get_all_to_all(1, **kwargs)]
|
||||||
|
|
||||||
return PplxPrepareAndFinalize(
|
return PplxPrepareAndFinalize(
|
||||||
all_to_all,
|
a2as,
|
||||||
max_num_tokens=max_num_tokens,
|
max_num_tokens=max_num_tokens,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from vllm.v1.worker.ubatching import (
|
|||||||
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
a2a: pplx.AllToAll,
|
a2as: list[pplx.AllToAll],
|
||||||
max_num_tokens: int,
|
max_num_tokens: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
rank: int,
|
rank: int,
|
||||||
@ -28,7 +28,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
block_shape: Optional[list[int]] = None):
|
block_shape: Optional[list[int]] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert max_num_tokens > 0
|
assert max_num_tokens > 0
|
||||||
self.a2a = a2a
|
self.a2as = a2as
|
||||||
self.block_shape = block_shape
|
self.block_shape = block_shape
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
@ -49,6 +49,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
num_tokens = a1.size(0) # M
|
num_tokens = a1.size(0) # M
|
||||||
hidden_dim = a1.size(-1) # K
|
hidden_dim = a1.size(-1) # K
|
||||||
|
ubatch_ctx = get_current_ubatch_context()
|
||||||
|
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||||
|
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||||
|
|
||||||
assert rank_topk_ids.size(0) == num_tokens
|
assert rank_topk_ids.size(0) == num_tokens
|
||||||
# assert expert_map is None, "NYI"
|
# assert expert_map is None, "NYI"
|
||||||
@ -110,7 +113,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
bound_m: Optional[torch.Tensor] = None
|
bound_m: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def dispatch(send: bool):
|
def dispatch(send: bool):
|
||||||
self.a2a.dispatch(
|
self.a2as[a2a_idx].dispatch(
|
||||||
out_expert_num_tokens=expert_num_tokens,
|
out_expert_num_tokens=expert_num_tokens,
|
||||||
out_expert_x=expert_x,
|
out_expert_x=expert_x,
|
||||||
out_expert_x_scale=expert_x_scale,
|
out_expert_x_scale=expert_x_scale,
|
||||||
@ -122,9 +125,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
do_recv=not send,
|
do_recv=not send,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ubatch_ctx = get_current_ubatch_context()
|
||||||
|
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
dispatch(True) # Send
|
dispatch(True) # Send
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
|
||||||
dispatch(False) # Recv
|
dispatch(False) # Recv
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||||
|
|
||||||
return expert_x, expert_x_scale, expert_num_tokens
|
return expert_x, expert_x_scale, expert_num_tokens
|
||||||
@ -141,6 +150,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
# This argument is optional
|
# This argument is optional
|
||||||
# There's not much point setting this unless it is != topk_ids.size(0)
|
# There's not much point setting this unless it is != topk_ids.size(0)
|
||||||
bound_m: Optional[torch.Tensor] = None
|
bound_m: Optional[torch.Tensor] = None
|
||||||
|
ubatch_ctx = get_current_ubatch_context()
|
||||||
|
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||||
|
a2a_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||||
|
|
||||||
assert topk_ids.size(0) == num_tokens, (
|
assert topk_ids.size(0) == num_tokens, (
|
||||||
f"{topk_ids.size(0)} == {num_tokens}")
|
f"{topk_ids.size(0)} == {num_tokens}")
|
||||||
@ -153,7 +165,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
topk_weights = torch.ones_like(topk_weights)
|
topk_weights = torch.ones_like(topk_weights)
|
||||||
|
|
||||||
def combine(send: bool):
|
def combine(send: bool):
|
||||||
self.a2a.combine(
|
self.a2as[a2a_idx].combine(
|
||||||
out_tokens=output,
|
out_tokens=output,
|
||||||
indices=topk_ids,
|
indices=topk_ids,
|
||||||
weights=topk_weights,
|
weights=topk_weights,
|
||||||
@ -162,8 +174,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
do_send=send,
|
do_send=send,
|
||||||
do_recv=not send,
|
do_recv=not send,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
combine(True)
|
combine(True)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
|
||||||
combine(False)
|
combine(False)
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||||
@ -7,6 +7,7 @@ import os
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
from torch.library import custom_op, register_kernel
|
from torch.library import custom_op, register_kernel
|
||||||
|
from vllm.distributed import (get_dp_group)
|
||||||
|
|
||||||
from vllm.utils import current_stream
|
from vllm.utils import current_stream
|
||||||
from vllm import forward_context
|
from vllm import forward_context
|
||||||
@ -53,7 +54,7 @@ class UBatchContext:
|
|||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
global _CURRENT_CONTEXT
|
global _CURRENT_CONTEXT
|
||||||
_CURRENT_CONTEXT[threading.get_ident()] = None
|
_CURRENT_CONTEXT[threading.get_ident()] = None
|
||||||
print("Finishing ubatch %d\n" % self.id)
|
print("Finishing ubatch %d\n" % self.id, flush=True)
|
||||||
self.cpu_signal_event.set()
|
self.cpu_signal_event.set()
|
||||||
self.cpu_wait_event.clear()
|
self.cpu_wait_event.clear()
|
||||||
self.current_stream = self.compute_stream
|
self.current_stream = self.compute_stream
|
||||||
@ -81,49 +82,63 @@ class UBatchContext:
|
|||||||
self.gpu_compute_done_event.record(self.compute_stream)
|
self.gpu_compute_done_event.record(self.compute_stream)
|
||||||
|
|
||||||
def _wait_compute_done(self):
|
def _wait_compute_done(self):
|
||||||
print("Waiting on compute stream")
|
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||||
print("Compute stream done")
|
# print("Compute stream done", flush=True)
|
||||||
|
|
||||||
def _wait_comm_done(self):
|
def _wait_comm_done(self):
|
||||||
print("Waiting on comm stream")
|
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||||
print("Comm stream done")
|
# print("Comm stream done", flush=True)
|
||||||
|
|
||||||
|
def stream_string(self):
|
||||||
|
if current_stream() == self.compute_stream:
|
||||||
|
assert self.current_stream == self.compute_stream
|
||||||
|
return "COMPUTE"
|
||||||
|
elif current_stream() == self.comm_stream:
|
||||||
|
assert self.current_stream == self.comm_stream
|
||||||
|
return "COMM"
|
||||||
|
|
||||||
def _cpu_yield(self):
|
def _cpu_yield(self):
|
||||||
print("UBatchContext: %d yielding CPU\n" % self.id)
|
# print(f"UBatchContext: {self.id} yielding CPU", flush=True)
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.cpu_signal_event.set()
|
self.cpu_signal_event.set()
|
||||||
self.cpu_wait_event.wait()
|
self.cpu_wait_event.wait()
|
||||||
self.cpu_wait_event.clear()
|
self.cpu_wait_event.clear()
|
||||||
self._restore_context()
|
self._restore_context()
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
print("UBatchContext: %d resuming CPU\n" % self.id)
|
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
||||||
|
|
||||||
def yield_and_switch_from_compute_to_comm(self):
|
def yield_and_switch_from_compute_to_comm(self):
|
||||||
print("Yield and switch from compute")
|
assert current_stream() == self.compute_stream
|
||||||
|
dp_rank = get_dp_group().rank_in_group
|
||||||
|
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self._signal_compute_done()
|
# self._signal_compute_done()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
assert self.current_stream == self.compute_stream
|
assert self.current_stream == self.compute_stream
|
||||||
self.update_stream(self.comm_stream)
|
self.update_stream(self.comm_stream)
|
||||||
self._wait_compute_done()
|
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||||
|
# self._wait_compute_done()
|
||||||
|
|
||||||
def yield_and_switch_from_comm_to_compute(self):
|
def yield_and_switch_from_comm_to_compute(self):
|
||||||
|
assert current_stream() == self.comm_stream
|
||||||
|
dp_rank = get_dp_group().rank_in_group
|
||||||
|
print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self._signal_comm_done()
|
# self._signal_comm_done()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
assert self.current_stream == self.comm_stream
|
assert self.current_stream == self.comm_stream
|
||||||
self.update_stream(self.compute_stream)
|
self.update_stream(self.compute_stream)
|
||||||
self._wait_comm_done()
|
print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
||||||
|
# self._wait_comm_done()
|
||||||
|
|
||||||
|
|
||||||
_CURRENT_CONTEXT: dict = {}
|
_CURRENT_CONTEXT: dict = {}
|
||||||
|
|
||||||
def get_current_ubatch_context() -> Optional[UBatchContext]:
|
def get_current_ubatch_context() -> Optional[UBatchContext]:
|
||||||
global _CURRENT_CONTEXT
|
global _CURRENT_CONTEXT
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user