mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 17:17:10 +08:00
capture works replay does not
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
af68574e3d
commit
4672c72f44
@ -84,7 +84,7 @@ def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
] * 100
|
||||
] * 5
|
||||
# import random
|
||||
# import string
|
||||
# prompts = [''.join(random.choices(string.ascii_letters, k=128)) for _ in range(2048)]
|
||||
|
||||
@ -155,7 +155,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||
|
||||
# Dispatch
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
expert_x, expert_num_tokens, handle, event, hook = \
|
||||
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
||||
rank_topk_ids,
|
||||
@ -165,7 +165,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
async_finish=False,
|
||||
return_recv_hook=False)
|
||||
self.handles[a2a_idx] = handle
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
|
||||
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
|
||||
a1.dtype)
|
||||
@ -188,7 +188,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
combine_topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
# TODO (varun) : Enable zero copy mode
|
||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||
_, event, hook = self.buffers[a2a_idx].low_latency_combine(
|
||||
fused_expert_output,
|
||||
topk_ids,
|
||||
@ -199,5 +199,5 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return_recv_hook=False,
|
||||
out=output)
|
||||
# event.current_stream_wait()
|
||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||
|
||||
|
||||
@ -47,7 +47,6 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
||||
|
||||
def embedding(self, layer: torch.nn.Module,
|
||||
input_: torch.Tensor) -> torch.Tensor:
|
||||
print("SHOULDNT BE HERE DURING CAPTURE")
|
||||
return F.embedding(input_, layer.weight)
|
||||
|
||||
|
||||
|
||||
@ -94,7 +94,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
||||
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
|
||||
@ -722,6 +722,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# logger.info(f"ATTEMPTING TO PAD UBATCH {should_ubatch}")
|
||||
assert should_ubatch
|
||||
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
|
||||
logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}")
|
||||
# logger.info("UBATCH PADDING DONE")
|
||||
if num_pad_tokens > 0:
|
||||
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
|
||||
@ -730,7 +731,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# We bail out of ubatching here. This accounts for the case where
|
||||
# the padding would result in an "empty" second ubatch.
|
||||
# TODO: just make the second ubatch a dummy ubatch
|
||||
logger.info("FALLING BACK AND DISABLING UBATCHING")
|
||||
# logger.info("FALLING BACK AND DISABLING UBATCHING")
|
||||
ubatch_bailout = True
|
||||
|
||||
# Note that if we are attempting to ubatch by this point then we know that no
|
||||
@ -738,7 +739,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if ubatch_slices:
|
||||
should_ubatch = self.should_ubatch(False if ubatch_bailout else True)
|
||||
if not should_ubatch:
|
||||
logger.info("SUCCESSFULLY BAILED OUT")
|
||||
# logger.info("SUCCESSFULLY BAILED OUT")
|
||||
num_pad_tokens = 0
|
||||
num_tokens_after_padding = None
|
||||
ubatch_slices = None
|
||||
@ -1344,7 +1345,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
num_tokens_padded = num_tokens_unpadded
|
||||
|
||||
logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}")
|
||||
# logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}")
|
||||
if (self.use_cuda_graph
|
||||
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
@ -1388,7 +1389,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens_padded = round_up(num_tokens_unpadded, 2)
|
||||
if (self.use_cuda_graph
|
||||
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded)
|
||||
else:
|
||||
@ -1581,10 +1581,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _make_ubatch_contexts(ubatch_slices,
|
||||
attn_metadata,
|
||||
compute_stream,
|
||||
is_dummy_run,
|
||||
num_tokens_across_dp) -> list[UBatchContext]:
|
||||
num_tokens_across_dp,
|
||||
skip_cuda_graphs) -> list[UBatchContext]:
|
||||
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
||||
compute_stream=current_stream(),
|
||||
compute_stream=compute_stream,
|
||||
device=self.device)
|
||||
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
@ -1617,13 +1619,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _make_ubatch_metadata(ubatch_slices,
|
||||
attn_metadata,
|
||||
compute_stream,
|
||||
is_dummy_run,
|
||||
num_tokens_across_dp) -> list[UbatchMetadata]:
|
||||
num_tokens_across_dp,
|
||||
skip_cuda_graphs) -> list[UbatchMetadata]:
|
||||
ubatch_ctxs = _make_ubatch_contexts(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
compute_stream=compute_stream,
|
||||
is_dummy_run=is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs
|
||||
)
|
||||
# First get some inputs
|
||||
ubatch_metadata: list[UbatchMetadata] = []
|
||||
@ -1644,8 +1650,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors):
|
||||
intermediate_tensors,
|
||||
start_signal=None):
|
||||
with context:
|
||||
if start_signal is not None:
|
||||
start_signal.wait()
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@ -1659,42 +1668,64 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return model_output
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(results, ubatch_metadata):
|
||||
def _ubatch_thread(results, ubatch_metadata, start_signal):
|
||||
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
context = ubatch_metadata.context
|
||||
with torch.cuda.stream(context.compute_stream):
|
||||
_ = torch.cuda.current_blas_handle()
|
||||
with torch.cuda.stream(context.comm_stream):
|
||||
_ = torch.cuda.current_blas_handle()
|
||||
model_output = _run(context=ubatch_metadata.context,
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors)
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
start_signal=start_signal)
|
||||
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
|
||||
def _run_ubatches(ubatch_metadata) -> torch.Tensor:
|
||||
def _run_ubatches(ubatch_metadata, num_tokens, should_capture=False) -> torch.Tensor:
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
root_stream = current_stream()
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
start_signals = []
|
||||
for metadata in ubatch_metadata:
|
||||
start_signal = threading.Event()
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
metadata,
|
||||
start_signal,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
start_signals.append(start_signal)
|
||||
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
|
||||
# torch.cuda.synchronize()
|
||||
torch.cuda.set_stream(root_stream)
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
return torch.cat(sorted_results, dim=0)
|
||||
# DO capture
|
||||
cudagraph_metadata = \
|
||||
CUDAGraphMetaData(
|
||||
cudagraph=torch.cuda.CUDAGraph(),
|
||||
using_ubatching=True
|
||||
)
|
||||
with torch.cuda.graph(cudagraph_metadata.cudagraph,
|
||||
stream=compute_stream):
|
||||
# logger.info("STARTING WAKEUP LOOP")
|
||||
for start_signal in start_signals:
|
||||
start_signal.set()
|
||||
# logger.info("FINISHED WAKEUP LOOP")
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
cudagraph_metadata.outputs = result
|
||||
logger.info(f"Capturing for {num_tokens} tokens")
|
||||
self.cudagraphs[num_tokens] = cudagraph_metadata
|
||||
return cudagraph_metadata.outputs
|
||||
|
||||
# run micro-batched
|
||||
if ubatch_slices is not None:
|
||||
@ -1702,69 +1733,71 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# num_tokens = ubatch_slices[1][1].stop
|
||||
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
||||
# assert not is_dummy_run
|
||||
compute_stream = torch.cuda.Stream(device=self.device)
|
||||
ubatch_metadata = _make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
compute_stream=compute_stream,
|
||||
is_dummy_run=is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs
|
||||
)
|
||||
if num_scheduled_tokens not in self.cudagraphs \
|
||||
and not skip_cuda_graphs and build_cuda_graph:
|
||||
# DO capture
|
||||
self.cudagraphs[num_scheduled_tokens] = \
|
||||
CUDAGraphMetaData(
|
||||
cudagraph=torch.cuda.CUDAGraph(),
|
||||
using_ubatching=True
|
||||
)
|
||||
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
|
||||
model_output = _run_ubatches(ubatch_metadata)
|
||||
self.cudagraphs[num_scheduled_tokens].outputs = model_output
|
||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
||||
elif num_scheduled_tokens in self.cudagraphs:
|
||||
self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
|
||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
||||
return _run_ubatches(ubatch_metadata, num_scheduled_tokens, should_capture=True)
|
||||
elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
|
||||
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
|
||||
logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
|
||||
cudagraph_metadata.cudagraph.replay()
|
||||
return cudagraph_metadata.outputs
|
||||
else:
|
||||
assert False
|
||||
return _run_ubatches(ubatch_metadata)
|
||||
# run single batch
|
||||
else:
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
|
||||
if num_scheduled_tokens not in self.cudagraphs \
|
||||
and not skip_cuda_graphs and build_cuda_graph:
|
||||
self.cudagraphs[num_scheduled_tokens] = \
|
||||
CUDAGraphMetaData(
|
||||
cudagraph=torch.cuda.CUDAGraph(),
|
||||
using_ubatching=False
|
||||
)
|
||||
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
|
||||
model_output = _run(
|
||||
context = set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs),
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors
|
||||
)
|
||||
self.cudagraphs[num_scheduled_tokens].outputs = model_output
|
||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
||||
elif num_scheduled_tokens in self.cudagraphs:
|
||||
self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
|
||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
||||
else:
|
||||
return _run(
|
||||
context = set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs),
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors
|
||||
)
|
||||
# if num_scheduled_tokens not in self.cudagraphs \
|
||||
# and not skip_cuda_graphs and build_cuda_graph:
|
||||
# assert False
|
||||
# logger.info(f"GRAPH BUILD{num_scheduled_tokens}")
|
||||
# self.cudagraphs[num_scheduled_tokens] = \
|
||||
# CUDAGraphMetaData(
|
||||
# cudagraph=torch.cuda.CUDAGraph(),
|
||||
# using_ubatching=False
|
||||
# )
|
||||
# with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
|
||||
# model_output = _run(
|
||||
# context = set_forward_context(attn_metadata,
|
||||
# vllm_config=self.vllm_config,
|
||||
# num_tokens=num_scheduled_tokens or 1,
|
||||
# num_tokens_across_dp=num_tokens_across_dp,
|
||||
# skip_cuda_graphs=skip_cuda_graphs),
|
||||
# input_ids=input_ids,
|
||||
# positions=positions,
|
||||
# inputs_embeds=inputs_embeds,
|
||||
# intermediate_tensors=intermediate_tensors
|
||||
# )
|
||||
# self.cudagraphs[num_scheduled_tokens].outputs = model_output
|
||||
# return self.cudagraphs[num_scheduled_tokens].outputs
|
||||
# elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
|
||||
# assert False
|
||||
# # logger.info(f"GRAPH REPLAY {num_scheduled_tokens}")
|
||||
# self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
|
||||
# return self.cudagraphs[num_scheduled_tokens].outputs
|
||||
# else:
|
||||
# logger.info(f"NORMAL RUN {num_scheduled_tokens}")
|
||||
return _run(
|
||||
context = set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs),
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@ -2325,8 +2358,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
build_cuda_graph: bool = False
|
||||
) -> torch.Tensor:
|
||||
|
||||
if allow_microbatching:
|
||||
logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN")
|
||||
# if allow_microbatching:
|
||||
# logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN")
|
||||
|
||||
|
||||
# TODO(Sage) We need some more code to properly handle
|
||||
@ -2339,6 +2372,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
allow_microbatching and capture_attn_cudagraph
|
||||
# _dummy_run doesn't go through _prepare_inputs so
|
||||
# we synchronize with other DP ranks here
|
||||
# logger.info(f"NUM TOKENS {num_tokens} SHOULD UBATCH {should_ubatch}")
|
||||
should_ubatch = self.should_ubatch(allow_microbatching)
|
||||
# Padding for DP
|
||||
# logger.info("PADDING DUMMY")
|
||||
|
||||
@ -30,7 +30,6 @@ class UBatchContext:
|
||||
self.id = id
|
||||
self.comm_stream = comm_stream
|
||||
self.compute_stream = compute_stream
|
||||
self.original_stream = current_stream()
|
||||
self.forward_context = None #fwd_ctx
|
||||
self.cpu_wait_event = cpu_wait_event
|
||||
self.cpu_signal_event = cpu_signal_event
|
||||
@ -58,7 +57,7 @@ class UBatchContext:
|
||||
self.cpu_signal_event.set()
|
||||
self.cpu_wait_event.clear()
|
||||
self.current_stream = self.compute_stream
|
||||
torch.cuda.set_stream(self.original_stream)
|
||||
torch.cuda.set_stream(self.current_stream)
|
||||
return False
|
||||
|
||||
def _restore_context(self):
|
||||
@ -76,26 +75,31 @@ class UBatchContext:
|
||||
pass
|
||||
|
||||
def _signal_comm_done(self):
|
||||
assert False
|
||||
self.ctx_valid_state()
|
||||
self.gpu_comm_done_event.record(self.comm_stream)
|
||||
|
||||
def _signal_compute_done(self):
|
||||
assert False
|
||||
self.ctx_valid_state()
|
||||
self.gpu_compute_done_event.record(self.compute_stream)
|
||||
|
||||
def _wait_compute_done(self):
|
||||
assert False
|
||||
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||
# print("Compute stream done", flush=True)
|
||||
|
||||
def _wait_comm_done(self):
|
||||
assert False
|
||||
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
||||
self.ctx_valid_state()
|
||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||
# print("Comm stream done", flush=True)
|
||||
|
||||
def stream_string(self):
|
||||
assert False
|
||||
if current_stream() == self.compute_stream:
|
||||
assert self.current_stream == self.compute_stream
|
||||
return "COMPUTE"
|
||||
@ -114,6 +118,7 @@ class UBatchContext:
|
||||
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
||||
|
||||
def yield_and_switch_from_compute_to_comm(self):
|
||||
assert False
|
||||
assert current_stream() == self.compute_stream
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
@ -129,6 +134,7 @@ class UBatchContext:
|
||||
self._wait_compute_done()
|
||||
|
||||
def yield_and_switch_from_comm_to_compute(self):
|
||||
assert False
|
||||
assert current_stream() == self.comm_stream
|
||||
# dp_rank = get_dp_group().rank_in_group
|
||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user