add attention splitting to dummy runs

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-25 21:39:33 +00:00
parent 144b148de2
commit 44a2b3494e
3 changed files with 58 additions and 27 deletions

View File

@ -630,10 +630,10 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
m.max_query_len = 1 # decode-only
# Update state usually set in reorder_batch.
self._num_decodes = m.num_reqs
self._num_decode_tokens = m.num_actual_tokens
self._num_prefills = 0
self._num_prefill_tokens = 0
# self._num_decodes = m.num_reqs
# self._num_decode_tokens = m.num_actual_tokens
# self._num_prefills = 0
# self._num_prefill_tokens = 0
return self.build(0, m)
def use_cascade_attention(self, *args, **kwargs) -> bool:

View File

@ -77,7 +77,7 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
n = num_splits.size(0)
logger.info(f"N : {n} bs: {self.runner.cudagraph_batch_sizes[-1]}")
if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1]:
if self.runner.full_cuda_graph and (n-1) <= self.runner.cudagraph_batch_sizes[-1] // 2:
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_tile_scheduler_metadata is None:
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata

View File

@ -228,7 +228,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list(
reversed(self.compilation_config.cudagraph_capture_sizes))
logger.info(f"cudagraph capture sizes {self.cudagraph_batch_sizes}")
self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.full_cuda_graph = True
logger.info(f"full_cuda_graph {self.full_cuda_graph}")
@ -558,7 +558,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.refresh_sampling_metadata()
def _ubatch_split(
self, query_start_loc_np: torch.Tensor,
self,
max_num_scheduled_tokens: int,
scheduler_output: "SchedulerOutput") -> Optional[UBatchSlices]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -707,7 +707,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
ubatch_slices: Optional[UBatchSlices] = self._ubatch_split(
self.query_start_loc_np, max_num_scheduled_tokens,
max_num_scheduled_tokens,
scheduler_output)
should_ubatch = self.should_ubatch(True if ubatch_slices else False)
# Don't attempt to microbatch unless every other DP worker is also microbatching
@ -1343,6 +1343,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_padded = num_tokens_unpadded
logger.info(f"num tokens unpadded: {num_tokens_unpadded} cudagraphs: {self.cudagraph_batch_sizes}")
if (self.use_cuda_graph
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
@ -2279,7 +2280,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# for now.
should_ubatch = num_tokens >= \
self.parallel_config.microbatching_token_threshold and \
allow_microbatching
allow_microbatching and capture_attn_cudagraph
# _dummy_run doesn't go through _prepare_inputs so
# we synchronize with other DP ranks here
should_ubatch = self.should_ubatch(allow_microbatching)
@ -2304,9 +2305,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
attn_metadata: Optional[dict[str, Any]] = None
ubatch_slices = None
# We currently only microbatch if the number of tokens is
# over a certain threshold.
if should_ubatch:
# We only support decode-only cudagraphs
assert num_reqs == num_tokens
assert num_tokens % 2 == 0
ubatch_slices = [(slice(0, num_reqs // 2),
slice(0, num_tokens // 2)),
(slice(num_reqs // 2, num_reqs),
slice(num_tokens // 2, num_tokens))]
# attn_metadata: Optional[dict[str, Any]] = None
attn_metadata: Optional[PerLayerAttnMetadata]= None
if capture_attn_cudagraph:
attn_metadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
query_start_loc = self.query_start_loc[:num_reqs + 1]
# Make sure max_model_len is used at the graph capture time.
@ -2316,39 +2333,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
non_blocking=True)
seq_lens = self.seq_lens[:num_reqs]
max_query_len = num_tokens
if ubatch_slices is not None:
max_query_len = 1
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc,
seq_lens=seq_lens,
num_reqs=num_reqs,
num_actual_tokens=num_tokens,
max_query_len=num_tokens,
max_query_len=max_query_len,
)
for kv_cache_group_id, kv_cache_group_spec in enumerate(
self.kv_cache_config.kv_cache_groups):
if ubatch_slices is not None:
for ubid, (req_slice, token_slice) in enumerate(ubatch_slices):
# Run a dummy batch if its a empty ubatch
if token_slice.stop <= token_slice.start:
attn_metadata_i = None
else:
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].
build_slice(
req_slice=req_slice,
token_slice=token_slice,
max_query_len=max_query_len,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
# assert attn_metadata_i is not None
# What if it's None? Do we still add it to the list?
attn_metadata[ubid][layer_name] = attn_metadata_i
else:
attn_metadata_i = self.attn_metadata_builders[
kv_cache_group_id].build_for_cudagraph_capture(
common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
attn_metadata_i = self.attn_metadata_builders[
kv_cache_group_id].build_for_cudagraph_capture(
common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
dummy_microbatches = None
# We currently only microbatch if the number of tokens is
# over a certain threshold.
if should_ubatch:
assert num_tokens % 2 == 0
dummy_microbatches = [(slice(0, num_tokens // 2),
slice(0, num_tokens // 2)),
(slice(num_tokens // 2, num_tokens),
slice(num_tokens // 2, num_tokens))]
with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens):
outputs = self._run_model(
attn_metadata,
num_tokens,
ubatch_slices=dummy_microbatches,
ubatch_slices=ubatch_slices,
is_dummy_run=True,
num_tokens_across_dp=num_tokens_across_dp,
build_cuda_graph=build_cuda_graph