debugging hang

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-23 15:22:50 +00:00
parent 2dc3b8b0a2
commit 9edd08231b
7 changed files with 71 additions and 20 deletions

View File

@ -661,6 +661,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
f"Hidden size mismatch {hidden_states.size(-1)} "
f"!= {w1.size(2)}")
print("in batched triton experts", hidden_states.shape, expert_num_tokens)
assert hidden_states.is_contiguous(
), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"

View File

@ -29,6 +29,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op
from vllm.v1.worker.ubatching import get_current_ubatch_context
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
@ -442,7 +443,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
print("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
@ -454,7 +456,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
print("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
@ -1298,6 +1300,9 @@ class FusedMoE(torch.nn.Module):
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
if (ubatch_ctdx := get_current_ubatch_context()) is not None:
print("in fused moe, ubatch:", ubatch_ctdx.id, "chunk size:", max_tokens_across_dp, "moe_dp_chunk_size_per_rank", moe_dp_chunk_size_per_rank)
num_tokens = full_hidden_states.size(0)
for chunk_start_ in range(0, max_tokens_across_dp,
moe_dp_chunk_size_per_rank):
@ -1396,6 +1401,8 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in fused moe, ubatch:", ubatch_ctx.id, self)
return self.forward_impl(hidden_states, router_logits)

View File

@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import Optional
import torch
from vllm.v1.worker.ubatching import get_current_ubatch_context
#
# This file defines a set of base classes used to make MoE kernels more modular.
@ -335,10 +336,20 @@ class FusedMoEModularKernel(torch.nn.Module):
device=a1.device,
dtype=workspace_dtype)
if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in modular moe, ubatch:", ubatch_ctx.id)
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
expert_map, apply_router_weight_on_input)
if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in modular moe2, ubatch:", ubatch_ctx.id, self.fused_experts)
print("pre synchronize")
torch.cuda.synchronize(a1.device)
print("post synchronize")
fused_out = self.fused_experts.apply(
a1q,
w1,
@ -358,6 +369,9 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens=expert_num_tokens,
)
if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in modular moe3, ubatch:", ubatch_ctx.id, self.fused_experts)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input)

View File

@ -119,11 +119,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send,
)
# if ubatch_ctx is not None:
# ubatch_ctx.gpu_stream_wait()
print("Dispatch pre-wait")
if (ubatch_ctx := get_current_ubatch_context()) is not None:
ubatch_ctx.gpu_stream_wait()
print("Dispatch launched")
dispatch(True) # Send
yield_impl(gpu_wait=False)
dispatch(False) # Recv
print("Finished dispatch")
return expert_x, expert_x_scale, expert_num_tokens
@ -160,8 +163,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_send=send,
do_recv=not send,
)
# if ubatch_ctx is not None:
# ubatch_ctx.gpu_stream_wait()
print("Combine pre-wait")
if (ubatch_ctx := get_current_ubatch_context()) is not None:
ubatch_ctx.gpu_stream_wait()
combine(True)
print("Combine launched")
yield_impl(gpu_wait=False)
combine(False)
print("Finished combine")

View File

@ -564,6 +564,8 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in decoder, ubatch:", ubatch_ctx.id)
# Self Attention
if residual is None:
residual = hidden_states
@ -657,7 +659,7 @@ class DeepseekV2Model(nn.Module):
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if ubatch_ctx := get_current_ubatch_context() is not None:
if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in forward, ubatch:", ubatch_ctx.id)
if get_pp_group().is_first_rank:

View File

@ -1287,10 +1287,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output: Optional["SchedulerOutput"] = None,
is_dummy_run: bool = False):
num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
if use_dummy_input:
num_tokens = num_scheduled_tokens or 1
return self._get_dummy_model_inputs(num_tokens)
return self._get_dummy_model_inputs(num_dummy_tokens)
else:
assert scheduler_output is not None
return self._get_model_inputs(tokens_slice, scheduler_output)
@ -1301,7 +1302,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
model_inputs(token_slice, use_dummy_input)
with context:
if isinstance(context, UBatchContext):
print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape}")
print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape} use_dummy_input {use_dummy_input} token_slice {token_slice}")
model_output = self.model(
input_ids=input_ids,
positions=positions,
@ -1333,19 +1334,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not hasattr(self, "ubatch_streams"):
# Create the ubatch streams
self.ubatch_streams = [torch.cuda.Stream(self.device) for _ in range(len(ubatch_slices))]
ubatch_fwd_ctxs = [create_forward_context(
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start)
) for i, (_, tokens_slice) in enumerate(ubatch_slices)]
# We have to be careful creating the forward contexts here otherwise we can end
# up with the dummy contexts have num_tokens set to 0
# ubatch_fwd_ctxs = [create_forward_context(
# attn_metadata[i] if attn_metadata is not None else None,
# self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start)
# ) for i, (_, tokens_slice) in enumerate(ubatch_slices)]
ubatch_ctxs, start_hook = make_ubatch_context_chain(
len(ubatch_slices),
fwd_ctxs=ubatch_fwd_ctxs,
#fwd_ctxs=ubatch_fwd_ctxs,
streams=self.ubatch_streams, #stream=root_stream, # Only works currently if everything is run on the same stream
device=self.device)
setup_done = threading.Event()
ubatch_threads = []
# Initialize Events? not sure if this helps
for ubatch_ctx in ubatch_ctxs:
ubatch_ctx.gpu_wait_event.record(ubatch_ctx.stream)
ubatch_ctx.stream.wait_event(ubatch_ctx.gpu_wait_event)
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
@ -1354,6 +1363,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
print("ubatch", i, "tokens slice", tokens_slice, "is dummy ubatch", is_dummy_ubatch, "is dummy run", is_dummy_run)
num_tokens = num_dummy_tokens if is_dummy_ubatch or is_dummy_run else (tokens_slice.stop - tokens_slice.start)
ubatch_ctxs[i].forward_context = create_forward_context(
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config, num_tokens=num_tokens)
thread = threading.Thread(target=_ubatch_thread, args=(
ubatch_ctxs[i],
root_stream,

View File

@ -17,7 +17,7 @@ class UBatchContext:
def __init__(self,
id: int,
stream: torch.cuda.Stream,
fwd_ctx: forward_context.ForwardContext,
#fwd_ctx: forward_context.ForwardContext,
cpu_wait_event: threading.Event,
cpu_signal_event: threading.Event,
gpu_wait_event: torch.cuda.Event,
@ -27,7 +27,7 @@ class UBatchContext:
self.id = id
self.stream = stream
self.original_stream = current_stream()
self.forward_context = fwd_ctx
self.forward_context = None #fwd_ctx
self.cpu_wait_event = cpu_wait_event
self.cpu_signal_event = cpu_signal_event
self.gpu_wait_event = gpu_wait_event
@ -80,6 +80,7 @@ class UBatchContext:
# until ubatch0-dispatch is done avoiding overlapping dispatches that
# might share underlying buffers
def gpu_stream_wait(self):
print("Waiting ubatch %d on %s in stream %s" % (self.id, self.gpu_wait_event, self.stream))
self.stream.wait_event(self.gpu_wait_event)
def _yield(self, gpu_wait: bool = True):
@ -92,6 +93,7 @@ class UBatchContext:
def _signal(self):
# Wait for the next batch to signal back
print(f"signaling ubatch {self.id} to {self.gpu_signal_event} on {self.stream}")
self.gpu_signal_event.record(self.stream)
# Signal that this batch reached the barrier
self.cpu_signal_event.set()
@ -134,7 +136,7 @@ def yield_(x: torch.Tensor, schedule: str="default") -> None:
"""
def make_ubatch_context_chain(
num_micro_batches: int,
fwd_ctxs: forward_context.ForwardContext,
#fwd_ctxs: forward_context.ForwardContext,
streams: Optional[list[torch.Stream]] = None,
device: Optional[torch.device] = None
) -> list[UBatchContext]:
@ -152,7 +154,7 @@ def make_ubatch_context_chain(
stream = (streams[i] if streams else None) or torch.cuda.Stream(device)
ctx = UBatchContext(id=i,
stream=stream,
fwd_ctx=fwd_ctxs[i],
#fwd_ctx=fwd_ctxs[i],
cpu_wait_event=cpu_events[i],
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
gpu_wait_event=gpu_events[i],
@ -163,6 +165,7 @@ def make_ubatch_context_chain(
def start_hook(from_stream: torch.cuda.Stream):
ctxs[0].gpu_wait_event.record(from_stream)
print('singal to ubatch %d event %s from stream %s' % (ctxs[0].id, ctxs[0].gpu_wait_event, from_stream))
ctxs[0].cpu_wait_event.set()
return ctxs, start_hook