added initial code for cuda graph capturing ubatches

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-24 22:19:24 +00:00
parent 930efd02ab
commit 96c0c4ea66
2 changed files with 32 additions and 13 deletions

View File

@ -1649,7 +1649,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if ubatch_slices is not None: if ubatch_slices is not None:
# num_tokens = ubatch_slices[1][1].stop # num_tokens = ubatch_slices[1][1].stop
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}") # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
assert not is_dummy_run # assert not is_dummy_run
model_output = _run_ubatches(ubatch_slices, attn_metadata, model_output = _run_ubatches(ubatch_slices, attn_metadata,
is_dummy_run, num_tokens_across_dp=num_tokens_across_dp) is_dummy_run, num_tokens_across_dp=num_tokens_across_dp)
# run single batch # run single batch
@ -2224,10 +2224,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
allow_microbatching: bool = False, allow_microbatching: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
should_microbatch = False if allow_microbatching:
logger.info("ATTEMPTING TO UBATCH THE DUMMY RUN")
# TODO(Sage) We need some more code to properly handle
# mixing normal and dummy runs. The DP padding needs to
# be properly setup. Since we only support microbatching
# in cuda graph capture it's fine to ignore the DP padding
# for now.
should_ubatch = num_tokens >= \
self.parallel_config.microbatching_token_threshold and \
allow_microbatching
# _dummy_run doesn't go through _prepare_inputs so # _dummy_run doesn't go through _prepare_inputs so
# we synchronize with other DP ranks here # we synchronize with other DP ranks here
self.should_ubatch(should_microbatch) should_ubatch = self.should_ubatch(allow_microbatching)
assert not should_ubatch
# Padding for DP # Padding for DP
# logger.info("PADDING DUMMY") # logger.info("PADDING DUMMY")
num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens)
@ -2278,19 +2290,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for layer_name in kv_cache_group_spec.layer_names: for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i attn_metadata[layer_name] = attn_metadata_i
# should_microbatch = ( dummy_microbatches = None
# allow_microbatching # We currently only microbatch if the number of tokens is
# and self.vllm_config.parallel_config.enable_microbatching # over a certain threshold.
# and self.vllm_config.parallel_config.always_microbatch_if_enabled) if should_ubatch:
# dummy_microbatches = [(slice(0, 0), slice(0, 0)), assert num_tokens % 2 == 0
# (slice(0, 0), slice(0, 0))] # TODO (Sage) Add actual slices here
assert False
dummy_microbatches = [(slice(0, 0), slice(0, 0)),
(slice(0, 0), slice(0, 0))]
with self.maybe_dummy_run_with_lora(self.lora_config, with self.maybe_dummy_run_with_lora(self.lora_config,
num_scheduled_tokens): num_scheduled_tokens):
outputs = self._run_model( outputs = self._run_model(
attn_metadata, attn_metadata,
num_tokens, num_tokens,
ubatch_slices=None, ubatch_slices=dummy_microbatches,
is_dummy_run=True, is_dummy_run=True,
num_tokens_across_dp=num_tokens_across_dp num_tokens_across_dp=num_tokens_across_dp
) )
@ -2488,8 +2503,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
total=len(self.cudagraph_batch_sizes)): total=len(self.cudagraph_batch_sizes)):
for _ in range( for _ in range(
self.compilation_config.cudagraph_num_of_warmups): self.compilation_config.cudagraph_num_of_warmups):
self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) self._dummy_run(num_tokens,
self._dummy_run(num_tokens, capture_attn_cudagraph=full_cg) capture_attn_cudagraph=full_cg,
allow_microbatching=allow_microbatching)
self._dummy_run(num_tokens,
capture_attn_cudagraph=full_cg,
allow_microbatching=allow_microbatching)
logger.info("CAPTURE MODEL END") logger.info("CAPTURE MODEL END")
end_time = time.perf_counter() end_time = time.perf_counter()

View File

@ -319,7 +319,7 @@ class Worker(WorkerBase):
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
# TODO: adding allow_microbatching will break non-gpu backends # TODO: adding allow_microbatching will break non-gpu backends
self.model_runner._dummy_run(1, allow_microbatching=True) self.model_runner._dummy_run(1)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request) return self.model_runner.add_lora(lora_request)