mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 11:07:03 +08:00
fixes
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
7b31e8a8ff
commit
a743a35948
@ -336,21 +336,10 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
# if (ubatch_ctx := get_current_ubatch_context()) is not None:
|
||||
# print("in modular moe, ubatch:", ubatch_ctx.id)
|
||||
|
||||
a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare(
|
||||
a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts,
|
||||
expert_map, apply_router_weight_on_input)
|
||||
|
||||
# if (ubatch_ctx := get_current_ubatch_context()) is not None:
|
||||
# print("in modular moe2, ubatch:", ubatch_ctx.id, self.fused_experts)
|
||||
|
||||
print("pre synchronize")
|
||||
dump_ubatching_state()
|
||||
torch.cuda.synchronize(a1.device)
|
||||
print("post synchronize")
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
a1q,
|
||||
w1,
|
||||
|
||||
@ -1318,9 +1318,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return model_output
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(ubatch_ctx, root_stream, token_slice, attn_metadata, results, save_results, use_dummy_input, setup_done_evt):
|
||||
ubatch_ctx.stream.wait_stream(root_stream)
|
||||
|
||||
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, use_dummy_input):
|
||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
||||
|
||||
if save_results:
|
||||
@ -1330,24 +1328,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
results = []
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
root_stream = current_stream()
|
||||
|
||||
if not hasattr(self, "ubatch_streams"):
|
||||
# Create the ubatch streams
|
||||
self.ubatch_streams = [torch.cuda.Stream(self.device) for _ in range(len(ubatch_slices))]
|
||||
|
||||
|
||||
# We have to be careful creating the forward contexts here otherwise we can end
|
||||
# up with the dummy contexts have num_tokens set to 0
|
||||
# ubatch_fwd_ctxs = [create_forward_context(
|
||||
# attn_metadata[i] if attn_metadata is not None else None,
|
||||
# self.vllm_config, num_tokens=(tokens_slice.stop - tokens_slice.start)
|
||||
# ) for i, (_, tokens_slice) in enumerate(ubatch_slices)]
|
||||
ubatch_ctxs, start_hook = make_ubatch_contexts(
|
||||
|
||||
ubatch_ctxs = make_ubatch_contexts(
|
||||
len(ubatch_slices),
|
||||
compute_stream=root_stream,
|
||||
device=self.device)
|
||||
setup_done = threading.Event()
|
||||
ubatch_threads = []
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
@ -1366,20 +1351,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
thread = threading.Thread(target=_ubatch_thread, args=(
|
||||
ubatch_ctxs[i],
|
||||
root_stream,
|
||||
tokens_slice,
|
||||
attn_metadata[i] if attn_metadata is not None else None,
|
||||
results,
|
||||
not is_dummy_ubatch or is_dummy_run,
|
||||
is_dummy_ubatch or is_dummy_run,
|
||||
setup_done,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Single the first ubatch to start
|
||||
start_hook(root_stream)
|
||||
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ class UBatchContext:
|
||||
def __enter__(self):
|
||||
global _CURRENT_CONTEXT
|
||||
_CURRENT_CONTEXT[threading.get_ident()] = self
|
||||
self._restore_context()
|
||||
# Assume we start on the compute stream
|
||||
assert current_stream() == self.compute_stream, \
|
||||
"Expected to start on the compute stream, but found %s" % current_stream()
|
||||
@ -52,9 +53,6 @@ class UBatchContext:
|
||||
return False
|
||||
|
||||
def _restore_context(self):
|
||||
# When we resume i.e. switch back to this micro-batch, we make sure
|
||||
# we have the correct stream and forward context
|
||||
torch.cuda.set_stream(self.stream)
|
||||
forward_context._forward_context = self.forward_context
|
||||
|
||||
def _signal_comm_done(self):
|
||||
@ -130,20 +128,21 @@ def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="defaul
|
||||
pass
|
||||
|
||||
def dump_ubatching_state():
|
||||
"""
|
||||
Dump the current UBatchContext state for debugging.
|
||||
"""
|
||||
pass
|
||||
# """
|
||||
# Dump the current UBatchContext state for debugging.
|
||||
# """
|
||||
|
||||
dp_rank = os.getenv("VLLM_DP_RANK", None)
|
||||
# dp_rank = os.getenv("VLLM_DP_RANK", None)
|
||||
|
||||
for ctx in _CURRENT_CONTEXT.values():
|
||||
print(f"UBatchContext: {ctx.id} (dp_rank {dp_rank})\n"
|
||||
f" Stream: {ctx.stream}, ({ctx.stream.query()})\n"
|
||||
f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n"
|
||||
f" CPU Wait Event: {ctx.cpu_wait_event}\n"
|
||||
f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n"
|
||||
f" CPU Signal Event: {ctx.cpu_signal_event}\n"
|
||||
f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n")
|
||||
# for ctx in _CURRENT_CONTEXT.values():
|
||||
# print(f"UBatchContext: {ctx.id} (dp_rank {dp_rank})\n"
|
||||
# f" Stream: {ctx.stream}, ({ctx.stream.query()})\n"
|
||||
# f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n"
|
||||
# f" CPU Wait Event: {ctx.cpu_wait_event}\n"
|
||||
# f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n"
|
||||
# f" CPU Signal Event: {ctx.cpu_signal_event}\n"
|
||||
# f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n")
|
||||
|
||||
"""
|
||||
"""
|
||||
@ -181,4 +180,4 @@ def make_ubatch_contexts(
|
||||
)
|
||||
ctxs.append(ctx)
|
||||
|
||||
return ctxs,
|
||||
return ctxs
|
||||
Loading…
x
Reference in New Issue
Block a user