ubatching fix

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-02 22:22:41 +00:00
parent 3d833aa759
commit 18f7bfb501

View File

@ -1716,7 +1716,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# num_tokens = ubatch_slices[1][1].stop
print(f"RUNNING UBATCH {ubatch_slices} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
# assert not is_dummy_run
compute_stream = torch.cuda.Stream(device=self.device)
compute_stream = torch.cuda.current_stream()
ubatch_metadata = _make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,