diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index c2db793659312..b28c441769204 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -661,6 +661,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): f"Hidden size mismatch {hidden_states.size(-1)} " f"!= {w1.size(2)}") + print("in batched triton experts", hidden_states.shape, expert_num_tokens) + assert hidden_states.is_contiguous( ), "Hidden_states must be contiguous" assert w1.stride(-1) == 1, "Stride of last dimension must be 1" diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f1cb77f64eae7..4eea5714e1be3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -29,6 +29,7 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op +from vllm.v1.worker.ubatching import get_current_ubatch_context has_pplx = importlib.util.find_spec("pplx_kernels") is not None @@ -442,7 +443,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): if isinstance(prepare_finalize, (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): - logger.debug("BatchedTritonExperts %s", self.moe) + print("BatchedTritonExperts %s", self.moe) + experts = BatchedTritonExperts( max_num_tokens=MOE_DP_CHUNK_SIZE, world_size=world_size, @@ -454,7 +456,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): block_shape=None, ) else: - logger.debug("TritonExperts %s", self.moe) + print("TritonExperts %s", self.moe) experts = TritonExperts( use_fp8_w8a8=False, use_int8_w8a8=False, @@ -1298,6 +1300,9 @@ class FusedMoE(torch.nn.Module): max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE + if (ubatch_ctdx := get_current_ubatch_context()) is not None: + print("in fused moe, ubatch:", ubatch_ctdx.id, "chunk size:", max_tokens_across_dp, "moe_dp_chunk_size_per_rank", moe_dp_chunk_size_per_rank) + num_tokens = full_hidden_states.size(0) for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): @@ -1396,6 +1401,8 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, forward_context: ForwardContext = get_forward_context() self = forward_context.no_compile_layers[layer_name] assert self.quant_method is not None + if (ubatch_ctx := get_current_ubatch_context()) is not None: + print("in fused moe, ubatch:", ubatch_ctx.id, self) return self.forward_impl(hidden_states, router_logits) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 7d3ddf8f14c4d..56317f6ee6adc 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Optional import torch +from vllm.v1.worker.ubatching import get_current_ubatch_context # # This file defines a set of base classes used to make MoE kernels more modular. @@ -335,10 +336,20 @@ class FusedMoEModularKernel(torch.nn.Module): device=a1.device, dtype=workspace_dtype) + if (ubatch_ctx := get_current_ubatch_context()) is not None: + print("in modular moe, ubatch:", ubatch_ctx.id) + a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare( a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, expert_map, apply_router_weight_on_input) + if (ubatch_ctx := get_current_ubatch_context()) is not None: + print("in modular moe2, ubatch:", ubatch_ctx.id, self.fused_experts) + + print("pre synchronize") + torch.cuda.synchronize(a1.device) + print("post synchronize") + fused_out = self.fused_experts.apply( a1q, w1, @@ -358,6 +369,9 @@ class FusedMoEModularKernel(torch.nn.Module): expert_num_tokens=expert_num_tokens, ) + if (ubatch_ctx := get_current_ubatch_context()) is not None: + print("in modular moe3, ubatch:", ubatch_ctx.id, self.fused_experts) + self.prepare_finalize.finalize(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index b2774382ef6f0..a179a4d6d6f2b 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -119,11 +119,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_recv=not send, ) - # if ubatch_ctx is not None: - # ubatch_ctx.gpu_stream_wait() + print("Dispatch pre-wait") + if (ubatch_ctx := get_current_ubatch_context()) is not None: + ubatch_ctx.gpu_stream_wait() + print("Dispatch launched") dispatch(True) # Send yield_impl(gpu_wait=False) dispatch(False) # Recv + print("Finished dispatch") return expert_x, expert_x_scale, expert_num_tokens @@ -160,8 +163,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): do_send=send, do_recv=not send, ) - # if ubatch_ctx is not None: - # ubatch_ctx.gpu_stream_wait() + + print("Combine pre-wait") + if (ubatch_ctx := get_current_ubatch_context()) is not None: + ubatch_ctx.gpu_stream_wait() combine(True) + print("Combine launched") yield_impl(gpu_wait=False) combine(False) + print("Finished combine") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 1cb22a109dacc..eb1eec53fcbd0 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -564,6 +564,8 @@ class DeepseekV2DecoderLayer(nn.Module): hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> torch.Tensor: + if (ubatch_ctx := get_current_ubatch_context()) is not None: + print("in decoder, ubatch:", ubatch_ctx.id) # Self Attention if residual is None: residual = hidden_states @@ -657,7 +659,7 @@ class DeepseekV2Model(nn.Module): intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - if ubatch_ctx := get_current_ubatch_context() is not None: + if (ubatch_ctx := get_current_ubatch_context()) is not None: print("in forward, ubatch:", ubatch_ctx.id) if get_pp_group().is_first_rank: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1b9fcb63a7804..be756929959ff 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1287,10 +1287,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output: Optional["SchedulerOutput"] = None, is_dummy_run: bool = False): + num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1 + def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple: if use_dummy_input: - num_tokens = num_scheduled_tokens or 1 - return self._get_dummy_model_inputs(num_tokens) + return self._get_dummy_model_inputs(num_dummy_tokens) else: assert scheduler_output is not None return self._get_model_inputs(tokens_slice, scheduler_output) @@ -1301,7 +1302,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): model_inputs(token_slice, use_dummy_input) with context: if isinstance(context, UBatchContext): - print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape}") + print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape} use_dummy_input {use_dummy_input} token_slice {token_slice}") model_output = self.model( input_ids=input_ids, positions=positions, @@ -1333,19 +1334,27 @@ class GPUModelRunner(LoRAModelRunnerMixin): if not hasattr(self, "ubatch_streams"): # Create the ubatch streams self.ubatch_streams = [torch.cuda.Stream(self.device) for _ in range(len(ubatch_slices))] - - ubatch_fwd_ctxs = [create_forward_context( - attn_metadata[i] if attn_metadata is not None else None, - self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start) - ) for i, (_, tokens_slice) in enumerate(ubatch_slices)] + + + # We have to be careful creating the forward contexts here otherwise we can end + # up with the dummy contexts have num_tokens set to 0 + # ubatch_fwd_ctxs = [create_forward_context( + # attn_metadata[i] if attn_metadata is not None else None, + # self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start) + # ) for i, (_, tokens_slice) in enumerate(ubatch_slices)] ubatch_ctxs, start_hook = make_ubatch_context_chain( len(ubatch_slices), - fwd_ctxs=ubatch_fwd_ctxs, + #fwd_ctxs=ubatch_fwd_ctxs, streams=self.ubatch_streams, #stream=root_stream, # Only works currently if everything is run on the same stream device=self.device) setup_done = threading.Event() ubatch_threads = [] + # Initialize Events? not sure if this helps + for ubatch_ctx in ubatch_ctxs: + ubatch_ctx.gpu_wait_event.record(ubatch_ctx.stream) + ubatch_ctx.stream.wait_event(ubatch_ctx.gpu_wait_event) + # Ubatches will manually manage the forward context, so we override # it to None here so we can have it restored correctly later with override_forward_context(None): @@ -1354,6 +1363,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run + print("ubatch", i, "tokens slice", tokens_slice, "is dummy ubatch", is_dummy_ubatch, "is dummy run", is_dummy_run) + + num_tokens = num_dummy_tokens if is_dummy_ubatch or is_dummy_run else (tokens_slice.stop - tokens_slice.start) + ubatch_ctxs[i].forward_context = create_forward_context( + attn_metadata[i] if attn_metadata is not None else None, + self.vllm_config, num_tokens=num_tokens) + thread = threading.Thread(target=_ubatch_thread, args=( ubatch_ctxs[i], root_stream, diff --git a/vllm/v1/worker/ubatching.py b/vllm/v1/worker/ubatching.py index aab9c2a8d68a7..d24435f227cc1 100644 --- a/vllm/v1/worker/ubatching.py +++ b/vllm/v1/worker/ubatching.py @@ -17,7 +17,7 @@ class UBatchContext: def __init__(self, id: int, stream: torch.cuda.Stream, - fwd_ctx: forward_context.ForwardContext, + #fwd_ctx: forward_context.ForwardContext, cpu_wait_event: threading.Event, cpu_signal_event: threading.Event, gpu_wait_event: torch.cuda.Event, @@ -27,7 +27,7 @@ class UBatchContext: self.id = id self.stream = stream self.original_stream = current_stream() - self.forward_context = fwd_ctx + self.forward_context = None #fwd_ctx self.cpu_wait_event = cpu_wait_event self.cpu_signal_event = cpu_signal_event self.gpu_wait_event = gpu_wait_event @@ -80,6 +80,7 @@ class UBatchContext: # until ubatch0-dispatch is done avoiding overlapping dispatches that # might share underlying buffers def gpu_stream_wait(self): + print("Waiting ubatch %d on %s in stream %s" % (self.id, self.gpu_wait_event, self.stream)) self.stream.wait_event(self.gpu_wait_event) def _yield(self, gpu_wait: bool = True): @@ -92,6 +93,7 @@ class UBatchContext: def _signal(self): # Wait for the next batch to signal back + print(f"signaling ubatch {self.id} to {self.gpu_signal_event} on {self.stream}") self.gpu_signal_event.record(self.stream) # Signal that this batch reached the barrier self.cpu_signal_event.set() @@ -134,7 +136,7 @@ def yield_(x: torch.Tensor, schedule: str="default") -> None: """ def make_ubatch_context_chain( num_micro_batches: int, - fwd_ctxs: forward_context.ForwardContext, + #fwd_ctxs: forward_context.ForwardContext, streams: Optional[list[torch.Stream]] = None, device: Optional[torch.device] = None ) -> list[UBatchContext]: @@ -152,7 +154,7 @@ def make_ubatch_context_chain( stream = (streams[i] if streams else None) or torch.cuda.Stream(device) ctx = UBatchContext(id=i, stream=stream, - fwd_ctx=fwd_ctxs[i], + #fwd_ctx=fwd_ctxs[i], cpu_wait_event=cpu_events[i], cpu_signal_event=cpu_events[(i + 1) % num_micro_batches], gpu_wait_event=gpu_events[i], @@ -163,6 +165,7 @@ def make_ubatch_context_chain( def start_hook(from_stream: torch.cuda.Stream): ctxs[0].gpu_wait_event.record(from_stream) + print('singal to ubatch %d event %s from stream %s' % (ctxs[0].id, ctxs[0].gpu_wait_event, from_stream)) ctxs[0].cpu_wait_event.set() return ctxs, start_hook \ No newline at end of file