mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 19:27:07 +08:00
dp working no yields
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
2a7f25fbe2
commit
a8439e2fd4
@ -4332,7 +4332,6 @@ class VllmConfig:
|
||||
logger.warning_once(
|
||||
"Piecewise compilation is not supported with "
|
||||
"microbatching. Disabling piecewiseching compilation.")
|
||||
self.parallel_config.enable_microbatching = False
|
||||
self.compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
|
||||
|
||||
|
||||
@ -102,7 +102,6 @@ def override_forward_context(forward_context: Optional[ForwardContext]):
|
||||
"""
|
||||
global _forward_context
|
||||
prev_context = _forward_context
|
||||
print("overriding forward context with", forward_context)
|
||||
_forward_context = forward_context
|
||||
try:
|
||||
yield
|
||||
|
||||
@ -74,7 +74,7 @@ class FusedMoEParallelConfig:
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep and has_pplx and False
|
||||
return self.dp_size > 1 and self.use_ep and has_pplx
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
|
||||
@ -105,15 +105,20 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# There's not much point setting this unless it is != indices.size(0)
|
||||
bound_m: Optional[torch.Tensor] = None
|
||||
|
||||
self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=rank_topk_ids,
|
||||
bound_m=bound_m,
|
||||
)
|
||||
def dispatch(send: bool):
|
||||
self.a2a.dispatch(
|
||||
out_expert_num_tokens=expert_num_tokens,
|
||||
out_expert_x=expert_x,
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=rank_topk_ids,
|
||||
bound_m=bound_m,
|
||||
do_send=send,
|
||||
do_recv=not send,
|
||||
)
|
||||
dispatch(True) # Send
|
||||
dispatch(False) # Recv
|
||||
|
||||
return expert_x, expert_x_scale, expert_num_tokens
|
||||
|
||||
@ -140,8 +145,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
if apply_router_weight_on_input:
|
||||
topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
self.a2a.combine(out_tokens=output,
|
||||
indices=topk_ids,
|
||||
weights=topk_weights,
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m)
|
||||
def combine(send: bool):
|
||||
self.a2a.combine(
|
||||
out_tokens=output,
|
||||
indices=topk_ids,
|
||||
weights=topk_weights,
|
||||
expert_y=fused_expert_output,
|
||||
bound_m=bound_m,
|
||||
do_send=send,
|
||||
do_recv=not send,
|
||||
)
|
||||
combine(True)
|
||||
combine(False)
|
||||
|
||||
@ -1300,16 +1300,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(token_slice, use_dummy_input)
|
||||
with context:
|
||||
if isinstance(context, UBatchContext):
|
||||
print("running ubatch ctx", context.id)
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if isinstance(context, UBatchContext):
|
||||
print("done ubatch ctx", context.id)
|
||||
if isinstance(context, UBatchContext):
|
||||
# Clone before we leave the ubatch context
|
||||
model_output = model_output.clone()
|
||||
@ -1335,7 +1331,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.ubatch_streams = [torch.cuda.Stream(self.device) for _ in range(len(ubatch_slices))]
|
||||
|
||||
ubatch_fwd_ctxs = [create_forward_context(
|
||||
attn_metadata[i], self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start)
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start)
|
||||
) for i, (_, tokens_slice) in enumerate(ubatch_slices)]
|
||||
ubatch_ctxs, start_hook = make_ubatch_context_chain(
|
||||
len(ubatch_slices),
|
||||
@ -1368,7 +1365,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Single the first ubatch to start
|
||||
start_hook(root_stream)
|
||||
print("started first ubatch")
|
||||
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
@ -1376,7 +1372,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for ubatch_ctx in ubatch_ctxs:
|
||||
root_stream.wait_stream(ubatch_ctx.stream)
|
||||
|
||||
print("torch cat")
|
||||
torch.cuda.set_stream(root_stream)
|
||||
return torch.cat(results, dim=0)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user