tone down prints

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-23 18:18:05 +00:00
parent 9edd08231b
commit 952f3c5c1e
5 changed files with 26 additions and 26 deletions

View File

@ -1300,8 +1300,8 @@ class FusedMoE(torch.nn.Module):
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE
if (ubatch_ctdx := get_current_ubatch_context()) is not None: # if (ubatch_ctdx := get_current_ubatch_context()) is not None:
print("in fused moe, ubatch:", ubatch_ctdx.id, "chunk size:", max_tokens_across_dp, "moe_dp_chunk_size_per_rank", moe_dp_chunk_size_per_rank) # print("in fused moe, ubatch:", ubatch_ctdx.id, "chunk size:", max_tokens_across_dp, "moe_dp_chunk_size_per_rank", moe_dp_chunk_size_per_rank)
num_tokens = full_hidden_states.size(0) num_tokens = full_hidden_states.size(0)
for chunk_start_ in range(0, max_tokens_across_dp, for chunk_start_ in range(0, max_tokens_across_dp,
@ -1401,8 +1401,8 @@ def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
if (ubatch_ctx := get_current_ubatch_context()) is not None: # if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in fused moe, ubatch:", ubatch_ctx.id, self) # print("in fused moe, ubatch:", ubatch_ctx.id, self)
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)

View File

@ -336,15 +336,15 @@ class FusedMoEModularKernel(torch.nn.Module):
device=a1.device, device=a1.device,
dtype=workspace_dtype) dtype=workspace_dtype)
if (ubatch_ctx := get_current_ubatch_context()) is not None: # if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in modular moe, ubatch:", ubatch_ctx.id) # print("in modular moe, ubatch:", ubatch_ctx.id)
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare( a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
expert_map, apply_router_weight_on_input) expert_map, apply_router_weight_on_input)
if (ubatch_ctx := get_current_ubatch_context()) is not None: # if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in modular moe2, ubatch:", ubatch_ctx.id, self.fused_experts) # print("in modular moe2, ubatch:", ubatch_ctx.id, self.fused_experts)
print("pre synchronize") print("pre synchronize")
torch.cuda.synchronize(a1.device) torch.cuda.synchronize(a1.device)
@ -369,8 +369,8 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_num_tokens=expert_num_tokens, expert_num_tokens=expert_num_tokens,
) )
if (ubatch_ctx := get_current_ubatch_context()) is not None: # if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in modular moe3, ubatch:", ubatch_ctx.id, self.fused_experts) # print("in modular moe3, ubatch:", ubatch_ctx.id, self.fused_experts)
self.prepare_finalize.finalize(output, fused_out, topk_weights, self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input) topk_ids, apply_router_weight_on_input)

View File

@ -119,14 +119,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send, do_recv=not send,
) )
print("Dispatch pre-wait") #print("Dispatch pre-wait")
if (ubatch_ctx := get_current_ubatch_context()) is not None: if (ubatch_ctx := get_current_ubatch_context()) is not None:
ubatch_ctx.gpu_stream_wait() ubatch_ctx.gpu_stream_wait()
print("Dispatch launched") #print("Dispatch launched")
dispatch(True) # Send dispatch(True) # Send
yield_impl(gpu_wait=False) yield_impl(gpu_wait=False)
dispatch(False) # Recv dispatch(False) # Recv
print("Finished dispatch") #print("Finished dispatch")
return expert_x, expert_x_scale, expert_num_tokens return expert_x, expert_x_scale, expert_num_tokens
@ -164,11 +164,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send, do_recv=not send,
) )
print("Combine pre-wait") #print("Combine pre-wait")
if (ubatch_ctx := get_current_ubatch_context()) is not None: if (ubatch_ctx := get_current_ubatch_context()) is not None:
ubatch_ctx.gpu_stream_wait() ubatch_ctx.gpu_stream_wait()
combine(True) combine(True)
print("Combine launched") #print("Combine launched")
yield_impl(gpu_wait=False) yield_impl(gpu_wait=False)
combine(False) combine(False)
print("Finished combine") #print("Finished combine")

View File

@ -564,8 +564,8 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
if (ubatch_ctx := get_current_ubatch_context()) is not None: # if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in decoder, ubatch:", ubatch_ctx.id) # print("in decoder, ubatch:", ubatch_ctx.id)
# Self Attention # Self Attention
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
@ -659,8 +659,8 @@ class DeepseekV2Model(nn.Module):
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if (ubatch_ctx := get_current_ubatch_context()) is not None: # if (ubatch_ctx := get_current_ubatch_context()) is not None:
print("in forward, ubatch:", ubatch_ctx.id) # print("in forward, ubatch:", ubatch_ctx.id)
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:

View File

@ -44,8 +44,8 @@ class UBatchContext:
self.original_stream.record_event(start_event) self.original_stream.record_event(start_event)
self.stream.wait_event(start_event) self.stream.wait_event(start_event)
print("Starting ubatch %d" % self.id) print("Starting ubatch %d" % self.id)
if self.gpu_wait_on_launch: # if self.gpu_wait_on_launch:
self.gpu_stream_wait() # self.gpu_stream_wait()
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
@ -84,10 +84,10 @@ class UBatchContext:
self.stream.wait_event(self.gpu_wait_event) self.stream.wait_event(self.gpu_wait_event)
def _yield(self, gpu_wait: bool = True): def _yield(self, gpu_wait: bool = True):
print("Yielding ubatch %d" % self.id) #print("Yielding ubatch %d" % self.id)
self._signal() self._signal()
self._cpu_wait() self._cpu_wait()
print("Resuming ubatch %d" % self.id) #print("Resuming ubatch %d" % self.id)
if gpu_wait: if gpu_wait:
self.gpu_stream_wait() self.gpu_stream_wait()
@ -115,7 +115,7 @@ def get_current_ubatch_context() -> Optional[UBatchContext]:
def yield_impl(schedule="default", gpu_wait: bool = True): def yield_impl(schedule="default", gpu_wait: bool = True):
# Perform the barrier if a context exists for this thread # Perform the barrier if a context exists for this thread
ctx = get_current_ubatch_context() ctx = get_current_ubatch_context()
print("you are in yield_impl", ctx) #print("you are in yield_impl", ctx)
if ctx is not None: if ctx is not None:
ctx._yield(gpu_wait=gpu_wait) ctx._yield(gpu_wait=gpu_wait)
@ -146,7 +146,7 @@ def make_ubatch_context_chain(
Create a context manager for micro-batching synchronization. Create a context manager for micro-batching synchronization.
""" """
cpu_events = [threading.Event() for _ in range(num_micro_batches)] cpu_events = [threading.Event() for _ in range(num_micro_batches)]
gpu_events = [torch.cuda.Event() for _ in range(num_micro_batches)] gpu_events = [torch.cuda.Event(blocking=True) for _ in range(num_micro_batches)]
device = device or torch.cuda.current_device() device = device or torch.cuda.current_device()
ctxs = [] ctxs = []