more refactoring

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-06-25 23:43:49 +00:00
parent e2ba707d64
commit 0e2b4bd546

View File

@ -1618,11 +1618,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert scheduler_output is not None
return self._get_model_inputs(tokens_slice, scheduler_output)
def _run(token_slice: slice,
context,
use_dummy_input: bool = False):
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(token_slice, use_dummy_input)
def _run(context,
input_ids,
positions,
inputs_embeds,
intermediate_tensors):
with context:
model_output = self.model(
input_ids=input_ids,
@ -1637,26 +1637,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return model_output
@torch.inference_mode()
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
use_dummy_input):
def _ubatch_thread(results, ubatch_ctx, input_ids, positions, inputs_embeds,
intermediate_tensors, save_results):
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
model_output = _run(context=ubatch_ctx,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors)
# TODO (Sage) I think we can just delete this check now that we
# enforce that all microbatches are valid
if save_results:
results.append((ubatch_ctx.id, model_output))
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_slices, attn_metadata,
is_dummy_run, num_tokens_across_dp) -> torch.Tensor:
def _run_ubatches(ubatch_slices, ubatch_ctxs,
is_dummy_run) -> torch.Tensor:
results: list[tuple[int, torch.Tensor]] = []
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
root_stream = current_stream()
ubatch_ctxs = _make_ubatch_contexts(ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp)
# Ubatches will manually manage the forward context, so we override
# it to None here so we can have it restored correctly later
with override_forward_context(None):
@ -1670,15 +1671,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert not is_dummy_ubatch or i == len(
ubatch_slices) - 1 or is_dummy_run
use_dummy_input = is_dummy_run or is_dummy_ubatch
# The only time we don't save results is when one of our ubatches
# is a dummy batch
save_results = not is_dummy_ubatch or is_dummy_run
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(tokens_slice, use_dummy_input)
thread = threading.Thread(target=_ubatch_thread,
args=(
ubatch_ctxs[i],
tokens_slice,
results,
not is_dummy_ubatch
or is_dummy_run,
is_dummy_ubatch
or is_dummy_run
ubatch_ctxs[i],
input_ids,
positions,
inputs_embeds,
intermediate_tensors,
save_results
))
ubatch_threads.append(thread)
thread.start()
@ -1775,21 +1783,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# num_tokens = ubatch_slices[1][1].stop
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
# assert not is_dummy_run
ubatch_ctxs = _make_ubatch_contexts(
ubatch_slices=ubatch_slices,
attn_metadata=attn_metadata,
is_dummy_run=is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp
)
model_output = _run_ubatches(ubatch_slices,
attn_metadata,
is_dummy_run,
num_tokens_across_dp=num_tokens_across_dp)
ubatch_ctxs,
is_dummy_run)
# run single batch
else:
# print("RUN NORMAL")
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
model_output = _run(
slice(0, num_scheduled_tokens),
set_forward_context(attn_metadata,
context = set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs),
is_dummy_run)
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
)
return model_output