mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 08:37:05 +08:00
more refactoring
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
e2ba707d64
commit
0e2b4bd546
@ -1618,11 +1618,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert scheduler_output is not None
|
||||
return self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
|
||||
def _run(token_slice: slice,
|
||||
context,
|
||||
use_dummy_input: bool = False):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(token_slice, use_dummy_input)
|
||||
def _run(context,
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors):
|
||||
with context:
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -1637,26 +1637,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return model_output
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
|
||||
use_dummy_input):
|
||||
def _ubatch_thread(results, ubatch_ctx, input_ids, positions, inputs_embeds,
|
||||
intermediate_tensors, save_results):
|
||||
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
||||
model_output = _run(context=ubatch_ctx,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
# TODO (Sage) I think we can just delete this check now that we
|
||||
# enforce that all microbatches are valid
|
||||
if save_results:
|
||||
results.append((ubatch_ctx.id, model_output))
|
||||
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
|
||||
def _run_ubatches(ubatch_slices, attn_metadata,
|
||||
is_dummy_run, num_tokens_across_dp) -> torch.Tensor:
|
||||
def _run_ubatches(ubatch_slices, ubatch_ctxs,
|
||||
is_dummy_run) -> torch.Tensor:
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
root_stream = current_stream()
|
||||
|
||||
ubatch_ctxs = _make_ubatch_contexts(ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
is_dummy_run=is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp)
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
@ -1670,15 +1671,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert not is_dummy_ubatch or i == len(
|
||||
ubatch_slices) - 1 or is_dummy_run
|
||||
|
||||
use_dummy_input = is_dummy_run or is_dummy_ubatch
|
||||
|
||||
# The only time we don't save results is when one of our ubatches
|
||||
# is a dummy batch
|
||||
save_results = not is_dummy_ubatch or is_dummy_run
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(tokens_slice, use_dummy_input)
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
ubatch_ctxs[i],
|
||||
tokens_slice,
|
||||
results,
|
||||
not is_dummy_ubatch
|
||||
or is_dummy_run,
|
||||
is_dummy_ubatch
|
||||
or is_dummy_run
|
||||
ubatch_ctxs[i],
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
save_results
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
@ -1775,21 +1783,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# num_tokens = ubatch_slices[1][1].stop
|
||||
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
||||
# assert not is_dummy_run
|
||||
ubatch_ctxs = _make_ubatch_contexts(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
is_dummy_run=is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp
|
||||
)
|
||||
model_output = _run_ubatches(ubatch_slices,
|
||||
attn_metadata,
|
||||
is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp)
|
||||
ubatch_ctxs,
|
||||
is_dummy_run)
|
||||
# run single batch
|
||||
else:
|
||||
# print("RUN NORMAL")
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
|
||||
model_output = _run(
|
||||
slice(0, num_scheduled_tokens),
|
||||
set_forward_context(attn_metadata,
|
||||
context = set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs),
|
||||
is_dummy_run)
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors
|
||||
)
|
||||
|
||||
return model_output
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user