mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 08:32:18 +08:00
refactor a bunch of misc parameters into a UbatchMetadata class
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
54deb61b87
commit
78228a67ce
@ -1570,10 +1570,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class UbatchMetadata:
|
class UbatchMetadata:
|
||||||
ubatch_id: int
|
|
||||||
context: UBatchContext
|
context: UBatchContext
|
||||||
ubatch_slice: UbatchSlice
|
|
||||||
|
|
||||||
input_ids: torch.Tensor
|
input_ids: torch.Tensor
|
||||||
positions: torch.Tensor
|
positions: torch.Tensor
|
||||||
inputs_embeds: Optional[torch.Tensor]
|
inputs_embeds: Optional[torch.Tensor]
|
||||||
@ -1618,6 +1615,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
assert scheduler_output is not None
|
assert scheduler_output is not None
|
||||||
return self._get_model_inputs(tokens_slice, scheduler_output)
|
return self._get_model_inputs(tokens_slice, scheduler_output)
|
||||||
|
|
||||||
|
def _make_ubatch_metadata(ubatch_slices,
|
||||||
|
attn_metadata,
|
||||||
|
is_dummy_run,
|
||||||
|
num_tokens_across_dp) -> list[UbatchMetadata]:
|
||||||
|
ubatch_ctxs = _make_ubatch_contexts(
|
||||||
|
ubatch_slices=ubatch_slices,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
is_dummy_run=is_dummy_run,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp
|
||||||
|
)
|
||||||
|
# First get some inputs
|
||||||
|
ubatch_metadata: list[UbatchMetadata] = []
|
||||||
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
|
model_inputs(tokens_slice, is_dummy_run)
|
||||||
|
ubatch_metadata.append(UbatchMetadata(
|
||||||
|
context=ubatch_ctxs[i],
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
intermediate_tensors=intermediate_tensors
|
||||||
|
))
|
||||||
|
|
||||||
|
return ubatch_metadata
|
||||||
|
|
||||||
def _run(context,
|
def _run(context,
|
||||||
input_ids,
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
@ -1637,49 +1659,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _ubatch_thread(results, ubatch_ctx, input_ids, positions, inputs_embeds,
|
def _ubatch_thread(results, ubatch_metadata):
|
||||||
intermediate_tensors):
|
|
||||||
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
# print(f"Starting Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||||
model_output = _run(context=ubatch_ctx,
|
model_output = _run(context=ubatch_metadata.context,
|
||||||
input_ids=input_ids,
|
input_ids=ubatch_metadata.input_ids,
|
||||||
positions=positions,
|
positions=ubatch_metadata.positions,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=ubatch_metadata.inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors)
|
intermediate_tensors=ubatch_metadata.intermediate_tensors)
|
||||||
|
|
||||||
results.append((ubatch_ctx.id, model_output))
|
results.append((ubatch_metadata.context.id, model_output))
|
||||||
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
# print(f"Finishing Request on ubatch: {ubatch_ctx.id}", flush=True)
|
||||||
|
|
||||||
def _run_ubatches(ubatch_slices, ubatch_ctxs,
|
def _run_ubatches(ubatch_metadata) -> torch.Tensor:
|
||||||
is_dummy_run) -> torch.Tensor:
|
|
||||||
results: list[tuple[int, torch.Tensor]] = []
|
results: list[tuple[int, torch.Tensor]] = []
|
||||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
|
||||||
root_stream = current_stream()
|
root_stream = current_stream()
|
||||||
|
|
||||||
# Ubatches will manually manage the forward context, so we override
|
# Ubatches will manually manage the forward context, so we override
|
||||||
# it to None here so we can have it restored correctly later
|
# it to None here so we can have it restored correctly later
|
||||||
with override_forward_context(None):
|
with override_forward_context(None):
|
||||||
ubatch_threads = []
|
ubatch_threads = []
|
||||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
for metadata in ubatch_metadata:
|
||||||
assert tokens_slice.stop > tokens_slice.start
|
|
||||||
|
|
||||||
use_dummy_input = is_dummy_run
|
|
||||||
|
|
||||||
# The only time we don't save results is when one of our ubatches
|
|
||||||
# is a dummy batch
|
|
||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
|
||||||
model_inputs(tokens_slice, use_dummy_input)
|
|
||||||
thread = threading.Thread(target=_ubatch_thread,
|
thread = threading.Thread(target=_ubatch_thread,
|
||||||
args=(
|
args=(
|
||||||
results,
|
results,
|
||||||
ubatch_ctxs[i],
|
metadata,
|
||||||
input_ids,
|
|
||||||
positions,
|
|
||||||
inputs_embeds,
|
|
||||||
intermediate_tensors,
|
|
||||||
))
|
))
|
||||||
ubatch_threads.append(thread)
|
ubatch_threads.append(thread)
|
||||||
thread.start()
|
thread.start()
|
||||||
ubatch_ctxs[0].cpu_wait_event.set()
|
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||||
|
|
||||||
for thread in ubatch_threads:
|
for thread in ubatch_threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
@ -1769,18 +1776,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# )
|
# )
|
||||||
# run micro-batched
|
# run micro-batched
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
|
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||||
# num_tokens = ubatch_slices[1][1].stop
|
# num_tokens = ubatch_slices[1][1].stop
|
||||||
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
# print(f"RUNNING UBATCH {num_tokens} is_dummy_run: {is_dummy_run} num_tokens_across_dp{num_tokens_across_dp}")
|
||||||
# assert not is_dummy_run
|
# assert not is_dummy_run
|
||||||
ubatch_ctxs = _make_ubatch_contexts(
|
ubatch_metadata = _make_ubatch_metadata(
|
||||||
ubatch_slices=ubatch_slices,
|
ubatch_slices=ubatch_slices,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
is_dummy_run=is_dummy_run,
|
is_dummy_run=is_dummy_run,
|
||||||
num_tokens_across_dp=num_tokens_across_dp
|
num_tokens_across_dp=num_tokens_across_dp
|
||||||
)
|
)
|
||||||
model_output = _run_ubatches(ubatch_slices,
|
model_output = _run_ubatches(ubatch_metadata)
|
||||||
ubatch_ctxs,
|
|
||||||
is_dummy_run)
|
|
||||||
# run single batch
|
# run single batch
|
||||||
else:
|
else:
|
||||||
# print("RUN NORMAL")
|
# print("RUN NORMAL")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user