delete any notion of dummy_ubatch

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-25 23:48:16 +00:00
parent 0e2b4bd546
commit 54deb61b87

View File

@ -1638,7 +1638,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@torch.inference_mode()
def _ubatch_thread(results, ubatch_ctx, input_ids, positions, inputs_embeds,
intermediate_tensors, save_results):
intermediate_tensors):
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
model_output = _run(context=ubatch_ctx,
input_ids=input_ids,
@ -1646,10 +1646,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
# TODO (Sage) I think we can just delete this check now that we
# enforce that all microbatches are valid
if save_results:
results.append((ubatch_ctx.id, model_output))
results.append((ubatch_ctx.id, model_output))
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_slices, ubatch_ctxs,
@ -1663,19 +1660,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
with override_forward_context(None):
ubatch_threads = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
# TODO (Sage) Consolidate all of this is_dummy_run
# is_dummy_ubatch, is attn_metadata==None, num_tokens==0
# nonsense into some unified structure. It's way to hard
# to keep track of and keep consistent right now.
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
assert not is_dummy_ubatch or i == len(
ubatch_slices) - 1 or is_dummy_run
assert tokens_slice.stop > tokens_slice.start
use_dummy_input = is_dummy_run or is_dummy_ubatch
use_dummy_input = is_dummy_run
# The only time we don't save results is when one of our ubatches
# is a dummy batch
save_results = not is_dummy_ubatch or is_dummy_run
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(tokens_slice, use_dummy_input)
thread = threading.Thread(target=_ubatch_thread,
@ -1686,7 +1676,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
positions,
inputs_embeds,
intermediate_tensors,
save_results
))
ubatch_threads.append(thread)
thread.start()