refactor a bunch of misc parameters into a UbatchMetadata class

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-26 00:14:18 +00:00
parent 54deb61b87
commit 78228a67ce

View File

@ -1570,10 +1570,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
@dataclasses.dataclass @dataclasses.dataclass
class UbatchMetadata: class UbatchMetadata:
ubatch_id: int
context: UBatchContext context: UBatchContext
ubatch_slice: UbatchSlice
input_ids: torch.Tensor input_ids: torch.Tensor
positions: torch.Tensor positions: torch.Tensor
inputs_embeds: Optional[torch.Tensor] inputs_embeds: Optional[torch.Tensor]
@ -1618,6 +1615,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert scheduler_output is not None assert scheduler_output is not None
return self._get_model_inputs(tokens_slice, scheduler_output) return self._get_model_inputs(tokens_slice, scheduler_output)
def _make_ubatch_metadata(ubatch_slices,
attn_metadata,
is_dummy_run,
num_tokens_across_dp) -> list[UbatchMetadata]:
ubatch_ctxs = _make_ubatch_contexts(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp
)
# First get some inputs
ubatch_metadata: list[UbatchMetadata] = []
for i, (_, tokens_slice) in enumerate(ubatch_slices):
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(tokens_slice, is_dummy_run)
ubatch_metadata.append(UbatchMetadata(
context=ubatch_ctxs[i],
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
))
return ubatch_metadata
def _run(context, def _run(context,
input_ids, input_ids,
positions, positions,
@ -1637,49 +1659,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return model_output return model_output
@torch.inference_mode() @torch.inference_mode()
def _ubatch_thread(results, ubatch_ctx, input_ids, positions, inputs_embeds, def _ubatch_thread(results, ubatch_metadata):
intermediate_tensors):
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True) # print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
model_output = _run(context=ubatch_ctx, model_output = _run(context=ubatch_metadata.context,
input_ids=input_ids, input_ids=ubatch_metadata.input_ids,
positions=positions, positions=ubatch_metadata.positions,
inputs_embeds=inputs_embeds, inputs_embeds=ubatch_metadata.inputs_embeds,
intermediate_tensors=intermediate_tensors) intermediate_tensors=ubatch_metadata.intermediate_tensors)
results.append((ubatch_ctx.id, model_output)) results.append((ubatch_metadata.context.id, model_output))
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True) # print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_slices, ubatch_ctxs, def _run_ubatches(ubatch_metadata) -> torch.Tensor:
is_dummy_run) -> torch.Tensor:
results: list[tuple[int, torch.Tensor]] = [] results: list[tuple[int, torch.Tensor]] = []
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
root_stream = current_stream() root_stream = current_stream()
# Ubatches will manually manage the forward context, so we override # Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later # it to None here so we can have it restored correctly later
with override_forward_context(None): with override_forward_context(None):
ubatch_threads = [] ubatch_threads = []
for i, (_, tokens_slice) in enumerate(ubatch_slices): for metadata in ubatch_metadata:
assert tokens_slice.stop > tokens_slice.start
use_dummy_input = is_dummy_run
# The only time we don't save results is when one of our ubatches
# is a dummy batch
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(tokens_slice, use_dummy_input)
thread = threading.Thread(target=_ubatch_thread, thread = threading.Thread(target=_ubatch_thread,
args=( args=(
results, results,
ubatch_ctxs[i], metadata,
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
)) ))
ubatch_threads.append(thread) ubatch_threads.append(thread)
thread.start() thread.start()
ubatch_ctxs[0].cpu_wait_event.set() ubatch_metadata[0].context.cpu_wait_event.set()
for thread in ubatch_threads: for thread in ubatch_threads:
thread.join() thread.join()
@ -1769,18 +1776,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# ) # )
# run micro-batched # run micro-batched
if ubatch_slices is not None: if ubatch_slices is not None:
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
# num_tokens = ubatch_slices[1][1].stop # num_tokens = ubatch_slices[1][1].stop
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}") # print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
# assert not is_dummy_run # assert not is_dummy_run
ubatch_ctxs = _make_ubatch_contexts( ubatch_metadata = _make_ubatch_metadata(
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
is_dummy_run=is_dummy_run, is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp num_tokens_across_dp=num_tokens_across_dp
) )
model_output = _run_ubatches(ubatch_slices, model_output = _run_ubatches(ubatch_metadata)
ubatch_ctxs,
is_dummy_run)
# run single batch # run single batch
else: else:
# print("RUN NORMAL") # print("RUN NORMAL")