[BugFix] AssertionError: Do not capture num_reqs > max_num_reqs for uniform batch (#25505)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Lucas Wilkinson 2025-09-23 20:00:29 -04:00 committed by yewentao256
parent faae7a7eab
commit 8e6a5e7dd4

View File

@ -2828,7 +2828,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _dummy_run( def _dummy_run(
self, self,
num_tokens: int, num_tokens: int,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: Optional[CUDAGraphMode] = None,
force_attention: bool = False, force_attention: bool = False,
uniform_decode: bool = False, uniform_decode: bool = False,
allow_microbatching: bool = True, allow_microbatching: bool = True,
@ -2844,6 +2844,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
Args: Args:
num_tokens: Number of tokens to run the dummy forward pass. num_tokens: Number of tokens to run the dummy forward pass.
cudagraph_runtime_mode: used to control the behavior. cudagraph_runtime_mode: used to control the behavior.
- if not set will determine the cudagraph mode based on using
the self.cudagraph_dispatcher.
- CUDAGraphMode.NONE: No cudagraph, for warm up and profile run - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
- CUDAGraphMode.PIECEWISE: Piecewise cudagraph. - CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
- CUDAGraphMode.FULL: Full cudagraph, attention metadata is - CUDAGraphMode.FULL: Full cudagraph, attention metadata is
@ -2857,7 +2859,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
(1 token) and prefill (multiple tokens) requests. (1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run remove_lora: If False, dummy LoRAs are not destroyed after the run
""" """
assert cudagraph_runtime_mode in { assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
} }
@ -2899,10 +2901,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
elif uniform_decode: elif uniform_decode:
assert not create_mixed_batch assert not create_mixed_batch
num_reqs = cdiv(num_tokens, max_query_len) num_reqs = cdiv(num_tokens, max_query_len)
assert num_reqs <= max_num_reqs, \
f"Do not capture num_reqs {num_reqs} > max_num_reqs " \
f"{max_num_reqs} for uniform batch. Num tokens: " \
f"{num_tokens}, max_query_len: {max_query_len}"
num_scheduled_tokens_list = [max_query_len] * num_reqs num_scheduled_tokens_list = [max_query_len] * num_reqs
if num_tokens % max_query_len != 0: if num_tokens % max_query_len != 0:
num_scheduled_tokens_list[-1] = num_tokens % max_query_len num_scheduled_tokens_list[-1] = num_tokens % max_query_len
@ -3043,18 +3041,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors( intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens, None, False) num_tokens, None, False)
if cudagraph_runtime_mode == CUDAGraphMode.NONE:
batch_descriptor = None # filter out the valid batch descriptor
else: _cg_mode, batch_descriptor = self.cudagraph_dispatcher.dispatch(
# filter out the valid batch descriptor BatchDescriptor(num_tokens=num_tokens,
_cg_mode, batch_descriptor = \ uniform_decode=uniform_decode))
self.cudagraph_dispatcher.dispatch( if cudagraph_runtime_mode is not None:
BatchDescriptor(num_tokens=num_tokens, # we allow forcing NONE when the dispatcher disagrees to support
uniform_decode=uniform_decode)) # warm ups for cudagraph capture
# sanity check assert cudagraph_runtime_mode == CUDAGraphMode.NONE or \
assert cudagraph_runtime_mode == _cg_mode, ( cudagraph_runtime_mode == _cg_mode, (
f"Cudagraph runtime mode mismatch at dummy_run. " f"Cudagraph runtime mode mismatch at dummy_run. "
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.") f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}.")
else:
cudagraph_runtime_mode = _cg_mode
if ubatch_slices is not None: if ubatch_slices is not None:
num_tokens = num_tokens // 2 num_tokens = num_tokens // 2