mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 21:47:05 +08:00
cleanups for ubatching.py
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
2f3461ad23
commit
83caef8bac
@ -8,8 +8,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.v1.worker.ubatching import (
|
||||
get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl,
|
||||
yield_and_switch_from_compute_to_comm_impl)
|
||||
get_current_ubatch_context, yield_and_switch_from_comm_to_compute,
|
||||
yield_and_switch_from_compute_to_comm)
|
||||
|
||||
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||
@ -154,7 +154,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||
|
||||
# Dispatch
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||
expert_x, expert_num_tokens, handle, _, _= \
|
||||
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
||||
rank_topk_ids,
|
||||
@ -164,7 +164,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
async_finish=False,
|
||||
return_recv_hook=False)
|
||||
self.handles[a2a_idx] = handle
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
yield_and_switch_from_comm_to_compute(schedule="default")
|
||||
|
||||
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
|
||||
a1.dtype)
|
||||
@ -186,7 +186,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
combine_topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
# TODO (varun) : Enable zero copy mode
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||
_ = self.buffers[a2a_idx].low_latency_combine(
|
||||
fused_expert_output,
|
||||
topk_ids,
|
||||
@ -196,5 +196,5 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
zero_copy=False,
|
||||
return_recv_hook=False,
|
||||
out=output)
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
yield_and_switch_from_comm_to_compute(schedule="default")
|
||||
|
||||
|
||||
@ -9,8 +9,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.v1.worker.ubatching import (
|
||||
get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl,
|
||||
yield_and_switch_from_compute_to_comm_impl)
|
||||
get_current_ubatch_context, yield_and_switch_from_comm_to_compute,
|
||||
yield_and_switch_from_compute_to_comm)
|
||||
|
||||
|
||||
# The max_num_tokens, world_size and dp_size must be the same
|
||||
@ -120,7 +120,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# There's not much point setting this unless it is != indices.size(0)
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||
self.a2as[a2a_idx].dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
@ -130,7 +130,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
indices=rank_topk_ids,
|
||||
bound_m=bound_m,
|
||||
)
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
yield_and_switch_from_comm_to_compute(schedule="default")
|
||||
if expert_x_scale is not None:
|
||||
expert_x_scale = expert_x_scale[:, :, 0:1]
|
||||
|
||||
@ -162,7 +162,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
if apply_router_weight_on_input:
|
||||
topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||
self.a2as[a2a_idx].combine(
|
||||
out_tokens=output,
|
||||
indices=topk_ids,
|
||||
@ -170,4 +170,4 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m,
|
||||
)
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
yield_and_switch_from_comm_to_compute(schedule="default")
|
||||
|
||||
@ -3,12 +3,9 @@ import threading
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
from torch.library import custom_op
|
||||
|
||||
from vllm import forward_context
|
||||
from vllm.utils import current_stream
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
|
||||
|
||||
class UBatchContext:
|
||||
@ -75,31 +72,26 @@ class UBatchContext:
|
||||
pass
|
||||
|
||||
def _signal_comm_done(self):
|
||||
# assert False
|
||||
self.ctx_valid_state()
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
|
||||
def _signal_compute_done(self):
|
||||
# assert False
|
||||
self.ctx_valid_state()
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
|
||||
def _wait_compute_done(self):
|
||||
# assert False
|
||||
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||
# print("Compute stream done", flush=True)
|
||||
|
||||
def _wait_comm_done(self):
|
||||
# assert False
|
||||
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||
# print("Comm stream done", flush=True)
|
||||
|
||||
def stream_string(self):
|
||||
# assert False
|
||||
if current_stream() == self.compute_stream:
|
||||
assert self.current_stream == self.compute_stream
|
||||
return "COMPUTE"
|
||||
@ -118,7 +110,6 @@ class UBatchContext:
|
||||
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
# assert False
|
||||
assert current_stream() == self.compute_stream
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
@ -134,7 +125,6 @@ class UBatchContext:
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
# assert False
|
||||
assert current_stream() == self.comm_stream
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
@ -161,7 +151,7 @@ def get_current_ubatch_context() -> Optional[UBatchContext]:
|
||||
return _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
||||
|
||||
|
||||
def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
|
||||
def yield_and_switch_from_compute_to_comm(schedule="default"):
|
||||
# Perform the barrier if a context exists for this thread
|
||||
ctx = get_current_ubatch_context()
|
||||
#print("you are in yield_impl", ctx)
|
||||
@ -169,50 +159,12 @@ def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
|
||||
ctx.yield_and_switch_from_compute_to_comm()
|
||||
|
||||
|
||||
def yield_and_switch_from_comm_to_compute_impl(schedule="default"):
|
||||
def yield_and_switch_from_comm_to_compute(schedule="default"):
|
||||
# Perform the barrier if a context exists for this thread
|
||||
ctx = get_current_ubatch_context()
|
||||
if ctx is not None and ctx.schedule == schedule:
|
||||
ctx.yield_and_switch_from_comm_to_compute()
|
||||
|
||||
|
||||
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
|
||||
# optimizing it away (TODO: see if this is actually needed)
|
||||
@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x", ))
|
||||
def yield_and_switch_from_compute_to_comm(x: torch.Tensor,
|
||||
schedule: str = "default") -> None:
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule)
|
||||
|
||||
|
||||
# 3) Fake implementation for shape prop and FX tracing
|
||||
@yield_and_switch_from_compute_to_comm.register_fake
|
||||
def yield_and_switch_from_compute_to_comm_fake(x: torch.Tensor,
|
||||
schedule: str = "default"
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x", ))
|
||||
def yield_and_switch_from_comm_to_compute(x: torch.Tensor,
|
||||
schedule: str = "default") -> None:
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule)
|
||||
|
||||
|
||||
@yield_and_switch_from_comm_to_compute.register_fake
|
||||
def yield_and_switch_from_comm_to_compute_fake(x: torch.Tensor,
|
||||
schedule: str = "default"
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def dump_ubatching_state():
|
||||
pass
|
||||
|
||||
|
||||
"""
|
||||
"""
|
||||
|
||||
|
||||
def make_ubatch_contexts(
|
||||
num_micro_batches: int,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user