Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-27 18:14:59 +00:00
parent 7b31e8a8ff
commit a743a35948
3 changed files with 18 additions and 51 deletions

View File

@ -336,21 +336,10 @@ class FusedMoEModularKernel(torch.nn.Module):
device=a1.device,
dtype=workspace_dtype)
# if (ubatch_ctx := get_current_ubatch_context()) is not None:
# print("in modular moe, ubatch:", ubatch_ctx.id)
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
expert_map, apply_router_weight_on_input)
# if (ubatch_ctx := get_current_ubatch_context()) is not None:
# print("in modular moe2, ubatch:", ubatch_ctx.id, self.fused_experts)
print("pre synchronize")
dump_ubatching_state()
torch.cuda.synchronize(a1.device)
print("post synchronize")
fused_out = self.fused_experts.apply(
a1q,
w1,

View File

@ -1318,9 +1318,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return model_output
@torch.inference_mode()
def _ubatch_thread(ubatch_ctx, root_stream, token_slice, attn_metadata, results, save_results, use_dummy_input, setup_done_evt):
ubatch_ctx.stream.wait_stream(root_stream)
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, use_dummy_input):
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
if save_results:
@ -1330,24 +1328,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
results = []
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
root_stream = current_stream()
if not hasattr(self, "ubatch_streams"):
# Create the ubatch streams
self.ubatch_streams = [torch.cuda.Stream(self.device) for _ in range(len(ubatch_slices))]
# We have to be careful creating the forward contexts here otherwise we can end
# up with the dummy contexts have num_tokens set to 0
# ubatch_fwd_ctxs = [create_forward_context(
# attn_metadata[i] if attn_metadata is not None else None,
# self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start)
# ) for i, (_, tokens_slice) in enumerate(ubatch_slices)]
ubatch_ctxs, start_hook = make_ubatch_contexts(
ubatch_ctxs = make_ubatch_contexts(
len(ubatch_slices),
compute_stream=root_stream,
device=self.device)
setup_done = threading.Event()
ubatch_threads = []
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
@ -1366,20 +1351,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
thread = threading.Thread(target=_ubatch_thread, args=(
ubatch_ctxs[i],
root_stream,
tokens_slice,
attn_metadata[i] if attn_metadata is not None else None,
results,
not is_dummy_ubatch or is_dummy_run,
is_dummy_ubatch or is_dummy_run,
setup_done,
))
ubatch_threads.append(thread)
thread.start()
# Single the first ubatch to start
start_hook(root_stream)
for thread in ubatch_threads:
thread.join()

View File

@ -39,6 +39,7 @@ class UBatchContext:
def __enter__(self):
global _CURRENT_CONTEXT
_CURRENT_CONTEXT[threading.get_ident()] = self
self._restore_context()
# Assume we start on the compute stream
assert current_stream() == self.compute_stream, \
"Expected to start on the compute stream, but found %s" % current_stream()
@ -52,9 +53,6 @@ class UBatchContext:
return False
def _restore_context(self):
# When we resume i.e. switch back to this micro-batch, we make sure
# we have the correct stream and forward context
torch.cuda.set_stream(self.stream)
forward_context._forward_context = self.forward_context
def _signal_comm_done(self):
@ -130,20 +128,21 @@ def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="defaul
pass
def dump_ubatching_state():
"""
Dump the current UBatchContext state for debugging.
"""
pass
# """
# Dump the current UBatchContext state for debugging.
# """
dp_rank = os.getenv("VLLM_DP_RANK", None)
# dp_rank = os.getenv("VLLM_DP_RANK", None)
for ctx in _CURRENT_CONTEXT.values():
print(f"UBatchContext: {ctx.id} (dp_rank {dp_rank})\n"
f" Stream: {ctx.stream}, ({ctx.stream.query()})\n"
f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n"
f" CPU Wait Event: {ctx.cpu_wait_event}\n"
f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n"
f" CPU Signal Event: {ctx.cpu_signal_event}\n"
f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n")
# for ctx in _CURRENT_CONTEXT.values():
# print(f"UBatchContext: {ctx.id} (dp_rank {dp_rank})\n"
# f" Stream: {ctx.stream}, ({ctx.stream.query()})\n"
# f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n"
# f" CPU Wait Event: {ctx.cpu_wait_event}\n"
# f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n"
# f" CPU Signal Event: {ctx.cpu_signal_event}\n"
# f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n")
"""
"""
@ -181,4 +180,4 @@ def make_ubatch_contexts(
)
ctxs.append(ctx)
return ctxs,
return ctxs