_prepare_inputs cleanup

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-08 13:02:21 +00:00
parent 82ae694de6
commit 1a0e7110dd

View File

@ -599,7 +599,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not should_ubatch:
return (None, 0, None)
# For pure decode we can just create ubatchs by cutting the request
# For pure decode we can just create ubatches by cutting the request
# in half
b0_reqs_end = num_reqs // 2
b0_tokens_end = total_num_scheduled_tokens // 2
@ -610,6 +610,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
(slice(b0_reqs_end, num_reqs),
slice(b0_tokens_end, total_num_scheduled_tokens)),
]
# Compute ubatch padding. This currently only accounts for DP padding
num_pad_tokens = 0
num_tokens_after_padding = None
ubatch_abort = False
@ -817,19 +819,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if ubatch_slices is not None:
for ubid, (req_slice, token_slice) in enumerate(ubatch_slices):
# Run a dummy batch if its a empty ubatch
if token_slice.stop <= token_slice.start:
attn_metadata_i = None
else:
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].
build_slice(
req_slice=req_slice,
token_slice=token_slice,
max_query_len=max(tokens[req_slice]),
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))
assert token_slice.stop > token_slice.start
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].
build_slice(
req_slice=req_slice,
token_slice=token_slice,
max_query_len=max(tokens[req_slice]),
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i
@ -1416,7 +1415,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice)
# This is where the second ubatch is adjusted to account for the padding.
# Should be called after attention metadata creation. This just extends
# Should be called after attention metadata creation. This just pads
# the second ubatch slice out to the total number of tokens
# (num_tokens + padding)
def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices, num_total_tokens: int):