better debug utils

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-23 18:23:29 +00:00
parent 952f3c5c1e
commit e4419df256
2 changed files with 20 additions and 4 deletions

View File

@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from typing import Optional
import torch
from vllm.v1.worker.ubatching import get_current_ubatch_context
from vllm.v1.worker.ubatching import get_current_ubatch_context, dump_ubatching_state
#
# This file defines a set of base classes used to make MoE kernels more modular.
@ -347,6 +347,7 @@ class FusedMoEModularKernel(torch.nn.Module):
# print("in modular moe2, ubatch:", ubatch_ctx.id, self.fused_experts)
print("pre synchronize")
dump_ubatching_state()
torch.cuda.synchronize(a1.device)
print("post synchronize")

View File

@ -40,9 +40,9 @@ class UBatchContext:
global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = self
self._cpu_wait()
start_event = torch.cuda.Event()
self.original_stream.record_event(start_event)
self.stream.wait_event(start_event)
# start_event = torch.cuda.Event()
# self.original_stream.record_event(start_event)
# self.stream.wait_event(start_event)
print("Starting ubatch %d" % self.id)
# if self.gpu_wait_on_launch:
# self.gpu_stream_wait()
@ -131,6 +131,21 @@ def yield_(x: torch.Tensor, schedule: str="default") -> None:
def yield_(x: torch.Tensor, schedule: str="default") -> None:
pass
def dump_ubatching_state():
"""
Dump the current UBatchContext state for debugging.
"""
for ctx in _CURRENT_CONTEXT.values():
print(f"UBatchContext: {ctx.id}\n"
f" Stream: {ctx.stream}, ({ctx.stream.query()})\n"
f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n"
f" CPU Wait Event: {ctx.cpu_wait_event}\n"
f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n"
f" CPU Signal Event: {ctx.cpu_signal_event}\n"
f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n")
"""
"""