mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 22:27:12 +08:00
split some of the ubatching logic out of _run_model
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
908e9f8f54
commit
10ca263058
@ -75,6 +75,7 @@ from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||
|
||||
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
import dataclasses
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
@ -98,7 +99,14 @@ PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
||||
UbatchSlice: TypeAlias = tuple[slice, slice]
|
||||
UBatchSlices: TypeAlias = list[UbatchSlice]
|
||||
|
||||
import dataclasses
|
||||
@dataclasses.dataclass
|
||||
class UbatchMetadata:
|
||||
context: UBatchContext
|
||||
input_ids: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
intermediate_tensors: Optional[IntermediateTensors]
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
@ -1498,85 +1506,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
tokens_slice, intermediate_tensors, True)
|
||||
return input_ids, positions, inputs_embeds, intermediate_tensors
|
||||
|
||||
def _run_model(self,
|
||||
attn_metadata: Optional[PerLayerAttnMetadata],
|
||||
num_scheduled_tokens: Optional[int],
|
||||
ubatch_slices: Optional[UBatchSlices] = None,
|
||||
scheduler_output: Optional["SchedulerOutput"] = None,
|
||||
is_dummy_run: bool = False,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
skip_cuda_graphs: bool = False):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UbatchMetadata:
|
||||
context: UBatchContext
|
||||
input_ids: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
intermediate_tensors: Optional[IntermediateTensors]
|
||||
def model_inputs(self, tokens_slice: slice, use_dummy_input: bool,
|
||||
scheduler_output: Optional["SchedulerOutput"]) -> tuple:
|
||||
if use_dummy_input:
|
||||
return self._get_dummy_model_inputs(tokens_slice.stop - tokens_slice.start)
|
||||
else:
|
||||
assert scheduler_output is not None
|
||||
return self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
|
||||
|
||||
num_dummy_tokens = num_scheduled_tokens if is_dummy_run else 1
|
||||
|
||||
def _make_ubatch_contexts(ubatch_slices,
|
||||
attn_metadata,
|
||||
compute_stream,
|
||||
num_tokens_across_dp,
|
||||
skip_cuda_graphs) -> list[UBatchContext]:
|
||||
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
||||
compute_stream=compute_stream,
|
||||
device=self.device)
|
||||
def _make_ubatch_metadata(self,
|
||||
ubatch_slices,
|
||||
attn_metadata,
|
||||
compute_stream,
|
||||
is_dummy_run,
|
||||
num_tokens_across_dp,
|
||||
skip_cuda_graphs,
|
||||
scheduler_output) -> list[UbatchMetadata]:
|
||||
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
num_tokens = (tokens_slice.stop - tokens_slice.start)
|
||||
# TODO (Sage) Instead of using this setter we should be able
|
||||
# to just create the forward context in advance and pass it
|
||||
# to the UBatchContext's __init__ method
|
||||
ubatch_ctxs[i].forward_context = create_forward_context(
|
||||
attn_metadata[i]
|
||||
if attn_metadata is not None else None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs)
|
||||
return ubatch_ctxs
|
||||
|
||||
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
|
||||
if use_dummy_input:
|
||||
assert num_dummy_tokens == tokens_slice.stop - tokens_slice.start
|
||||
return self._get_dummy_model_inputs(num_dummy_tokens)
|
||||
else:
|
||||
assert scheduler_output is not None
|
||||
return self._get_model_inputs(tokens_slice, scheduler_output)
|
||||
|
||||
def _make_ubatch_metadata(ubatch_slices,
|
||||
attn_metadata,
|
||||
compute_stream,
|
||||
is_dummy_run,
|
||||
num_tokens_across_dp,
|
||||
skip_cuda_graphs) -> list[UbatchMetadata]:
|
||||
ubatch_ctxs = _make_ubatch_contexts(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
compute_stream=compute_stream,
|
||||
# Create one forward context per ubatch
|
||||
forward_contexts = []
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
num_tokens = (tokens_slice.stop - tokens_slice.start)
|
||||
forward_contexts.append(create_forward_context(
|
||||
attn_metadata[i]
|
||||
if attn_metadata is not None else None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs
|
||||
)
|
||||
skip_cuda_graphs=skip_cuda_graphs))
|
||||
|
||||
ubatch_metadata: list[UbatchMetadata] = []
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(tokens_slice, is_dummy_run)
|
||||
ubatch_metadata.append(UbatchMetadata(
|
||||
context=ubatch_ctxs[i],
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors
|
||||
))
|
||||
|
||||
return ubatch_metadata
|
||||
ubatch_ctxs = make_ubatch_contexts(num_micro_batches=len(ubatch_slices),
|
||||
compute_stream=compute_stream,
|
||||
forward_contexts=forward_contexts,
|
||||
device=self.device)
|
||||
|
||||
ubatch_metadata: list[UbatchMetadata] = []
|
||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
self.model_inputs(tokens_slice, is_dummy_run, scheduler_output)
|
||||
ubatch_metadata.append(UbatchMetadata(
|
||||
context=ubatch_ctxs[i],
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors
|
||||
))
|
||||
|
||||
return ubatch_metadata
|
||||
|
||||
|
||||
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||
@torch.inference_mode()
|
||||
def _ubatch_thread(results, model, ubatch_metadata):
|
||||
with ubatch_metadata.context:
|
||||
@ -1588,49 +1569,59 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
results.append((ubatch_metadata.context.id, model_output))
|
||||
|
||||
def _run_ubatches(ubatch_metadata, model) -> torch.Tensor:
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
results: list[tuple[int, torch.Tensor]] = []
|
||||
|
||||
# Ubatch threads will manually manage the forward context, so we
|
||||
# override it to None here so we can have it restored correctly
|
||||
# after both threads have finished
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
model,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
# Ubatch threads will manually manage the forward context, so we
|
||||
# override it to None here so we can have it restored correctly
|
||||
# after both threads have finished
|
||||
with override_forward_context(None):
|
||||
ubatch_threads = []
|
||||
for metadata in ubatch_metadata:
|
||||
thread = threading.Thread(target=_ubatch_thread,
|
||||
args=(
|
||||
results,
|
||||
model,
|
||||
metadata,
|
||||
))
|
||||
ubatch_threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
return result
|
||||
|
||||
def _run_model(self,
|
||||
attn_metadata: Optional[PerLayerAttnMetadata],
|
||||
num_scheduled_tokens: Optional[int],
|
||||
ubatch_slices: Optional[UBatchSlices] = None,
|
||||
scheduler_output: Optional["SchedulerOutput"] = None,
|
||||
is_dummy_run: bool = False,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
skip_cuda_graphs: bool = False):
|
||||
|
||||
ubatch_metadata[0].context.cpu_wait_event.set()
|
||||
for thread in ubatch_threads:
|
||||
thread.join()
|
||||
sorted_results = [value for position, value in sorted(results)]
|
||||
result = torch.cat(sorted_results, dim=0)
|
||||
return result
|
||||
|
||||
# run micro-batched
|
||||
if ubatch_slices is not None:
|
||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||
|
||||
compute_stream = torch.cuda.current_stream()
|
||||
ubatch_metadata = _make_ubatch_metadata(
|
||||
ubatch_metadata = self._make_ubatch_metadata(
|
||||
ubatch_slices=ubatch_slices,
|
||||
attn_metadata=attn_metadata,
|
||||
compute_stream=compute_stream,
|
||||
is_dummy_run=is_dummy_run,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
skip_cuda_graphs=skip_cuda_graphs
|
||||
skip_cuda_graphs=skip_cuda_graphs,
|
||||
scheduler_output=scheduler_output
|
||||
)
|
||||
return _run_ubatches(ubatch_metadata, self.model)
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
# run normal batch
|
||||
else:
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
model_inputs(slice(0, num_scheduled_tokens), is_dummy_run)
|
||||
self.model_inputs(slice(0, num_scheduled_tokens), is_dummy_run, scheduler_output)
|
||||
with set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
num_tokens=num_scheduled_tokens or 1,
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm import forward_context
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import current_stream
|
||||
|
||||
|
||||
@ -18,7 +19,7 @@ class UBatchContext:
|
||||
id: int,
|
||||
comm_stream: torch.cuda.Stream,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
#fwd_ctx: forward_context.ForwardContext,
|
||||
forward_context: ForwardContext,
|
||||
cpu_wait_event: threading.Event,
|
||||
cpu_signal_event: threading.Event,
|
||||
gpu_comm_done_event: torch.cuda.Event,
|
||||
@ -27,7 +28,7 @@ class UBatchContext:
|
||||
self.id = id
|
||||
self.comm_stream = comm_stream
|
||||
self.compute_stream = compute_stream
|
||||
self.forward_context = None #fwd_ctx
|
||||
self.forward_context = forward_context
|
||||
self.cpu_wait_event = cpu_wait_event
|
||||
self.cpu_signal_event = cpu_signal_event
|
||||
self.current_stream = compute_stream
|
||||
@ -150,6 +151,7 @@ def yield_and_switch_from_comm_to_compute(schedule="default"):
|
||||
def make_ubatch_contexts(
|
||||
num_micro_batches: int,
|
||||
compute_stream: torch.cuda.Stream,
|
||||
forward_contexts: list[ForwardContext],
|
||||
device: Optional[torch.device] = None,
|
||||
schedule: str = "default",
|
||||
) -> list[UBatchContext]:
|
||||
@ -167,11 +169,14 @@ def make_ubatch_contexts(
|
||||
device = device or torch.cuda.current_device()
|
||||
comm_stream = torch.cuda.Stream(device)
|
||||
|
||||
assert len(forward_contexts) == 2
|
||||
|
||||
ctxs = []
|
||||
for i in range(num_micro_batches):
|
||||
ctx = UBatchContext(id=i,
|
||||
compute_stream=compute_stream,
|
||||
comm_stream=comm_stream,
|
||||
forward_context=forward_contexts[i],
|
||||
cpu_wait_event=cpu_events[i],
|
||||
cpu_signal_event=cpu_events[(i + 1) %
|
||||
num_micro_batches],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user