mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 19:37:10 +08:00
refactor a bunch of misc parameters into a UbatchMetadata class
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
54deb61b87
commit
78228a67ce
@ -1570,10 +1570,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UbatchMetadata:
|
||||
ubatch_id: int
|
||||
context: UBatchContext
|
||||
ubatch_slice: UbatchSlice
|
||||
|
||||
input_ids: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
@ -1618,6 +1615,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert scheduler_output is not None
|
||||
return self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
|
||||
def _make_ubatch_metadata(ubatch_slices,
|
||||
attn_metadata,
|
||||
is_dummy_run,
|
||||
num_tokens_across_dp) -> list[UbatchMetadata]:
|
||||
ubatch_ctxs = _make_ubatch_contexts(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
is_dummy_run=is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp
|
||||
)
|
||||
# First get some inputs
|
||||
ubatch_metadata: list[UbatchMetadata] = []
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(tokens_slice, is_dummy_run)
|
||||
ubatch_metadata.append(UbatchMetadata(
|
||||
context=ubatch_ctxs[i],
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors
|
||||
))
|
||||
|
||||
return ubatch_metadata
|
||||
|
||||
def _run(context,
|
||||
input_ids,
|
||||
positions,
|
||||
@ -1637,49 +1659,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return model_output
|
||||
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(results, ubatch_ctx, input_ids, positions, inputs_embeds,
|
||||
intermediate_tensors):
|
||||
def _ubatch_thread(results, ubatch_metadata):
|
||||
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
model_output = _run(context=ubatch_ctx,
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
model_output = _run(context=ubatch_metadata.context,
|
||||
input_ids=ubatch_metadata.input_ids,
|
||||
positions=ubatch_metadata.positions,
|
||||
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||
intermediate_tensors=ubatch_metadata.intermediate_tensors)
|
||||
|
||||
results.append((ubatch_ctx.id, model_output))
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||
|
||||
def _run_ubatches(ubatch_slices, ubatch_ctxs,
|
||||
is_dummy_run) -> torch.Tensor:
|
||||
def _run_ubatches(ubatch_metadata) -> torch.Tensor:
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
root_stream = current_stream()
|
||||
|
||||
# Ubatches will manually manage the forward context, so we override
|
||||
# it to None here so we can have it restored correctly later
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
assert tokens_slice.stop > tokens_slice.start
|
||||
|
||||
use_dummy_input = is_dummy_run
|
||||
|
||||
# The only time we don't save results is when one of our ubatches
|
||||
# is a dummy batch
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(tokens_slice, use_dummy_input)
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
ubatch_ctxs[i],
|
||||
input_ids,
|
||||
positions,
|
||||
inputs_embeds,
|
||||
intermediate_tensors,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
ubatch_ctxs[0].cpu_wait_event.set()
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
@ -1769,18 +1776,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# )
|
||||
# run micro-batched
|
||||
if ubatch_slices is not None:
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
# num_tokens = ubatch_slices[1][1].stop
|
||||
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
||||
# assert not is_dummy_run
|
||||
ubatch_ctxs = _make_ubatch_contexts(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
is_dummy_run=is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp
|
||||
)
|
||||
model_output = _run_ubatches(ubatch_slices,
|
||||
ubatch_ctxs,
|
||||
is_dummy_run)
|
||||
ubatch_metadata = _make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
is_dummy_run=is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp
|
||||
)
|
||||
model_output = _run_ubatches(ubatch_metadata)
|
||||
# run single batch
|
||||
else:
|
||||
# print("RUN NORMAL")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user