dp working no yields

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-22 21:49:14 +00:00
parent 2a7f25fbe2
commit a8439e2fd4
5 changed files with 29 additions and 24 deletions

View File

@ -4332,7 +4332,6 @@ class VllmConfig:
logger.warning_once( logger.warning_once(
"Piecewise compilation is not supported with " "Piecewise compilation is not supported with "
"microbatching. Disabling piecewiseching compilation.") "microbatching. Disabling piecewiseching compilation.")
self.parallel_config.enable_microbatching = False
self.compilation_config.level = CompilationLevel.DYNAMO_ONCE self.compilation_config.level = CompilationLevel.DYNAMO_ONCE

View File

@ -102,7 +102,6 @@ def override_forward_context(forward_context: Optional[ForwardContext]):
""" """
global _forward_context global _forward_context
prev_context = _forward_context prev_context = _forward_context
print("overriding forward context with", forward_context)
_forward_context = forward_context _forward_context = forward_context
try: try:
yield yield

View File

@ -74,7 +74,7 @@ class FusedMoEParallelConfig:
@property @property
def use_pplx_kernels(self): def use_pplx_kernels(self):
return self.dp_size > 1 and self.use_ep and has_pplx and False return self.dp_size > 1 and self.use_ep and has_pplx
@staticmethod @staticmethod
def make(tp_size_: int, dp_size_: int, def make(tp_size_: int, dp_size_: int,

View File

@ -105,15 +105,20 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# There's not much point setting this unless it is != indices.size(0) # There's not much point setting this unless it is != indices.size(0)
bound_m: Optional[torch.Tensor] = None bound_m: Optional[torch.Tensor] = None
self.a2a.dispatch( def dispatch(send: bool):
out_expert_num_tokens=expert_num_tokens, self.a2a.dispatch(
out_expert_x=expert_x, out_expert_num_tokens=expert_num_tokens,
out_expert_x_scale=expert_x_scale, out_expert_x=expert_x,
dp_x=a1q, out_expert_x_scale=expert_x_scale,
dp_x_scale=a1q_scale, dp_x=a1q,
indices=rank_topk_ids, dp_x_scale=a1q_scale,
bound_m=bound_m, indices=rank_topk_ids,
) bound_m=bound_m,
do_send=send,
do_recv=not send,
)
dispatch(True) # Send
dispatch(False) # Recv
return expert_x, expert_x_scale, expert_num_tokens return expert_x, expert_x_scale, expert_num_tokens
@ -140,8 +145,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights) topk_weights = torch.ones_like(topk_weights)
self.a2a.combine(out_tokens=output, def combine(send: bool):
indices=topk_ids, self.a2a.combine(
weights=topk_weights, out_tokens=output,
expert_y=fused_expert_output, indices=topk_ids,
bound_m=bound_m) weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=send,
do_recv=not send,
)
combine(True)
combine(False)

View File

@ -1300,16 +1300,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_ids, positions, inputs_embeds, intermediate_tensors = \ input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(token_slice, use_dummy_input) model_inputs(token_slice, use_dummy_input)
with context: with context:
if isinstance(context, UBatchContext):
print("running ubatch ctx", context.id)
model_output = self.model( model_output = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
if isinstance(context, UBatchContext):
print("done ubatch ctx", context.id)
if isinstance(context, UBatchContext): if isinstance(context, UBatchContext):
# Clone before we leave the ubatch context # Clone before we leave the ubatch context
model_output = model_output.clone() model_output = model_output.clone()
@ -1335,7 +1331,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.ubatch_streams = [torch.cuda.Stream(self.device) for _ in range(len(ubatch_slices))] self.ubatch_streams = [torch.cuda.Stream(self.device) for _ in range(len(ubatch_slices))]
ubatch_fwd_ctxs = [create_forward_context( ubatch_fwd_ctxs = [create_forward_context(
attn_metadata[i], self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start) attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start)
) for i, (_, tokens_slice) in enumerate(ubatch_slices)] ) for i, (_, tokens_slice) in enumerate(ubatch_slices)]
ubatch_ctxs, start_hook = make_ubatch_context_chain( ubatch_ctxs, start_hook = make_ubatch_context_chain(
len(ubatch_slices), len(ubatch_slices),
@ -1368,7 +1365,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Single the first ubatch to start # Single the first ubatch to start
start_hook(root_stream) start_hook(root_stream)
print("started first ubatch")
for thread in ubatch_threads: for thread in ubatch_threads:
thread.join() thread.join()
@ -1376,7 +1372,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for ubatch_ctx in ubatch_ctxs: for ubatch_ctx in ubatch_ctxs:
root_stream.wait_stream(ubatch_ctx.stream) root_stream.wait_stream(ubatch_ctx.stream)
print("torch cat")
torch.cuda.set_stream(root_stream) torch.cuda.set_stream(root_stream)
return torch.cat(results, dim=0) return torch.cat(results, dim=0)