From e4419df256d2d32c834774bd727af435661f832c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 23 May 2025 18:23:29 +0000 Subject: [PATCH] better debug utils Signed-off-by: Lucas Wilkinson --- .../layers/fused_moe/modular_kernel.py | 3 ++- vllm/v1/worker/ubatching.py | 21 ++++++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 47d0880ee8071..35eef966e2717 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from typing import Optional import torch -from vllm.v1.worker.ubatching import get_current_ubatch_context +from vllm.v1.worker.ubatching import get_current_ubatch_context, dump_ubatching_state # # This file defines a set of base classes used to make MoE kernels more modular. @@ -347,6 +347,7 @@ class FusedMoEModularKernel(torch.nn.Module): # print("in modular moe2, ubatch:", ubatch_ctx.id, self.fused_experts) print("pre synchronize") + dump_ubatching_state() torch.cuda.synchronize(a1.device) print("post synchronize") diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index c4026f7eae014..4c8801f5c6157 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -40,9 +40,9 @@ class UBatchContext: global _CURRENT_CONTEXT _CURRENT_CONTEXT[threading.get_ident()] = self self._cpu_wait() - start_event = torch.cuda.Event() - self.original_stream.record_event(start_event) - self.stream.wait_event(start_event) + # start_event = torch.cuda.Event() + # self.original_stream.record_event(start_event) + # self.stream.wait_event(start_event) print("Starting ubatch %d" % self.id) # if self.gpu_wait_on_launch: # self.gpu_stream_wait() @@ -131,6 +131,21 @@ def yield_(x: torch.Tensor, schedule: str="default") -> None: def yield_(x: torch.Tensor, schedule: str="default") -> None: pass +def dump_ubatching_state(): + """ + Dump the current UBatchContext state for debugging. + """ + for ctx in _CURRENT_CONTEXT.values(): + print(f"UBatchContext: {ctx.id}\n" + f" Stream: {ctx.stream}, ({ctx.stream.query()})\n" + f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n" + f" CPU Wait Event: {ctx.cpu_wait_event}\n" + f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n" + f" CPU Signal Event: {ctx.cpu_signal_event}\n" + f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n") + + + """ """