dp working no yields

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-22 21:49:14 +00:00
parent 2a7f25fbe2
commit a8439e2fd4
5 changed files with 29 additions and 24 deletions

View File

@ -4332,7 +4332,6 @@ class VllmConfig:
logger.warning_once(
"Piecewise compilation is not supported with "
"microbatching. Disabling piecewiseching compilation.")
self.parallel_config.enable_microbatching = False
self.compilation_config.level = CompilationLevel.DYNAMO_ONCE

View File

@ -102,7 +102,6 @@ def override_forward_context(forward_context: Optional[ForwardContext]):
"""
global _forward_context
prev_context = _forward_context
print("overriding forward context with", forward_context)
_forward_context = forward_context
try:
yield

View File

@ -74,7 +74,7 @@ class FusedMoEParallelConfig:
@property
def use_pplx_kernels(self):
return self.dp_size > 1 and self.use_ep and has_pplx and False
return self.dp_size > 1 and self.use_ep and has_pplx
@staticmethod
def make(tp_size_: int, dp_size_: int,

View File

@ -105,15 +105,20 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# There's not much point setting this unless it is != indices.size(0)
bound_m: Optional[torch.Tensor] = None
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=rank_topk_ids,
bound_m=bound_m,
)
def dispatch(send: bool):
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=rank_topk_ids,
bound_m=bound_m,
do_send=send,
do_recv=not send,
)
dispatch(True) # Send
dispatch(False) # Recv
return expert_x, expert_x_scale, expert_num_tokens
@ -140,8 +145,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
self.a2a.combine(out_tokens=output,
indices=topk_ids,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m)
def combine(send: bool):
self.a2a.combine(
out_tokens=output,
indices=topk_ids,
weights=topk_weights,
expert_y=fused_expert_output,
bound_m=bound_m,
do_send=send,
do_recv=not send,
)
combine(True)
combine(False)

View File

@ -1300,16 +1300,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(token_slice, use_dummy_input)
with context:
if isinstance(context, UBatchContext):
print("running ubatch ctx", context.id)
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if isinstance(context, UBatchContext):
print("done ubatch ctx", context.id)
if isinstance(context, UBatchContext):
# Clone before we leave the ubatch context
model_output = model_output.clone()
@ -1335,7 +1331,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.ubatch_streams = [torch.cuda.Stream(self.device) for _ in range(len(ubatch_slices))]
ubatch_fwd_ctxs = [create_forward_context(
attn_metadata[i], self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start)
attn_metadata[i] if attn_metadata is not None else None,
self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start)
) for i, (_, tokens_slice) in enumerate(ubatch_slices)]
ubatch_ctxs, start_hook = make_ubatch_context_chain(
len(ubatch_slices),
@ -1368,7 +1365,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Single the first ubatch to start
start_hook(root_stream)
print("started first ubatch")
for thread in ubatch_threads:
thread.join()
@ -1376,7 +1372,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for ubatch_ctx in ubatch_ctxs:
root_stream.wait_stream(ubatch_ctx.stream)
print("torch cat")
torch.cuda.set_stream(root_stream)
return torch.cat(results, dim=0)