separate ubatch and normal runs

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-07-03 17:07:58 +00:00
parent 510e839429
commit bb0645c644

View File

@ -1588,38 +1588,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return ubatch_metadata
def _run(context,
input_ids,
positions,
inputs_embeds,
intermediate_tensors):
with context:
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if isinstance(context, UBatchContext):
# Clone before we leave the ubatch context
model_output = model_output.clone()
return model_output
@torch.inference_mode()
def _ubatch_thread(results, ubatch_metadata):
def _ubatch_thread(results, model, ubatch_metadata):
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
context = ubatch_metadata.context
model_output = _run(context=context,
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
inputs_embeds=ubatch_metadata.inputs_embeds,
intermediate_tensors=ubatch_metadata.intermediate_tensors)
with ubatch_metadata.context:
model_output = model(
input_ids=ubatch_metadata.input_ids,
positions=ubatch_metadata.positions,
intermediate_tensors=ubatch_metadata.intermediate_tensors,
inputs_embeds=ubatch_metadata.inputs_embeds,
)
results.append((ubatch_metadata.context.id, model_output))
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
def _run_ubatches(ubatch_metadata) -> torch.Tensor:
def _run_ubatches(ubatch_metadata, model) -> torch.Tensor:
results: list[tuple[int, torch.Tensor]] = []
# Ubatches will manually manage the forward context, so we override
@ -1630,6 +1612,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
thread = threading.Thread(target=_ubatch_thread,
args=(
results,
model,
metadata,
))
ubatch_threads.append(thread)
@ -1657,22 +1640,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs
)
return _run_ubatches(ubatch_metadata)
return _run_ubatches(ubatch_metadata, self.model)
# run normal batch
else:
input_ids, positions, inputs_embeds, intermediate_tensors = \
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
logger.info(f"NORMAL RUN {num_scheduled_tokens}")
return _run(
context = set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs),
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors
with set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_scheduled_tokens or 1,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs):
return self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
def _pool(
self,