refactor a bunch of misc parameters into a UbatchMetadata class

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-26 00:14:18 +00:00
parent 54deb61b87
commit 78228a67ce

View File

@ -1570,10 +1570,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@dataclasses.dataclass
class UbatchMetadata:
ubatch_id: int
context: UBatchContext
ubatch_slice: UbatchSlice
input_ids: torch.Tensor
positions: torch.Tensor
inputs_embeds: Optional[torch.Tensor]
@ -1618,6 +1615,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert scheduler_output is not None
return self._get_model_inputs(tokens_slice, scheduler_output)
def _make_ubatch_metadata(ubatch_slices,
attn_metadata,
is_dummy_run,
num_tokens_across_dp) -> list[UbatchMetadata]:
ubatch_ctxs = _make_ubatch_contexts(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp
)
# First get some inputs
ubatch_metadata: list[UbatchMetadata] = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(tokens_slice, is_dummy_run)
ubatch_metadata.append(UbatchMetadata(
context=ubatch_ctxs[i],
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
))
return ubatch_metadata
def _run(context,
input_ids,
positions,
@ -1637,49 +1659,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return model_output
@torch.inference_mode()
def _ubatch_thread(results, ubatch_ctx, input_ids, positions, inputs_embeds,
intermediate_tensors):
def _ubatch_thread(results, ubatch_metadata):
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
model_output = _run(context=ubatch_ctx,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
model_output = _run(context=ubatch_metadata.context,
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
inputs_embeds=ubatch_metadata.inputs_embeds,
intermediate_tensors=ubatch_metadata.intermediate_tensors)
results.append((ubatch_ctx.id, model_output))
results.append((ubatch_metadata.context.id, model_output))
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_slices, ubatch_ctxs,
is_dummy_run) -> torch.Tensor:
def _run_ubatches(ubatch_metadata) -> torch.Tensor:
results: list[tuple[int, torch.Tensor]] = []
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
root_stream = current_stream()
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
ubatch_threads = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
assert tokens_slice.stop > tokens_slice.start
use_dummy_input = is_dummy_run
# The only time we don't save results is when one of our ubatches
# is a dummy batch
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(tokens_slice, use_dummy_input)
for metadata in ubatch_metadata:
thread = threading.Thread(target=_ubatch_thread,
args=(
results,
ubatch_ctxs[i],
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
metadata,
))
ubatch_threads.append(thread)
thread.start()
ubatch_ctxs[0].cpu_wait_event.set()
ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads:
thread.join()
@ -1769,18 +1776,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# )
# run micro-batched
if ubatch_slices is not None:
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
# num_tokens = ubatch_slices[1][1].stop
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
# assert not is_dummy_run
ubatch_ctxs = _make_ubatch_contexts(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp
)
model_output = _run_ubatches(ubatch_slices,
ubatch_ctxs,
is_dummy_run)
ubatch_metadata = _make_ubatch_metadata(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp
)
model_output = _run_ubatches(ubatch_metadata)
# run single batch
else:
# print("RUN NORMAL")