mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-11 06:53:36 +08:00
capture works replay does not
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
af68574e3d
commit
4672c72f44
@ -84,7 +84,7 @@ def main(args, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
|
|||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
"The capital of France is",
|
"The capital of France is",
|
||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
] * 100
|
] * 5
|
||||||
# import random
|
# import random
|
||||||
# import string
|
# import string
|
||||||
# prompts = [''.join(random.choices(string.ascii_letters, k=128)) for _ in range(2048)]
|
# prompts = [''.join(random.choices(string.ascii_letters, k=128)) for _ in range(2048)]
|
||||||
|
|||||||
@ -155,7 +155,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||||
|
|
||||||
# Dispatch
|
# Dispatch
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
expert_x, expert_num_tokens, handle, event, hook = \
|
expert_x, expert_num_tokens, handle, event, hook = \
|
||||||
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
self.buffers[a2a_idx].low_latency_dispatch(a1,
|
||||||
rank_topk_ids,
|
rank_topk_ids,
|
||||||
@ -165,7 +165,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
async_finish=False,
|
async_finish=False,
|
||||||
return_recv_hook=False)
|
return_recv_hook=False)
|
||||||
self.handles[a2a_idx] = handle
|
self.handles[a2a_idx] = handle
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||||
|
|
||||||
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
|
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
|
||||||
a1.dtype)
|
a1.dtype)
|
||||||
@ -188,7 +188,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
combine_topk_weights = torch.ones_like(topk_weights)
|
combine_topk_weights = torch.ones_like(topk_weights)
|
||||||
|
|
||||||
# TODO (varun) : Enable zero copy mode
|
# TODO (varun) : Enable zero copy mode
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
|
||||||
_, event, hook = self.buffers[a2a_idx].low_latency_combine(
|
_, event, hook = self.buffers[a2a_idx].low_latency_combine(
|
||||||
fused_expert_output,
|
fused_expert_output,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@ -199,5 +199,5 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return_recv_hook=False,
|
return_recv_hook=False,
|
||||||
out=output)
|
out=output)
|
||||||
# event.current_stream_wait()
|
# event.current_stream_wait()
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
|
||||||
|
|
||||||
|
|||||||
@ -47,7 +47,6 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
|||||||
|
|
||||||
def embedding(self, layer: torch.nn.Module,
|
def embedding(self, layer: torch.nn.Module,
|
||||||
input_: torch.Tensor) -> torch.Tensor:
|
input_: torch.Tensor) -> torch.Tensor:
|
||||||
print("SHOULDNT BE HERE DURING CAPTURE")
|
|
||||||
return F.embedding(input_, layer.weight)
|
return F.embedding(input_, layer.weight)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -94,7 +94,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
|
|
||||||
# Num splits is per-batch, varying size (batch_size,)
|
# Num splits is per-batch, varying size (batch_size,)
|
||||||
n = num_splits.size(0)
|
n = num_splits.size(0)
|
||||||
logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
||||||
# make sure static buffer is large enough
|
# make sure static buffer is large enough
|
||||||
assert n <= self.cg_buf_num_splits.size(0)
|
assert n <= self.cg_buf_num_splits.size(0)
|
||||||
num_splits_view = self.cg_buf_num_splits[:n]
|
num_splits_view = self.cg_buf_num_splits[:n]
|
||||||
|
|||||||
@ -722,6 +722,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# logger.info(f"ATTEMPTING TO PAD UBATCH {should_ubatch}")
|
# logger.info(f"ATTEMPTING TO PAD UBATCH {should_ubatch}")
|
||||||
assert should_ubatch
|
assert should_ubatch
|
||||||
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
|
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
|
||||||
|
logger.info(f"num_tokens {scheduler_output.total_num_scheduled_tokens} num_pad_tokens {num_pad_tokens} num_toknes_after {num_tokens_after_padding}")
|
||||||
# logger.info("UBATCH PADDING DONE")
|
# logger.info("UBATCH PADDING DONE")
|
||||||
if num_pad_tokens > 0:
|
if num_pad_tokens > 0:
|
||||||
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
|
if num_pad_tokens < scheduler_output.total_num_scheduled_tokens:
|
||||||
@ -730,7 +731,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# We bail out of ubatching here. This accounts for the case where
|
# We bail out of ubatching here. This accounts for the case where
|
||||||
# the padding would result in an "empty" second ubatch.
|
# the padding would result in an "empty" second ubatch.
|
||||||
# TODO: just make the second ubatch a dummy ubatch
|
# TODO: just make the second ubatch a dummy ubatch
|
||||||
logger.info("FALLING BACK AND DISABLING UBATCHING")
|
# logger.info("FALLING BACK AND DISABLING UBATCHING")
|
||||||
ubatch_bailout = True
|
ubatch_bailout = True
|
||||||
|
|
||||||
# Note that if we are attempting to ubatch by this point then we know that no
|
# Note that if we are attempting to ubatch by this point then we know that no
|
||||||
@ -738,7 +739,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if ubatch_slices:
|
if ubatch_slices:
|
||||||
should_ubatch = self.should_ubatch(False if ubatch_bailout else True)
|
should_ubatch = self.should_ubatch(False if ubatch_bailout else True)
|
||||||
if not should_ubatch:
|
if not should_ubatch:
|
||||||
logger.info("SUCCESSFULLY BAILED OUT")
|
# logger.info("SUCCESSFULLY BAILED OUT")
|
||||||
num_pad_tokens = 0
|
num_pad_tokens = 0
|
||||||
num_tokens_after_padding = None
|
num_tokens_after_padding = None
|
||||||
ubatch_slices = None
|
ubatch_slices = None
|
||||||
@ -1344,7 +1345,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
num_tokens_padded = num_tokens_unpadded
|
num_tokens_padded = num_tokens_unpadded
|
||||||
|
|
||||||
logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}")
|
# logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}")
|
||||||
if (self.use_cuda_graph
|
if (self.use_cuda_graph
|
||||||
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
|
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
|
||||||
# Use piecewise CUDA graphs.
|
# Use piecewise CUDA graphs.
|
||||||
@ -1388,7 +1389,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_tokens_padded = round_up(num_tokens_unpadded, 2)
|
num_tokens_padded = round_up(num_tokens_unpadded, 2)
|
||||||
if (self.use_cuda_graph
|
if (self.use_cuda_graph
|
||||||
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
|
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
|
||||||
# Use piecewise CUDA graphs.
|
|
||||||
# Add padding to the batch size.
|
# Add padding to the batch size.
|
||||||
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded)
|
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded)
|
||||||
else:
|
else:
|
||||||
@ -1581,10 +1581,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
def _make_ubatch_contexts(ubatch_slices,
|
def _make_ubatch_contexts(ubatch_slices,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
|
compute_stream,
|
||||||
is_dummy_run,
|
is_dummy_run,
|
||||||
num_tokens_across_dp) -> list[UBatchContext]:
|
num_tokens_across_dp,
|
||||||
|
skip_cuda_graphs) -> list[UBatchContext]:
|
||||||
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
||||||
compute_stream=current_stream(),
|
compute_stream=compute_stream,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
@ -1617,13 +1619,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
def _make_ubatch_metadata(ubatch_slices,
|
def _make_ubatch_metadata(ubatch_slices,
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
|
compute_stream,
|
||||||
is_dummy_run,
|
is_dummy_run,
|
||||||
num_tokens_across_dp) -> list[UbatchMetadata]:
|
num_tokens_across_dp,
|
||||||
|
skip_cuda_graphs) -> list[UbatchMetadata]:
|
||||||
ubatch_ctxs = _make_ubatch_contexts(
|
ubatch_ctxs = _make_ubatch_contexts(
|
||||||
ubatch_slices=ubatch_slices,
|
ubatch_slices=ubatch_slices,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
|
compute_stream=compute_stream,
|
||||||
is_dummy_run=is_dummy_run,
|
is_dummy_run=is_dummy_run,
|
||||||
num_tokens_across_dp=num_tokens_across_dp
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
skip_cuda_graphs=skip_cuda_graphs
|
||||||
)
|
)
|
||||||
# First get some inputs
|
# First get some inputs
|
||||||
ubatch_metadata: list[UbatchMetadata] = []
|
ubatch_metadata: list[UbatchMetadata] = []
|
||||||
@ -1644,8 +1650,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
intermediate_tensors):
|
intermediate_tensors,
|
||||||
|
start_signal=None):
|
||||||
with context:
|
with context:
|
||||||
|
if start_signal is not None:
|
||||||
|
start_signal.wait()
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
@ -1659,42 +1668,64 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _ubatch_thread(results, ubatch_metadata):
|
def _ubatch_thread(results, ubatch_metadata, start_signal):
|
||||||
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||||
|
context = ubatch_metadata.context
|
||||||
|
with torch.cuda.stream(context.compute_stream):
|
||||||
|
_ = torch.cuda.current_blas_handle()
|
||||||
|
with torch.cuda.stream(context.comm_stream):
|
||||||
|
_ = torch.cuda.current_blas_handle()
|
||||||
model_output = _run(context=ubatch_metadata.context,
|
model_output = _run(context=ubatch_metadata.context,
|
||||||
input_ids=ubatch_metadata.input_ids,
|
input_ids=ubatch_metadata.input_ids,
|
||||||
positions=ubatch_metadata.positions,
|
positions=ubatch_metadata.positions,
|
||||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||||
intermediate_tensors=ubatch_metadata.intermediate_tensors)
|
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||||
|
start_signal=start_signal)
|
||||||
|
|
||||||
results.append((ubatch_metadata.context.id, model_output))
|
results.append((ubatch_metadata.context.id, model_output))
|
||||||
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||||
|
|
||||||
def _run_ubatches(ubatch_metadata) -> torch.Tensor:
|
def _run_ubatches(ubatch_metadata, num_tokens, should_capture=False) -> torch.Tensor:
|
||||||
results: list[tuple[int, torch.Tensor]] = []
|
results: list[tuple[int, torch.Tensor]] = []
|
||||||
root_stream = current_stream()
|
|
||||||
|
|
||||||
# Ubatches will manually manage the forward context, so we override
|
# Ubatches will manually manage the forward context, so we override
|
||||||
# it to None here so we can have it restored correctly later
|
# it to None here so we can have it restored correctly later
|
||||||
with override_forward_context(None):
|
with override_forward_context(None):
|
||||||
ubatch_threads = []
|
ubatch_threads = []
|
||||||
|
start_signals = []
|
||||||
for metadata in ubatch_metadata:
|
for metadata in ubatch_metadata:
|
||||||
|
start_signal = threading.Event()
|
||||||
thread = threading.Thread(target=_ubatch_thread,
|
thread = threading.Thread(target=_ubatch_thread,
|
||||||
args=(
|
args=(
|
||||||
results,
|
results,
|
||||||
metadata,
|
metadata,
|
||||||
|
start_signal,
|
||||||
))
|
))
|
||||||
ubatch_threads.append(thread)
|
ubatch_threads.append(thread)
|
||||||
thread.start()
|
thread.start()
|
||||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
start_signals.append(start_signal)
|
||||||
|
|
||||||
for thread in ubatch_threads:
|
# DO capture
|
||||||
thread.join()
|
cudagraph_metadata = \
|
||||||
|
CUDAGraphMetaData(
|
||||||
# torch.cuda.synchronize()
|
cudagraph=torch.cuda.CUDAGraph(),
|
||||||
torch.cuda.set_stream(root_stream)
|
using_ubatching=True
|
||||||
sorted_results = [value for position, value in sorted(results)]
|
)
|
||||||
return torch.cat(sorted_results, dim=0)
|
with torch.cuda.graph(cudagraph_metadata.cudagraph,
|
||||||
|
stream=compute_stream):
|
||||||
|
# logger.info("STARTING WAKEUP LOOP")
|
||||||
|
for start_signal in start_signals:
|
||||||
|
start_signal.set()
|
||||||
|
# logger.info("FINISHED WAKEUP LOOP")
|
||||||
|
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||||
|
for thread in ubatch_threads:
|
||||||
|
thread.join()
|
||||||
|
sorted_results = [value for position, value in sorted(results)]
|
||||||
|
result = torch.cat(sorted_results, dim=0)
|
||||||
|
cudagraph_metadata.outputs = result
|
||||||
|
logger.info(f"Capturing for {num_tokens} tokens")
|
||||||
|
self.cudagraphs[num_tokens] = cudagraph_metadata
|
||||||
|
return cudagraph_metadata.outputs
|
||||||
|
|
||||||
# run micro-batched
|
# run micro-batched
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
@ -1702,69 +1733,71 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# num_tokens = ubatch_slices[1][1].stop
|
# num_tokens = ubatch_slices[1][1].stop
|
||||||
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
||||||
# assert not is_dummy_run
|
# assert not is_dummy_run
|
||||||
|
compute_stream = torch.cuda.Stream(device=self.device)
|
||||||
ubatch_metadata = _make_ubatch_metadata(
|
ubatch_metadata = _make_ubatch_metadata(
|
||||||
ubatch_slices=ubatch_slices,
|
ubatch_slices=ubatch_slices,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
|
compute_stream=compute_stream,
|
||||||
is_dummy_run=is_dummy_run,
|
is_dummy_run=is_dummy_run,
|
||||||
num_tokens_across_dp=num_tokens_across_dp
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
skip_cuda_graphs=skip_cuda_graphs
|
||||||
)
|
)
|
||||||
if num_scheduled_tokens not in self.cudagraphs \
|
if num_scheduled_tokens not in self.cudagraphs \
|
||||||
and not skip_cuda_graphs and build_cuda_graph:
|
and not skip_cuda_graphs and build_cuda_graph:
|
||||||
# DO capture
|
return _run_ubatches(ubatch_metadata, num_scheduled_tokens, should_capture=True)
|
||||||
self.cudagraphs[num_scheduled_tokens] = \
|
elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
|
||||||
CUDAGraphMetaData(
|
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
|
||||||
cudagraph=torch.cuda.CUDAGraph(),
|
logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
|
||||||
using_ubatching=True
|
cudagraph_metadata.cudagraph.replay()
|
||||||
)
|
return cudagraph_metadata.outputs
|
||||||
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
|
|
||||||
model_output = _run_ubatches(ubatch_metadata)
|
|
||||||
self.cudagraphs[num_scheduled_tokens].outputs = model_output
|
|
||||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
|
||||||
elif num_scheduled_tokens in self.cudagraphs:
|
|
||||||
self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
|
|
||||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
|
||||||
else:
|
else:
|
||||||
|
assert False
|
||||||
return _run_ubatches(ubatch_metadata)
|
return _run_ubatches(ubatch_metadata)
|
||||||
# run single batch
|
# run single batch
|
||||||
else:
|
else:
|
||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
|
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
|
||||||
if num_scheduled_tokens not in self.cudagraphs \
|
# if num_scheduled_tokens not in self.cudagraphs \
|
||||||
and not skip_cuda_graphs and build_cuda_graph:
|
# and not skip_cuda_graphs and build_cuda_graph:
|
||||||
self.cudagraphs[num_scheduled_tokens] = \
|
# assert False
|
||||||
CUDAGraphMetaData(
|
# logger.info(f"GRAPH BUILD{num_scheduled_tokens}")
|
||||||
cudagraph=torch.cuda.CUDAGraph(),
|
# self.cudagraphs[num_scheduled_tokens] = \
|
||||||
using_ubatching=False
|
# CUDAGraphMetaData(
|
||||||
)
|
# cudagraph=torch.cuda.CUDAGraph(),
|
||||||
with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
|
# using_ubatching=False
|
||||||
model_output = _run(
|
# )
|
||||||
context = set_forward_context(attn_metadata,
|
# with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
|
||||||
vllm_config=self.vllm_config,
|
# model_output = _run(
|
||||||
num_tokens=num_scheduled_tokens or 1,
|
# context = set_forward_context(attn_metadata,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
# vllm_config=self.vllm_config,
|
||||||
skip_cuda_graphs=skip_cuda_graphs),
|
# num_tokens=num_scheduled_tokens or 1,
|
||||||
input_ids=input_ids,
|
# num_tokens_across_dp=num_tokens_across_dp,
|
||||||
positions=positions,
|
# skip_cuda_graphs=skip_cuda_graphs),
|
||||||
inputs_embeds=inputs_embeds,
|
# input_ids=input_ids,
|
||||||
intermediate_tensors=intermediate_tensors
|
# positions=positions,
|
||||||
)
|
# inputs_embeds=inputs_embeds,
|
||||||
self.cudagraphs[num_scheduled_tokens].outputs = model_output
|
# intermediate_tensors=intermediate_tensors
|
||||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
# )
|
||||||
elif num_scheduled_tokens in self.cudagraphs:
|
# self.cudagraphs[num_scheduled_tokens].outputs = model_output
|
||||||
self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
|
# return self.cudagraphs[num_scheduled_tokens].outputs
|
||||||
return self.cudagraphs[num_scheduled_tokens].outputs
|
# elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
|
||||||
else:
|
# assert False
|
||||||
return _run(
|
# # logger.info(f"GRAPH REPLAY {num_scheduled_tokens}")
|
||||||
context = set_forward_context(attn_metadata,
|
# self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
|
||||||
vllm_config=self.vllm_config,
|
# return self.cudagraphs[num_scheduled_tokens].outputs
|
||||||
num_tokens=num_scheduled_tokens or 1,
|
# else:
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
# logger.info(f"NORMAL RUN {num_scheduled_tokens}")
|
||||||
skip_cuda_graphs=skip_cuda_graphs),
|
return _run(
|
||||||
input_ids=input_ids,
|
context = set_forward_context(attn_metadata,
|
||||||
positions=positions,
|
vllm_config=self.vllm_config,
|
||||||
inputs_embeds=inputs_embeds,
|
num_tokens=num_scheduled_tokens or 1,
|
||||||
intermediate_tensors=intermediate_tensors
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
)
|
skip_cuda_graphs=skip_cuda_graphs),
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
intermediate_tensors=intermediate_tensors
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
@ -2325,8 +2358,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
build_cuda_graph: bool = False
|
build_cuda_graph: bool = False
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if allow_microbatching:
|
# if allow_microbatching:
|
||||||
logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN")
|
# logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN")
|
||||||
|
|
||||||
|
|
||||||
# TODO(Sage) We need some more code to properly handle
|
# TODO(Sage) We need some more code to properly handle
|
||||||
@ -2339,6 +2372,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
allow_microbatching and capture_attn_cudagraph
|
allow_microbatching and capture_attn_cudagraph
|
||||||
# _dummy_run doesn't go through _prepare_inputs so
|
# _dummy_run doesn't go through _prepare_inputs so
|
||||||
# we synchronize with other DP ranks here
|
# we synchronize with other DP ranks here
|
||||||
|
# logger.info(f"NUM TOKENS {num_tokens} SHOULD UBATCH {should_ubatch}")
|
||||||
should_ubatch = self.should_ubatch(allow_microbatching)
|
should_ubatch = self.should_ubatch(allow_microbatching)
|
||||||
# Padding for DP
|
# Padding for DP
|
||||||
# logger.info("PADDING DUMMY")
|
# logger.info("PADDING DUMMY")
|
||||||
|
|||||||
@ -30,7 +30,6 @@ class UBatchContext:
|
|||||||
self.id = id
|
self.id = id
|
||||||
self.comm_stream = comm_stream
|
self.comm_stream = comm_stream
|
||||||
self.compute_stream = compute_stream
|
self.compute_stream = compute_stream
|
||||||
self.original_stream = current_stream()
|
|
||||||
self.forward_context = None #fwd_ctx
|
self.forward_context = None #fwd_ctx
|
||||||
self.cpu_wait_event = cpu_wait_event
|
self.cpu_wait_event = cpu_wait_event
|
||||||
self.cpu_signal_event = cpu_signal_event
|
self.cpu_signal_event = cpu_signal_event
|
||||||
@ -58,7 +57,7 @@ class UBatchContext:
|
|||||||
self.cpu_signal_event.set()
|
self.cpu_signal_event.set()
|
||||||
self.cpu_wait_event.clear()
|
self.cpu_wait_event.clear()
|
||||||
self.current_stream = self.compute_stream
|
self.current_stream = self.compute_stream
|
||||||
torch.cuda.set_stream(self.original_stream)
|
torch.cuda.set_stream(self.current_stream)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _restore_context(self):
|
def _restore_context(self):
|
||||||
@ -76,26 +75,31 @@ class UBatchContext:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def _signal_comm_done(self):
|
def _signal_comm_done(self):
|
||||||
|
assert False
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.gpu_comm_done_event.record(self.comm_stream)
|
self.gpu_comm_done_event.record(self.comm_stream)
|
||||||
|
|
||||||
def _signal_compute_done(self):
|
def _signal_compute_done(self):
|
||||||
|
assert False
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.gpu_compute_done_event.record(self.compute_stream)
|
self.gpu_compute_done_event.record(self.compute_stream)
|
||||||
|
|
||||||
def _wait_compute_done(self):
|
def _wait_compute_done(self):
|
||||||
|
assert False
|
||||||
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
# print(f"{self.id} Waiting on COMPUTE stream", flush=True)
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
self.comm_stream.wait_event(self.gpu_compute_done_event)
|
||||||
# print("Compute stream done", flush=True)
|
# print("Compute stream done", flush=True)
|
||||||
|
|
||||||
def _wait_comm_done(self):
|
def _wait_comm_done(self):
|
||||||
|
assert False
|
||||||
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
# print(f"{self.id} Waiting on COMM stream", flush=True)
|
||||||
self.ctx_valid_state()
|
self.ctx_valid_state()
|
||||||
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
self.compute_stream.wait_event(self.gpu_comm_done_event)
|
||||||
# print("Comm stream done", flush=True)
|
# print("Comm stream done", flush=True)
|
||||||
|
|
||||||
def stream_string(self):
|
def stream_string(self):
|
||||||
|
assert False
|
||||||
if current_stream() == self.compute_stream:
|
if current_stream() == self.compute_stream:
|
||||||
assert self.current_stream == self.compute_stream
|
assert self.current_stream == self.compute_stream
|
||||||
return "COMPUTE"
|
return "COMPUTE"
|
||||||
@ -114,6 +118,7 @@ class UBatchContext:
|
|||||||
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
# print(f"UBatchContext: {self.id} resuming CPU", flush=True)
|
||||||
|
|
||||||
def yield_and_switch_from_compute_to_comm(self):
|
def yield_and_switch_from_compute_to_comm(self):
|
||||||
|
assert False
|
||||||
assert current_stream() == self.compute_stream
|
assert current_stream() == self.compute_stream
|
||||||
# dp_rank = get_dp_group().rank_in_group
|
# dp_rank = get_dp_group().rank_in_group
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
@ -129,6 +134,7 @@ class UBatchContext:
|
|||||||
self._wait_compute_done()
|
self._wait_compute_done()
|
||||||
|
|
||||||
def yield_and_switch_from_comm_to_compute(self):
|
def yield_and_switch_from_comm_to_compute(self):
|
||||||
|
assert False
|
||||||
assert current_stream() == self.comm_stream
|
assert current_stream() == self.comm_stream
|
||||||
# dp_rank = get_dp_group().rank_in_group
|
# dp_rank = get_dp_group().rank_in_group
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} "
|
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user