mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 06:17:51 +08:00
add support for splitting dispatch/combine deepep ll kernels
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
1c41175b2a
commit
582d301f98
@ -1937,6 +1937,9 @@ class ParallelConfig:
|
|||||||
request is greater than this threshold, microbatching will be used.
|
request is greater than this threshold, microbatching will be used.
|
||||||
Otherwise, the request will be processed in a single batch."""
|
Otherwise, the request will be processed in a single batch."""
|
||||||
|
|
||||||
|
enable_async_comms: bool = False
|
||||||
|
"""enable async comms"""
|
||||||
|
|
||||||
ray_workers_use_nsight: bool = False
|
ray_workers_use_nsight: bool = False
|
||||||
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
|
||||||
|
|
||||||
|
|||||||
@ -304,6 +304,7 @@ class EngineArgs:
|
|||||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||||
enable_microbatching: bool = ParallelConfig.enable_microbatching
|
enable_microbatching: bool = ParallelConfig.enable_microbatching
|
||||||
|
enable_async_comms: bool = ParallelConfig.enable_async_comms
|
||||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||||
num_redundant_experts: int = ParallelConfig.num_redundant_experts
|
num_redundant_experts: int = ParallelConfig.num_redundant_experts
|
||||||
eplb_window_size: int = ParallelConfig.eplb_window_size
|
eplb_window_size: int = ParallelConfig.eplb_window_size
|
||||||
@ -640,6 +641,8 @@ class EngineArgs:
|
|||||||
**parallel_kwargs["enable_expert_parallel"])
|
**parallel_kwargs["enable_expert_parallel"])
|
||||||
parallel_group.add_argument("--enable-microbatching",
|
parallel_group.add_argument("--enable-microbatching",
|
||||||
**parallel_kwargs["enable_microbatching"])
|
**parallel_kwargs["enable_microbatching"])
|
||||||
|
parallel_group.add_argument("--enable-async-comms",
|
||||||
|
**parallel_kwargs["enable_async_comms"])
|
||||||
parallel_group.add_argument("--enable-eplb",
|
parallel_group.add_argument("--enable-eplb",
|
||||||
**parallel_kwargs["enable_eplb"])
|
**parallel_kwargs["enable_eplb"])
|
||||||
parallel_group.add_argument("--num-redundant-experts",
|
parallel_group.add_argument("--num-redundant-experts",
|
||||||
@ -1166,6 +1169,7 @@ class EngineArgs:
|
|||||||
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
|
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
|
||||||
enable_expert_parallel=self.enable_expert_parallel,
|
enable_expert_parallel=self.enable_expert_parallel,
|
||||||
enable_microbatching=self.enable_microbatching,
|
enable_microbatching=self.enable_microbatching,
|
||||||
|
enable_async_comms=self.enable_async_comms,
|
||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
num_redundant_experts=self.num_redundant_experts,
|
num_redundant_experts=self.num_redundant_experts,
|
||||||
eplb_window_size=self.eplb_window_size,
|
eplb_window_size=self.eplb_window_size,
|
||||||
|
|||||||
@ -127,6 +127,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
hidden_size = a1.size(1)
|
hidden_size = a1.size(1)
|
||||||
ubatch_ctx = get_current_ubatch_context()
|
ubatch_ctx = get_current_ubatch_context()
|
||||||
a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0
|
a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0
|
||||||
|
do_recv_hook = True if ubatch_ctx is not None and \
|
||||||
|
ubatch_ctx.enable_async_comms else False
|
||||||
|
|
||||||
if self.use_fp8_dispatch:
|
if self.use_fp8_dispatch:
|
||||||
assert hidden_size % 128 == 0, \
|
assert hidden_size % 128 == 0, \
|
||||||
@ -147,16 +149,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
# Dispatch
|
# Dispatch
|
||||||
yield_and_switch_from_compute_to_comm(schedule="default")
|
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||||
expert_x, expert_num_tokens, handle, _, _= \
|
expert_x, expert_num_tokens, handle, _, recv_hook= \
|
||||||
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
self.max_tokens_per_rank,
|
self.max_tokens_per_rank,
|
||||||
num_experts,
|
num_experts,
|
||||||
use_fp8=self.use_fp8_dispatch,
|
use_fp8=self.use_fp8_dispatch,
|
||||||
async_finish=False,
|
async_finish=False,
|
||||||
return_recv_hook=False)
|
return_recv_hook=do_recv_hook)
|
||||||
self.handles[a2a_idx] = handle
|
self.handles[a2a_idx] = handle
|
||||||
yield_and_switch_from_comm_to_compute(schedule="default")
|
yield_and_switch_from_comm_to_compute(schedule="default", recv_hook=recv_hook)
|
||||||
|
|
||||||
expert_x, expert_x_scale = self._do_quant(
|
expert_x, expert_x_scale = self._do_quant(
|
||||||
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
|
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
|
||||||
@ -178,6 +180,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
ubatch_ctx = get_current_ubatch_context()
|
ubatch_ctx = get_current_ubatch_context()
|
||||||
a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0
|
a2a_idx = ubatch_ctx.id if ubatch_ctx is not None else 0
|
||||||
handle = self.handles[a2a_idx]
|
handle = self.handles[a2a_idx]
|
||||||
|
do_recv_hook = True if ubatch_ctx is not None and \
|
||||||
|
ubatch_ctx.enable_async_comms else False
|
||||||
assert handle is not None
|
assert handle is not None
|
||||||
|
|
||||||
combine_topk_weights = topk_weights
|
combine_topk_weights = topk_weights
|
||||||
@ -187,12 +191,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
# TODO (varun) : Enable zero copy mode
|
# TODO (varun) : Enable zero copy mode
|
||||||
yield_and_switch_from_compute_to_comm(schedule="default")
|
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||||
_ = self.buffers[a2a_idx].low_latency_combine(fused_expert_output,
|
_, _, recv_hook = self.buffers[a2a_idx].low_latency_combine(fused_expert_output,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
combine_topk_weights,
|
combine_topk_weights,
|
||||||
handle,
|
handle,
|
||||||
async_finish=False,
|
async_finish=False,
|
||||||
zero_copy=False,
|
zero_copy=False,
|
||||||
return_recv_hook=False,
|
return_recv_hook=do_recv_hook,
|
||||||
out=output)
|
out=output)
|
||||||
yield_and_switch_from_comm_to_compute(schedule="default")
|
yield_and_switch_from_comm_to_compute(schedule="default", recv_hook=recv_hook)
|
||||||
|
|||||||
@ -1656,7 +1656,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
comm_stream=self.comm_stream,
|
comm_stream=self.comm_stream,
|
||||||
compute_stream=compute_stream,
|
compute_stream=compute_stream,
|
||||||
forward_contexts=forward_contexts,
|
forward_contexts=forward_contexts,
|
||||||
device=self.device)
|
device=self.device,
|
||||||
|
enable_async_comms=self.parallel_config.enable_async_comms)
|
||||||
|
|
||||||
ubatch_metadata: list[UbatchMetadata] = []
|
ubatch_metadata: list[UbatchMetadata] = []
|
||||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
|
|||||||
@ -24,6 +24,7 @@ class UBatchContext:
|
|||||||
cpu_signal_event: threading.Event,
|
cpu_signal_event: threading.Event,
|
||||||
gpu_comm_done_event: torch.cuda.Event,
|
gpu_comm_done_event: torch.cuda.Event,
|
||||||
gpu_compute_done_event: torch.cuda.Event,
|
gpu_compute_done_event: torch.cuda.Event,
|
||||||
|
enable_async_comms: bool,
|
||||||
schedule: str = "default"):
|
schedule: str = "default"):
|
||||||
self.id = id
|
self.id = id
|
||||||
self.comm_stream = comm_stream
|
self.comm_stream = comm_stream
|
||||||
@ -34,6 +35,7 @@ class UBatchContext:
|
|||||||
self.current_stream = compute_stream
|
self.current_stream = compute_stream
|
||||||
self.gpu_comm_done_event = gpu_comm_done_event
|
self.gpu_comm_done_event = gpu_comm_done_event
|
||||||
self.gpu_compute_done_event = gpu_compute_done_event
|
self.gpu_compute_done_event = gpu_compute_done_event
|
||||||
|
self.enable_async_comms = enable_async_comms
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@ -106,10 +108,14 @@ class UBatchContext:
|
|||||||
self.update_stream(self.comm_stream)
|
self.update_stream(self.comm_stream)
|
||||||
self._wait_compute_done()
|
self._wait_compute_done()
|
||||||
|
|
||||||
def yield_and_switch_from_comm_to_compute(self):
|
def yield_and_switch_from_comm_to_compute(self, recv_hook = None):
|
||||||
assert current_stream() == self.comm_stream
|
assert current_stream() == self.comm_stream
|
||||||
self._signal_comm_done()
|
if recv_hook is None:
|
||||||
|
self._signal_comm_done()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
|
if recv_hook is not None:
|
||||||
|
recv_hook()
|
||||||
|
self._signal_comm_done()
|
||||||
assert self.current_stream == self.comm_stream
|
assert self.current_stream == self.comm_stream
|
||||||
self.update_stream(self.compute_stream)
|
self.update_stream(self.compute_stream)
|
||||||
self._wait_comm_done()
|
self._wait_comm_done()
|
||||||
@ -133,11 +139,11 @@ def yield_and_switch_from_compute_to_comm(schedule="default"):
|
|||||||
ctx.yield_and_switch_from_compute_to_comm()
|
ctx.yield_and_switch_from_compute_to_comm()
|
||||||
|
|
||||||
|
|
||||||
def yield_and_switch_from_comm_to_compute(schedule="default"):
|
def yield_and_switch_from_comm_to_compute(schedule="default", recv_hook = None):
|
||||||
# Perform the barrier if a context exists for this thread
|
# Perform the barrier if a context exists for this thread
|
||||||
ctx = get_current_ubatch_context()
|
ctx = get_current_ubatch_context()
|
||||||
if ctx is not None and ctx.schedule == schedule:
|
if ctx is not None and ctx.schedule == schedule:
|
||||||
ctx.yield_and_switch_from_comm_to_compute()
|
ctx.yield_and_switch_from_comm_to_compute(recv_hook=recv_hook)
|
||||||
|
|
||||||
|
|
||||||
def make_ubatch_contexts(
|
def make_ubatch_contexts(
|
||||||
@ -146,6 +152,7 @@ def make_ubatch_contexts(
|
|||||||
comm_stream: torch.cuda.Stream,
|
comm_stream: torch.cuda.Stream,
|
||||||
forward_contexts: list[ForwardContext],
|
forward_contexts: list[ForwardContext],
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
|
enable_async_comms: bool = False,
|
||||||
schedule: str = "default",
|
schedule: str = "default",
|
||||||
) -> list[UBatchContext]:
|
) -> list[UBatchContext]:
|
||||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||||
@ -175,6 +182,7 @@ def make_ubatch_contexts(
|
|||||||
num_micro_batches],
|
num_micro_batches],
|
||||||
gpu_comm_done_event=gpu_comm_done_events[i],
|
gpu_comm_done_event=gpu_comm_done_events[i],
|
||||||
gpu_compute_done_event=gpu_compute_done_events[i],
|
gpu_compute_done_event=gpu_compute_done_events[i],
|
||||||
|
enable_async_comms=enable_async_comms,
|
||||||
schedule=schedule)
|
schedule=schedule)
|
||||||
ctxs.append(ctx)
|
ctxs.append(ctx)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user