misc changes

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-02 15:58:41 +00:00
parent f7a3ee0ea1
commit c0efbbb5de
3 changed files with 29 additions and 137 deletions

View File

@ -150,7 +150,9 @@ def main(
max_num_seqs=max_num_seqs,
gpu_memory_utilization=gpu_memory_utilization,
)
print("BEFORE GENERATE")
outputs = llm.generate(prompts, sampling_params)
print("AFTER GENERATE")
# Print the outputs.
for i, output in enumerate(outputs):
if i >= 5:

View File

@ -134,14 +134,14 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send,
)
yield_and_switch_from_compute_to_comm_impl(schedule="default")
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
dispatch(True) # Send
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER SEND SYNC", flush=True)
dispatch(False) # Recv
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default")
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
# torch.cuda.synchronize()
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, 0:1]
@ -185,11 +185,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
do_recv=not send,
)
yield_and_switch_from_compute_to_comm_impl(schedule="default")
# yield_and_switch_from_compute_to_comm_impl(schedule="default")
combine(True)
# torch.cuda.synchronize()
# print(f"{ubatch_id} AFTER COMBINE SEND SYNC", flush=True)
combine(False)
# print(f"{ubatch_id} AFTER COMBINE RECV SYNC", flush=True)
yield_and_switch_from_comm_to_compute_impl(schedule="default")
# yield_and_switch_from_comm_to_compute_impl(schedule="default")
# torch.cuda.synchronize()

View File

@ -245,10 +245,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list(
reversed(self.compilation_config.cudagraph_capture_sizes))
logger.info(f"cudagraph capture sizes {self.cudagraph_batch_sizes}")
self.full_cuda_graph = self.compilation_config.full_cuda_graph
# self.full_cuda_graph = True
logger.info(f"full_cuda_graph {self.full_cuda_graph}")
# Cache the device properties.
self._init_device_properties()
@ -1683,11 +1680,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
start_signal=None):
intermediate_tensors):
with context:
if start_signal is not None:
start_signal.wait()
model_output = self.model(
input_ids=input_ids,
positions=positions,
@ -1701,64 +1695,41 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return model_output
@torch.inference_mode()
def _ubatch_thread(results, ubatch_metadata, start_signal):
def _ubatch_thread(results, ubatch_metadata):
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
context = ubatch_metadata.context
with torch.cuda.stream(context.compute_stream):
_ = torch.cuda.current_blas_handle()
with torch.cuda.stream(context.comm_stream):
_ = torch.cuda.current_blas_handle()
model_output = _run(context=ubatch_metadata.context,
model_output = _run(context=context,
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
inputs_embeds=ubatch_metadata.inputs_embeds,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
start_signal=start_signal)
intermediate_tensors=ubatch_metadata.intermediate_tensors)
results.append((ubatch_metadata.context.id, model_output))
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_metadata, num_tokens, should_capture=False) -> torch.Tensor:
def _run_ubatches(ubatch_metadata) -> torch.Tensor:
results: list[tuple[int, torch.Tensor]] = []
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
ubatch_threads = []
start_signals = []
for metadata in ubatch_metadata:
start_signal = threading.Event()
thread = threading.Thread(target=_ubatch_thread,
args=(
results,
metadata,
start_signal,
))
ubatch_threads.append(thread)
thread.start()
start_signals.append(start_signal)
# DO capture
cudagraph_metadata = \
CUDAGraphMetaData(
cudagraph=torch.cuda.CUDAGraph(),
using_ubatching=True
)
with torch.cuda.graph(cudagraph_metadata.cudagraph,
stream=compute_stream):
# logger.info("STARTING WAKEUP LOOP")
for start_signal in start_signals:
start_signal.set()
# logger.info("FINISHED WAKEUP LOOP")
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
cudagraph_metadata.outputs = result
logger.info(f"Capturing for {num_tokens} tokens")
self.cudagraphs[num_tokens] = cudagraph_metadata
return cudagraph_metadata.outputs
# logger.info("FINISHED WAKEUP LOOP")
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
sorted_results = [value for position, value in sorted(results)]
result = torch.cat(sorted_results, dim=0)
return result
# run micro-batched
if ubatch_slices is not None:
@ -1775,51 +1746,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs
)
if num_scheduled_tokens not in self.cudagraphs \
and not skip_cuda_graphs and build_cuda_graph:
return _run_ubatches(ubatch_metadata, num_scheduled_tokens, should_capture=True)
elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
cudagraph_metadata.cudagraph.replay()
return cudagraph_metadata.outputs
else:
assert False
return _run_ubatches(ubatch_metadata)
return _run_ubatches(ubatch_metadata)
# run single batch
else:
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
# if num_scheduled_tokens not in self.cudagraphs \
# and not skip_cuda_graphs and build_cuda_graph:
# assert False
# logger.info(f"GRAPH BUILD{num_scheduled_tokens}")
# self.cudagraphs[num_scheduled_tokens] = \
# CUDAGraphMetaData(
# cudagraph=torch.cuda.CUDAGraph(),
# using_ubatching=False
# )
# with torch.cuda.graph(self.cudagraphs[num_scheduled_tokens].cudagraph):
# model_output = _run(
# context = set_forward_context(attn_metadata,
# vllm_config=self.vllm_config,
# num_tokens=num_scheduled_tokens or 1,
# num_tokens_across_dp=num_tokens_across_dp,
# skip_cuda_graphs=skip_cuda_graphs),
# input_ids=input_ids,
# positions=positions,
# inputs_embeds=inputs_embeds,
# intermediate_tensors=intermediate_tensors
# )
# self.cudagraphs[num_scheduled_tokens].outputs = model_output
# return self.cudagraphs[num_scheduled_tokens].outputs
# elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
# assert False
# # logger.info(f"GRAPH REPLAY {num_scheduled_tokens}")
# self.cudagraphs[num_scheduled_tokens].cudagraph.replay()
# return self.cudagraphs[num_scheduled_tokens].outputs
# else:
# logger.info(f"NORMAL RUN {num_scheduled_tokens}")
logger.info(f"NORMAL RUN {num_scheduled_tokens}")
return _run(
context = set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
@ -2518,13 +2450,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# _dummy_run doesn't go through _prepare_inputs so
# we synchronize with other DP ranks here
# logger.info(f"NUM TOKENS {num_tokens} SHOULD UBATCH {should_ubatch}")
should_ubatch = self.should_ubatch(False)
_ = self.should_ubatch(False)
# Padding for DP
# logger.info("PADDING DUMMY")
num_tokens_across_dp = None
num_pad = 0
if not should_ubatch:
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
num_tokens += num_pad
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
@ -2544,26 +2473,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# We currently only microbatch if the number of tokens is
# over a certain threshold.
# logger.info("PADDING DUMMY DONE")
if should_ubatch:
# We only support decode-only cudagraphs
assert num_reqs == num_tokens
assert num_tokens % 2 == 0
num_tokens_per_ubatch = num_tokens // 2
num_tokens_across_dp = torch.tensor([num_tokens_per_ubatch] * 2,
device="cpu",
dtype=torch.int32)
ubatch_slices = [(slice(0, num_reqs // 2),
slice(0, num_tokens // 2)),
(slice(num_reqs // 2, num_reqs),
slice(num_tokens // 2, num_tokens))]
# attn_metadata: Optional[dict[str, Any]] = None
attn_metadata: Optional[PerLayerAttnMetadata]= None
attn_metadata: Optional[dict[str, Any]] = None
if capture_attn_cudagraph:
attn_metadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
query_start_loc = self.query_start_loc[:num_reqs + 1]
# Make sure max_model_len is used at the graph capture time.
@ -2586,33 +2498,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
if ubatch_slices is not None:
for ubid, (req_slice, token_slice) in enumerate(ubatch_slices):
# Run a dummy batch if its a empty ubatch
if token_slice.stop <= token_slice.start:
attn_metadata_i = None
else:
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].
build_slice(
req_slice=req_slice,
token_slice=token_slice,
max_query_len=max_query_len,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
ubatch_id=ubid
))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
# assert attn_metadata_i is not None
# What if it's None? Do we still add it to the list?
attn_metadata[ubid][layer_name] = attn_metadata_i
else:
attn_metadata_i = self.attn_metadata_builders[
kv_cache_group_id].build_for_cudagraph_capture(
common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
attn_metadata_i = self.attn_metadata_builders[
kv_cache_group_id].build_for_cudagraph_capture(
common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
with self.maybe_dummy_run_with_lora(self.lora_config,
@ -2870,7 +2760,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# can reuse the memory pool allocated for the large shapes.
with graph_capture(device=self.device):
full_cg = self.full_cuda_graph
allow_microbatching = self.parallel_config.enable_microbatching
allow_microbatching = False
for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes),
desc="Capturing CUDA graphs",
total=len(self.cudagraph_batch_sizes)):