mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 19:57:08 +08:00
separate ubatch and normal runs
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
510e839429
commit
bb0645c644
@ -1588,38 +1588,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return ubatch_metadata
|
||||
|
||||
def _run(context,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors):
|
||||
with context:
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if isinstance(context, UBatchContext):
|
||||
# Clone before we leave the ubatch context
|
||||
model_output = model_output.clone()
|
||||
|
||||
return model_output
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(results, ubatch_metadata):
|
||||
def _ubatch_thread(results, model, ubatch_metadata):
|
||||
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
context = ubatch_metadata.context
|
||||
model_output = _run(context=context,
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors)
|
||||
|
||||
with ubatch_metadata.context:
|
||||
model_output = model(
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
)
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
|
||||
def _run_ubatches(ubatch_metadata) -> torch.Tensor:
|
||||
def _run_ubatches(ubatch_metadata, model) -> torch.Tensor:
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
@ -1630,6 +1612,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
model,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
@ -1657,22 +1640,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs
|
||||
)
|
||||
return _run_ubatches(ubatch_metadata)
|
||||
return _run_ubatches(ubatch_metadata, self.model)
|
||||
# run normal batch
|
||||
else:
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
|
||||
logger.info(f"NORMAL RUN {num_scheduled_tokens}")
|
||||
return _run(
|
||||
context = set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs),
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors
|
||||
with set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs):
|
||||
return self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
def _pool(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user