mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 22:16:42 +08:00
cleanups for ubatching.py
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
2f3461ad23
commit
83caef8bac
@ -8,8 +8,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
moe_kernel_quantize_input)
|
moe_kernel_quantize_input)
|
||||||
from vllm.v1.worker.ubatching import (
|
from vllm.v1.worker.ubatching import (
|
||||||
get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl,
|
get_current_ubatch_context, yield_and_switch_from_comm_to_compute,
|
||||||
yield_and_switch_from_compute_to_comm_impl)
|
yield_and_switch_from_compute_to_comm)
|
||||||
|
|
||||||
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
||||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||||
@ -154,7 +154,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||||
|
|
||||||
# Dispatch
|
# Dispatch
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||||
expert_x, expert_num_tokens, handle, _, _= \
|
expert_x, expert_num_tokens, handle, _, _= \
|
||||||
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
||||||
rank_topk_ids,
|
rank_topk_ids,
|
||||||
@ -164,7 +164,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
async_finish=False,
|
async_finish=False,
|
||||||
return_recv_hook=False)
|
return_recv_hook=False)
|
||||||
self.handles[a2a_idx] = handle
|
self.handles[a2a_idx] = handle
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute(schedule="default")
|
||||||
|
|
||||||
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
|
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
|
||||||
a1.dtype)
|
a1.dtype)
|
||||||
@ -186,7 +186,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
combine_topk_weights = torch.ones_like(topk_weights)
|
combine_topk_weights = torch.ones_like(topk_weights)
|
||||||
|
|
||||||
# TODO (varun) : Enable zero copy mode
|
# TODO (varun) : Enable zero copy mode
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||||
_ = self.buffers[a2a_idx].low_latency_combine(
|
_ = self.buffers[a2a_idx].low_latency_combine(
|
||||||
fused_expert_output,
|
fused_expert_output,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@ -196,5 +196,5 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
zero_copy=False,
|
zero_copy=False,
|
||||||
return_recv_hook=False,
|
return_recv_hook=False,
|
||||||
out=output)
|
out=output)
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute(schedule="default")
|
||||||
|
|
||||||
|
|||||||
@ -9,8 +9,8 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
moe_kernel_quantize_input)
|
moe_kernel_quantize_input)
|
||||||
from vllm.v1.worker.ubatching import (
|
from vllm.v1.worker.ubatching import (
|
||||||
get_current_ubatch_context, yield_and_switch_from_comm_to_compute_impl,
|
get_current_ubatch_context, yield_and_switch_from_comm_to_compute,
|
||||||
yield_and_switch_from_compute_to_comm_impl)
|
yield_and_switch_from_compute_to_comm)
|
||||||
|
|
||||||
|
|
||||||
# The max_num_tokens, world_size and dp_size must be the same
|
# The max_num_tokens, world_size and dp_size must be the same
|
||||||
@ -120,7 +120,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
# There's not much point setting this unless it is != indices.size(0)
|
# There's not much point setting this unless it is != indices.size(0)
|
||||||
bound_m: Optional[torch.Tensor] = None
|
bound_m: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||||
self.a2as[a2a_idx].dispatch(
|
self.a2as[a2a_idx].dispatch(
|
||||||
out_expert_num_tokens=expert_num_tokens,
|
out_expert_num_tokens=expert_num_tokens,
|
||||||
out_expert_x=expert_x,
|
out_expert_x=expert_x,
|
||||||
@ -130,7 +130,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
indices=rank_topk_ids,
|
indices=rank_topk_ids,
|
||||||
bound_m=bound_m,
|
bound_m=bound_m,
|
||||||
)
|
)
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute(schedule="default")
|
||||||
if expert_x_scale is not None:
|
if expert_x_scale is not None:
|
||||||
expert_x_scale = expert_x_scale[:, :, 0:1]
|
expert_x_scale = expert_x_scale[:, :, 0:1]
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
topk_weights = torch.ones_like(topk_weights)
|
topk_weights = torch.ones_like(topk_weights)
|
||||||
|
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
yield_and_switch_from_compute_to_comm(schedule="default")
|
||||||
self.a2as[a2a_idx].combine(
|
self.a2as[a2a_idx].combine(
|
||||||
out_tokens=output,
|
out_tokens=output,
|
||||||
indices=topk_ids,
|
indices=topk_ids,
|
||||||
@ -170,4 +170,4 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_y=fused_expert_output,
|
expert_y=fused_expert_output,
|
||||||
bound_m=bound_m,
|
bound_m=bound_m,
|
||||||
)
|
)
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
yield_and_switch_from_comm_to_compute(schedule="default")
|
||||||
|
|||||||
@ -3,12 +3,9 @@ import threading
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo
|
|
||||||
from torch.library import custom_op
|
|
||||||
|
|
||||||
from vllm import forward_context
|
from vllm import forward_context
|
||||||
from vllm.utils import current_stream
|
from vllm.utils import current_stream
|
||||||
from vllm.distributed.parallel_state import get_dp_group
|
|
||||||
|
|
||||||
|
|
||||||
class UBatchContext:
|
class UBatchContext:
|
||||||
@ -75,31 +72,26 @@ class UBatchContext:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _signal_comm_done(self):
|
def _signal_comm_done(self):
|
||||||
# assert False
|
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.gpu_comm_done_event.record(self.comm_stream)
|
self.gpu_comm_done_event.record(self.comm_stream)
|
||||||
|
|
||||||
def _signal_compute_done(self):
|
def _signal_compute_done(self):
|
||||||
# assert False
|
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.gpu_compute_done_event.record(self.compute_stream)
|
self.gpu_compute_done_event.record(self.compute_stream)
|
||||||
|
|
||||||
def _wait_compute_done(self):
|
def _wait_compute_done(self):
|
||||||
# assert False
|
|
||||||
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||||
# print("Compute stream done", flush=True)
|
# print("Compute stream done", flush=True)
|
||||||
|
|
||||||
def _wait_comm_done(self):
|
def _wait_comm_done(self):
|
||||||
# assert False
|
|
||||||
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||||
# print("Comm stream done", flush=True)
|
# print("Comm stream done", flush=True)
|
||||||
|
|
||||||
def stream_string(self):
|
def stream_string(self):
|
||||||
# assert False
|
|
||||||
if current_stream() == self.compute_stream:
|
if current_stream() == self.compute_stream:
|
||||||
assert self.current_stream == self.compute_stream
|
assert self.current_stream == self.compute_stream
|
||||||
return "COMPUTE"
|
return "COMPUTE"
|
||||||
@ -118,7 +110,6 @@ class UBatchContext:
|
|||||||
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
||||||
|
|
||||||
def yield_and_switch_from_compute_to_comm(self):
|
def yield_and_switch_from_compute_to_comm(self):
|
||||||
# assert False
|
|
||||||
assert current_stream() == self.compute_stream
|
assert current_stream() == self.compute_stream
|
||||||
# dp_rank = get_dp_group().rank_in_group
|
# dp_rank = get_dp_group().rank_in_group
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
@ -134,7 +125,6 @@ class UBatchContext:
|
|||||||
self._wait_compute_done()
|
self._wait_compute_done()
|
||||||
|
|
||||||
def yield_and_switch_from_comm_to_compute(self):
|
def yield_and_switch_from_comm_to_compute(self):
|
||||||
# assert False
|
|
||||||
assert current_stream() == self.comm_stream
|
assert current_stream() == self.comm_stream
|
||||||
# dp_rank = get_dp_group().rank_in_group
|
# dp_rank = get_dp_group().rank_in_group
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
@ -161,7 +151,7 @@ def get_current_ubatch_context() -> Optional[UBatchContext]:
|
|||||||
return _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
return _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
||||||
|
|
||||||
|
|
||||||
def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
|
def yield_and_switch_from_compute_to_comm(schedule="default"):
|
||||||
# Perform the barrier if a context exists for this thread
|
# Perform the barrier if a context exists for this thread
|
||||||
ctx = get_current_ubatch_context()
|
ctx = get_current_ubatch_context()
|
||||||
#print("you are in yield_impl", ctx)
|
#print("you are in yield_impl", ctx)
|
||||||
@ -169,50 +159,12 @@ def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
|
|||||||
ctx.yield_and_switch_from_compute_to_comm()
|
ctx.yield_and_switch_from_compute_to_comm()
|
||||||
|
|
||||||
|
|
||||||
def yield_and_switch_from_comm_to_compute_impl(schedule="default"):
|
def yield_and_switch_from_comm_to_compute(schedule="default"):
|
||||||
# Perform the barrier if a context exists for this thread
|
# Perform the barrier if a context exists for this thread
|
||||||
ctx = get_current_ubatch_context()
|
ctx = get_current_ubatch_context()
|
||||||
if ctx is not None and ctx.schedule == schedule:
|
if ctx is not None and ctx.schedule == schedule:
|
||||||
ctx.yield_and_switch_from_comm_to_compute()
|
ctx.yield_and_switch_from_comm_to_compute()
|
||||||
|
|
||||||
|
|
||||||
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
|
|
||||||
# optimizing it away (TODO: see if this is actually needed)
|
|
||||||
@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x", ))
|
|
||||||
def yield_and_switch_from_compute_to_comm(x: torch.Tensor,
|
|
||||||
schedule: str = "default") -> None:
|
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule)
|
|
||||||
|
|
||||||
|
|
||||||
# 3) Fake implementation for shape prop and FX tracing
|
|
||||||
@yield_and_switch_from_compute_to_comm.register_fake
|
|
||||||
def yield_and_switch_from_compute_to_comm_fake(x: torch.Tensor,
|
|
||||||
schedule: str = "default"
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x", ))
|
|
||||||
def yield_and_switch_from_comm_to_compute(x: torch.Tensor,
|
|
||||||
schedule: str = "default") -> None:
|
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule)
|
|
||||||
|
|
||||||
|
|
||||||
@yield_and_switch_from_comm_to_compute.register_fake
|
|
||||||
def yield_and_switch_from_comm_to_compute_fake(x: torch.Tensor,
|
|
||||||
schedule: str = "default"
|
|
||||||
) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def dump_ubatching_state():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def make_ubatch_contexts(
|
def make_ubatch_contexts(
|
||||||
num_micro_batches: int,
|
num_micro_batches: int,
|
||||||
compute_stream: torch.cuda.Stream,
|
compute_stream: torch.cuda.Stream,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user