mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-12 21:16:49 +08:00
format
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
8ea80fca4a
commit
92e0cc79a8
@ -2,18 +2,16 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import os
|
import threading
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
|
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
|
||||||
import contextlib
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.utils import current_stream
|
|
||||||
from vllm.attention import AttentionType, get_attn_backend
|
from vllm.attention import AttentionType, get_attn_backend
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
AttentionMetadataBuilder)
|
AttentionMetadataBuilder)
|
||||||
@ -27,7 +25,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
|||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_pp_group, get_tp_group, graph_capture,
|
get_pp_group, get_tp_group, graph_capture,
|
||||||
prepare_communication_buffer_for_model)
|
prepare_communication_buffer_for_model)
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context, create_forward_context, override_forward_context
|
from vllm.forward_context import (create_forward_context, get_forward_context,
|
||||||
|
override_forward_context,
|
||||||
|
set_forward_context)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
@ -38,7 +38,7 @@ from vllm.sampling_params import SamplingType
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||||
is_pin_memory_available)
|
current_stream, is_pin_memory_available)
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
@ -59,14 +59,11 @@ from vllm.v1.utils import bind_kv_cache
|
|||||||
from vllm.v1.worker.block_table import BlockTable
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
from vllm.v1.worker.ubatching import make_ubatch_contexts, UBatchContext
|
from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
||||||
|
|
||||||
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||||
scatter_mm_placeholders)
|
scatter_mm_placeholders)
|
||||||
|
|
||||||
import threading
|
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
|
|
||||||
@ -504,22 +501,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_batch.refresh_sampling_metadata()
|
self.input_batch.refresh_sampling_metadata()
|
||||||
|
|
||||||
def _ubatch_split(
|
def _ubatch_split(
|
||||||
self,
|
self, query_start_loc_np: torch.Tensor,
|
||||||
query_start_loc_np: torch.Tensor,
|
max_num_scheduled_tokens: int,
|
||||||
max_num_scheduled_tokens: int,
|
scheduler_output: "SchedulerOutput") -> Optional[UBatchSlices]:
|
||||||
scheduler_output: "SchedulerOutput"
|
|
||||||
) -> Optional[UBatchSlices]:
|
|
||||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
|
||||||
if self.parallel_config.enable_microbatching and \
|
if self.parallel_config.enable_microbatching and \
|
||||||
total_num_scheduled_tokens >= self.parallel_config.microbatching_token_threshold \
|
total_num_scheduled_tokens >= \
|
||||||
and max_num_scheduled_tokens == 1:
|
self.parallel_config.microbatching_token_threshold \
|
||||||
|
and max_num_scheduled_tokens == 1:
|
||||||
# For pure decode we can just create ubatchs by cutting the request
|
# For pure decode we can just create ubatchs by cutting the request
|
||||||
# in half
|
# in half
|
||||||
b0_reqs_end = num_reqs // 2
|
b0_reqs_end = num_reqs // 2
|
||||||
b0_tokens_end = total_num_scheduled_tokens // 2
|
b0_tokens_end = total_num_scheduled_tokens // 2
|
||||||
assert b0_reqs_end < num_reqs and b0_tokens_end < total_num_scheduled_tokens
|
assert b0_reqs_end < num_reqs and \
|
||||||
|
b0_tokens_end < total_num_scheduled_tokens
|
||||||
return [
|
return [
|
||||||
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
|
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
|
||||||
(slice(b0_reqs_end, num_reqs),
|
(slice(b0_reqs_end, num_reqs),
|
||||||
@ -531,14 +528,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# TODO we can do something more advanced here to try to balance,
|
# TODO we can do something more advanced here to try to balance,
|
||||||
# i.e. split to the left of `total_num_scheduled_tokens // 2` if it
|
# i.e. split to the left of `total_num_scheduled_tokens // 2` if it
|
||||||
# is more balanced
|
# is more balanced
|
||||||
req_split_id = np.argmax(query_start_loc_np > (total_num_scheduled_tokens // 2))
|
req_split_id = np.argmax(
|
||||||
return [(slice(0, req_split_id), slice(0, query_start_loc_np[req_split_id])),
|
query_start_loc_np > (total_num_scheduled_tokens // 2))
|
||||||
(slice(req_split_id, num_reqs), slice(query_start_loc_np[req_split_id], total_num_scheduled_tokens))]
|
return [(slice(0, req_split_id),
|
||||||
|
slice(0, query_start_loc_np[req_split_id])),
|
||||||
|
(slice(req_split_id, num_reqs),
|
||||||
|
slice(query_start_loc_np[req_split_id],
|
||||||
|
total_num_scheduled_tokens))]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _is_dummy_ubatch(self, ubatch_slice: UBatchSlices) -> bool:
|
|
||||||
return ubatch_slice[1].start >= ubatch_slice[1].stop
|
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self, scheduler_output: "SchedulerOutput"
|
self, scheduler_output: "SchedulerOutput"
|
||||||
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
|
||||||
@ -630,7 +628,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
|
||||||
|
|
||||||
ubatch_slices: Optional[UBatchSlices] = self._ubatch_split(
|
ubatch_slices: Optional[UBatchSlices] = self._ubatch_split(
|
||||||
self.query_start_loc_np, max_num_scheduled_tokens, scheduler_output)
|
self.query_start_loc_np, max_num_scheduled_tokens,
|
||||||
|
scheduler_output)
|
||||||
|
|
||||||
self.seq_lens_np[:num_reqs] = (
|
self.seq_lens_np[:num_reqs] = (
|
||||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
|
||||||
@ -708,6 +707,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
))
|
))
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
assert type(attn_metadata) is list
|
assert type(attn_metadata) is list
|
||||||
|
assert attn_metadata_i is not None
|
||||||
|
# What if it's None? Do we still add it to the list?
|
||||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||||
else:
|
else:
|
||||||
attn_metadata_i = (
|
attn_metadata_i = (
|
||||||
@ -749,7 +750,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||||
|
|
||||||
return attn_metadata, logits_indices, spec_decode_metadata, ubatch_slices
|
return (attn_metadata, logits_indices, spec_decode_metadata,
|
||||||
|
ubatch_slices)
|
||||||
|
|
||||||
def _compute_cascade_attn_prefix_len(
|
def _compute_cascade_attn_prefix_len(
|
||||||
self,
|
self,
|
||||||
@ -1167,8 +1169,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if sync_self:
|
if sync_self:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
for k, v in intermediate_tensors.items():
|
for k, v in intermediate_tensors.items():
|
||||||
is_scattered = "residual" and is_residual_scattered
|
_copy_slice = copy_slice(is_residual_scattered)
|
||||||
_copy_slice = copy_slice(is_scattered)
|
|
||||||
self.intermediate_tensors[k][_copy_slice].copy_(
|
self.intermediate_tensors[k][_copy_slice].copy_(
|
||||||
v[_copy_slice], non_blocking=True)
|
v[_copy_slice], non_blocking=True)
|
||||||
|
|
||||||
@ -1206,7 +1207,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
return input_ids, positions, inputs_embeds, intermediate_tensors
|
return input_ids, positions, inputs_embeds, intermediate_tensors
|
||||||
|
|
||||||
|
|
||||||
def _get_model_inputs(self, tokens_slice: slice,
|
def _get_model_inputs(self, tokens_slice: slice,
|
||||||
scheduler_output: "SchedulerOutput"):
|
scheduler_output: "SchedulerOutput"):
|
||||||
num_tokens = tokens_slice.stop - tokens_slice.start
|
num_tokens = tokens_slice.stop - tokens_slice.start
|
||||||
@ -1291,26 +1291,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
|
def model_inputs(tokens_slice: slice, use_dummy_input: bool) -> tuple:
|
||||||
if use_dummy_input:
|
if use_dummy_input:
|
||||||
|
assert num_dummy_tokens == 1
|
||||||
return self._get_dummy_model_inputs(num_dummy_tokens)
|
return self._get_dummy_model_inputs(num_dummy_tokens)
|
||||||
else:
|
else:
|
||||||
assert scheduler_output is not None
|
assert scheduler_output is not None
|
||||||
return self._get_model_inputs(tokens_slice, scheduler_output)
|
return self._get_model_inputs(tokens_slice, scheduler_output)
|
||||||
|
|
||||||
|
|
||||||
def _run(token_slice: slice, context, use_dummy_input: bool = False):
|
def _run(token_slice: slice, context, use_dummy_input: bool = False):
|
||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
model_inputs(token_slice, use_dummy_input)
|
model_inputs(token_slice, use_dummy_input)
|
||||||
with context:
|
with context:
|
||||||
# if isinstance(context, UBatchContext):
|
|
||||||
# print(f"Running ubatch {context.id} with input_ids {input_ids.shape} and positions {positions.shape} use_dummy_input {use_dummy_input} token_slice {token_slice}")
|
|
||||||
model_output = self.model(
|
model_output = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
# if isinstance(context, UBatchContext):
|
|
||||||
# print(f"Ran ubatch {context.id}putput {model_output.shape}")
|
|
||||||
if isinstance(context, UBatchContext):
|
if isinstance(context, UBatchContext):
|
||||||
# Clone before we leave the ubatch context
|
# Clone before we leave the ubatch context
|
||||||
model_output = model_output.clone()
|
model_output = model_output.clone()
|
||||||
@ -1318,21 +1314,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return model_output
|
return model_output
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results, use_dummy_input):
|
def _ubatch_thread(ubatch_ctx, token_slice, results, save_results,
|
||||||
|
use_dummy_input):
|
||||||
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
model_output = _run(token_slice, ubatch_ctx, use_dummy_input)
|
||||||
|
|
||||||
if save_results:
|
if save_results:
|
||||||
results.append(model_output)
|
results.append(model_output)
|
||||||
|
|
||||||
def _run_ubatches(ubatch_slices, attn_metadata, is_dummy_run):
|
def _run_ubatches(ubatch_slices, attn_metadata,
|
||||||
results = []
|
is_dummy_run) -> torch.Tensor:
|
||||||
|
results: list[torch.Tensor] = []
|
||||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||||
root_stream = current_stream()
|
root_stream = current_stream()
|
||||||
|
|
||||||
ubatch_ctxs = make_ubatch_contexts(
|
ubatch_ctxs = make_ubatch_contexts(len(ubatch_slices),
|
||||||
len(ubatch_slices),
|
compute_stream=root_stream,
|
||||||
compute_stream=root_stream,
|
device=self.device)
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
# Ubatches will manually manage the forward context, so we override
|
# Ubatches will manually manage the forward context, so we override
|
||||||
# it to None here so we can have it restored correctly later
|
# it to None here so we can have it restored correctly later
|
||||||
@ -1340,22 +1337,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
ubatch_threads = []
|
ubatch_threads = []
|
||||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
is_dummy_ubatch = tokens_slice.stop <= tokens_slice.start
|
||||||
assert not is_dummy_ubatch or i == len(ubatch_slices) - 1 or is_dummy_run
|
assert not is_dummy_ubatch or i == len(
|
||||||
|
ubatch_slices) - 1 or is_dummy_run
|
||||||
|
|
||||||
# print("ubatch", i, "tokens slice", tokens_slice, "is dummy ubatch", is_dummy_ubatch, "is dummy run", is_dummy_run)
|
num_tokens = num_dummy_tokens if is_dummy_ubatch or \
|
||||||
|
is_dummy_run else (tokens_slice.stop - tokens_slice.start)
|
||||||
num_tokens = num_dummy_tokens if is_dummy_ubatch or is_dummy_run else (tokens_slice.stop - tokens_slice.start)
|
|
||||||
ubatch_ctxs[i].forward_context = create_forward_context(
|
ubatch_ctxs[i].forward_context = create_forward_context(
|
||||||
attn_metadata[i] if attn_metadata is not None else None,
|
attn_metadata[i]
|
||||||
self.vllm_config, num_tokens=num_tokens)
|
if attn_metadata is not None else None,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_tokens)
|
||||||
|
|
||||||
thread = threading.Thread(target=_ubatch_thread, args=(
|
thread = threading.Thread(target=_ubatch_thread,
|
||||||
ubatch_ctxs[i],
|
args=(
|
||||||
tokens_slice,
|
ubatch_ctxs[i],
|
||||||
results,
|
tokens_slice,
|
||||||
not is_dummy_ubatch or is_dummy_run,
|
results,
|
||||||
is_dummy_ubatch or is_dummy_run,
|
not is_dummy_ubatch
|
||||||
))
|
or is_dummy_run,
|
||||||
|
is_dummy_ubatch
|
||||||
|
or is_dummy_run,
|
||||||
|
))
|
||||||
ubatch_threads.append(thread)
|
ubatch_threads.append(thread)
|
||||||
thread.start()
|
thread.start()
|
||||||
ubatch_ctxs[0].cpu_wait_event.set()
|
ubatch_ctxs[0].cpu_wait_event.set()
|
||||||
@ -1369,16 +1371,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# run micro-batched
|
# run micro-batched
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
model_output = _run_ubatches(
|
model_output = _run_ubatches(ubatch_slices, attn_metadata,
|
||||||
ubatch_slices, attn_metadata, is_dummy_run)
|
is_dummy_run)
|
||||||
# run single batch
|
# run single batch
|
||||||
else:
|
else:
|
||||||
model_output = _run(
|
model_output = _run(
|
||||||
slice(0, num_scheduled_tokens),
|
slice(0, num_scheduled_tokens),
|
||||||
set_forward_context(
|
set_forward_context(attn_metadata,
|
||||||
attn_metadata,
|
vllm_config=self.vllm_config,
|
||||||
vllm_config=self.vllm_config,
|
num_tokens=num_scheduled_tokens or 1),
|
||||||
num_tokens=num_scheduled_tokens or 1),
|
|
||||||
is_dummy_run)
|
is_dummy_run)
|
||||||
|
|
||||||
return model_output
|
return model_output
|
||||||
@ -1583,7 +1584,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if spec_decode_metadata is None:
|
if spec_decode_metadata is None:
|
||||||
# input_ids can be None for multimodal models.
|
# input_ids can be None for multimodal models.
|
||||||
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
target_token_ids = self.input_ids[:num_scheduled_tokens]
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
#TODO(sage) make sure this works with mrope
|
||||||
|
target_positions = self.positions[:num_scheduled_tokens]
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
target_hidden_states = torch.cat(
|
target_hidden_states = torch.cat(
|
||||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||||
@ -1609,7 +1611,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_rejected_tokens,
|
num_rejected_tokens,
|
||||||
)
|
)
|
||||||
target_token_ids = self.input_ids[token_indices]
|
target_token_ids = self.input_ids[token_indices]
|
||||||
target_positions = positions[token_indices]
|
#TODO(sage) make sure this works with mrope
|
||||||
|
target_positions = self.positions[token_indices]
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
target_hidden_states = torch.cat(
|
target_hidden_states = torch.cat(
|
||||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||||
@ -1647,6 +1650,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
def kv_connector_no_forward(
|
def kv_connector_no_forward(
|
||||||
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
|
||||||
|
|
||||||
# KV send/recv even if no work to do.
|
# KV send/recv even if no work to do.
|
||||||
with set_forward_context(None, self.vllm_config):
|
with set_forward_context(None, self.vllm_config):
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
@ -1898,16 +1902,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
should_microbatch = (
|
should_microbatch = (
|
||||||
allow_microbatching
|
allow_microbatching
|
||||||
and self.vllm_config.parallel_config.enable_microbatching
|
and self.vllm_config.parallel_config.enable_microbatching
|
||||||
and self.vllm_config.parallel_config.always_microbatch_if_enabled
|
and self.vllm_config.parallel_config.always_microbatch_if_enabled)
|
||||||
)
|
dummy_microbatches = [(slice(0, 0), slice(0, 0)),
|
||||||
dummy_microbatches = [(slice(0, 0), slice(0, 0)), (slice(0, 0), slice(0, 0))]
|
(slice(0, 0), slice(0, 0))]
|
||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||||
num_scheduled_tokens):
|
num_scheduled_tokens):
|
||||||
outputs = self._run_model(
|
outputs = self._run_model(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
ubatch_slices=None if not should_microbatch else dummy_microbatches,
|
ubatch_slices=None
|
||||||
|
if not should_microbatch else dummy_microbatches,
|
||||||
is_dummy_run=True,
|
is_dummy_run=True,
|
||||||
)
|
)
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
|
|||||||
@ -1,36 +1,36 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import threading
|
import threading
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
import torch.profiler as profiler
|
from torch.library import custom_op
|
||||||
import os
|
|
||||||
from typing import Optional
|
|
||||||
from torch.library import Library
|
|
||||||
from torch.library import custom_op, register_kernel
|
|
||||||
from vllm.distributed import (get_dp_group)
|
|
||||||
|
|
||||||
from vllm.utils import current_stream
|
|
||||||
from vllm import forward_context
|
from vllm import forward_context
|
||||||
|
from vllm.utils import current_stream
|
||||||
|
|
||||||
|
|
||||||
class UBatchContext:
|
class UBatchContext:
|
||||||
"""
|
"""
|
||||||
Context manager for micro-batching synchronization using threading events.
|
Context manager for micro-batching synchronization using threading events.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
id: int,
|
def __init__(
|
||||||
comm_stream: torch.cuda.Stream,
|
self,
|
||||||
compute_stream: torch.cuda.Stream,
|
id: int,
|
||||||
#fwd_ctx: forward_context.ForwardContext,
|
comm_stream: torch.cuda.Stream,
|
||||||
cpu_wait_event: threading.Event,
|
compute_stream: torch.cuda.Stream,
|
||||||
cpu_signal_event: threading.Event,
|
#fwd_ctx: forward_context.ForwardContext,
|
||||||
gpu_comm_done_event: torch.cuda.Event,
|
cpu_wait_event: threading.Event,
|
||||||
gpu_compute_done_event: torch.cuda.Event,
|
cpu_signal_event: threading.Event,
|
||||||
schedule: str = "default"):
|
gpu_comm_done_event: torch.cuda.Event,
|
||||||
|
gpu_compute_done_event: torch.cuda.Event,
|
||||||
|
schedule: str = "default"):
|
||||||
self.id = id
|
self.id = id
|
||||||
self.comm_stream = comm_stream
|
self.comm_stream = comm_stream
|
||||||
self.compute_stream = compute_stream
|
self.compute_stream = compute_stream
|
||||||
self.original_stream = current_stream()
|
self.original_stream = current_stream()
|
||||||
self.forward_context = None #fwd_ctx
|
self.forward_context = None #fwd_ctx
|
||||||
self.cpu_wait_event = cpu_wait_event
|
self.cpu_wait_event = cpu_wait_event
|
||||||
self.cpu_signal_event = cpu_signal_event
|
self.cpu_signal_event = cpu_signal_event
|
||||||
self.current_stream = compute_stream
|
self.current_stream = compute_stream
|
||||||
@ -47,8 +47,7 @@ class UBatchContext:
|
|||||||
self.cpu_wait_event.clear()
|
self.cpu_wait_event.clear()
|
||||||
self._restore_context()
|
self._restore_context()
|
||||||
# Assume we start on the compute stream
|
# Assume we start on the compute stream
|
||||||
assert current_stream() == self.compute_stream, \
|
assert current_stream() == self.compute_stream
|
||||||
"Expected to start on the compute stream, but found %s" % current_stream()
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
@ -74,6 +73,7 @@ class UBatchContext:
|
|||||||
# assert current_stream() == self.current_stream
|
# assert current_stream() == self.current_stream
|
||||||
# assert not self.cpu_wait_event.is_set()
|
# assert not self.cpu_wait_event.is_set()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _signal_comm_done(self):
|
def _signal_comm_done(self):
|
||||||
# self.ctx_valid_state()
|
# self.ctx_valid_state()
|
||||||
self.gpu_comm_done_event.record(self.comm_stream)
|
self.gpu_comm_done_event.record(self.comm_stream)
|
||||||
@ -115,31 +115,37 @@ class UBatchContext:
|
|||||||
def yield_and_switch_from_compute_to_comm(self):
|
def yield_and_switch_from_compute_to_comm(self):
|
||||||
assert current_stream() == self.compute_stream
|
assert current_stream() == self.compute_stream
|
||||||
# dp_rank = get_dp_group().rank_in_group
|
# dp_rank = get_dp_group().rank_in_group
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
|
# f"Yield and switch from {self.stream_string()}", flush=True)
|
||||||
# self.ctx_valid_state()
|
# self.ctx_valid_state()
|
||||||
self._signal_compute_done()
|
self._signal_compute_done()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
# self.ctx_valid_state()
|
# self.ctx_valid_state()
|
||||||
assert self.current_stream == self.compute_stream
|
assert self.current_stream == self.compute_stream
|
||||||
self.update_stream(self.comm_stream)
|
self.update_stream(self.comm_stream)
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
|
# f"Resuming on stream {self.stream_string()}", flush=True)
|
||||||
self._wait_compute_done()
|
self._wait_compute_done()
|
||||||
|
|
||||||
def yield_and_switch_from_comm_to_compute(self):
|
def yield_and_switch_from_comm_to_compute(self):
|
||||||
assert current_stream() == self.comm_stream
|
assert current_stream() == self.comm_stream
|
||||||
# dp_rank = get_dp_group().rank_in_group
|
# dp_rank = get_dp_group().rank_in_group
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} Yield and switch from {self.stream_string()}", flush=True)
|
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
|
# f"Yield and switch from {self.stream_string()}", flush=True)
|
||||||
# self.ctx_valid_state()
|
# self.ctx_valid_state()
|
||||||
self._signal_comm_done()
|
self._signal_comm_done()
|
||||||
self._cpu_yield()
|
self._cpu_yield()
|
||||||
# self.ctx_valid_state()
|
# self.ctx_valid_state()
|
||||||
assert self.current_stream == self.comm_stream
|
assert self.current_stream == self.comm_stream
|
||||||
self.update_stream(self.compute_stream)
|
self.update_stream(self.compute_stream)
|
||||||
# print(f"DP: {dp_rank} UB: {self.id} Resuming on stream {self.stream_string()}", flush=True)
|
# print(f"DP: {dp_rank} UB: {self.id} "
|
||||||
|
# f"Resuming on stream {self.stream_string()}", flush=True)
|
||||||
self._wait_comm_done()
|
self._wait_comm_done()
|
||||||
|
|
||||||
|
|
||||||
_CURRENT_CONTEXT: dict = {}
|
_CURRENT_CONTEXT: dict = {}
|
||||||
|
|
||||||
|
|
||||||
def get_current_ubatch_context() -> Optional[UBatchContext]:
|
def get_current_ubatch_context() -> Optional[UBatchContext]:
|
||||||
global _CURRENT_CONTEXT
|
global _CURRENT_CONTEXT
|
||||||
"""
|
"""
|
||||||
@ -147,6 +153,7 @@ def get_current_ubatch_context() -> Optional[UBatchContext]:
|
|||||||
"""
|
"""
|
||||||
return _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
return _CURRENT_CONTEXT.get(threading.get_ident(), None)
|
||||||
|
|
||||||
|
|
||||||
def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
|
def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
|
||||||
# Perform the barrier if a context exists for this thread
|
# Perform the barrier if a context exists for this thread
|
||||||
ctx = get_current_ubatch_context()
|
ctx = get_current_ubatch_context()
|
||||||
@ -154,50 +161,51 @@ def yield_and_switch_from_compute_to_comm_impl(schedule="default"):
|
|||||||
if ctx is not None and ctx.schedule == schedule:
|
if ctx is not None and ctx.schedule == schedule:
|
||||||
ctx.yield_and_switch_from_compute_to_comm()
|
ctx.yield_and_switch_from_compute_to_comm()
|
||||||
|
|
||||||
|
|
||||||
def yield_and_switch_from_comm_to_compute_impl(schedule="default"):
|
def yield_and_switch_from_comm_to_compute_impl(schedule="default"):
|
||||||
# Perform the barrier if a context exists for this thread
|
# Perform the barrier if a context exists for this thread
|
||||||
ctx = get_current_ubatch_context()
|
ctx = get_current_ubatch_context()
|
||||||
if ctx is not None and ctx.schedule == schedule:
|
if ctx is not None and ctx.schedule == schedule:
|
||||||
ctx.yield_and_switch_from_comm_to_compute()
|
ctx.yield_and_switch_from_comm_to_compute()
|
||||||
|
|
||||||
|
|
||||||
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
|
# 2) Register kernel for CUDA, mark as mutating to prevent the compiler from
|
||||||
# optimizing it away (TODO: see if this is actually needed)
|
# optimizing it away (TODO: see if this is actually needed)
|
||||||
@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x",))
|
@custom_op("vllm::yield_and_switch_from_compute_to_comm", mutates_args=("x", ))
|
||||||
def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None:
|
def yield_and_switch_from_compute_to_comm(x: torch.Tensor,
|
||||||
|
schedule: str = "default") -> None:
|
||||||
yield_and_switch_from_compute_to_comm_impl(schedule)
|
yield_and_switch_from_compute_to_comm_impl(schedule)
|
||||||
|
|
||||||
|
|
||||||
# 3) Fake implementation for shape prop and FX tracing
|
# 3) Fake implementation for shape prop and FX tracing
|
||||||
@yield_and_switch_from_compute_to_comm.register_fake
|
@yield_and_switch_from_compute_to_comm.register_fake
|
||||||
def yield_and_switch_from_compute_to_comm(x: torch.Tensor, schedule: str="default") -> None:
|
def yield_and_switch_from_compute_to_comm_fake(x: torch.Tensor,
|
||||||
|
schedule: str = "default"
|
||||||
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x",))
|
|
||||||
def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None:
|
@custom_op("vllm::yield_and_switch_from_comm_to_compute", mutates_args=("x", ))
|
||||||
|
def yield_and_switch_from_comm_to_compute(x: torch.Tensor,
|
||||||
|
schedule: str = "default") -> None:
|
||||||
yield_and_switch_from_comm_to_compute_impl(schedule)
|
yield_and_switch_from_comm_to_compute_impl(schedule)
|
||||||
|
|
||||||
|
|
||||||
@yield_and_switch_from_comm_to_compute.register_fake
|
@yield_and_switch_from_comm_to_compute.register_fake
|
||||||
def yield_and_switch_from_comm_to_compute(x: torch.Tensor, schedule: str="default") -> None:
|
def yield_and_switch_from_comm_to_compute_fake(x: torch.Tensor,
|
||||||
|
schedule: str = "default"
|
||||||
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def dump_ubatching_state():
|
def dump_ubatching_state():
|
||||||
pass
|
pass
|
||||||
# """
|
|
||||||
# Dump the current UBatchContext state for debugging.
|
|
||||||
# """
|
|
||||||
|
|
||||||
# dp_rank = os.getenv("VLLM_DP_RANK", None)
|
|
||||||
|
|
||||||
# for ctx in _CURRENT_CONTEXT.values():
|
|
||||||
# print(f"UBatchContext: {ctx.id} (dp_rank {dp_rank})\n"
|
|
||||||
# f" Stream: {ctx.stream}, ({ctx.stream.query()})\n"
|
|
||||||
# f" Original Stream: {ctx.original_stream}, ({ctx.original_stream.query()})\n"
|
|
||||||
# f" CPU Wait Event: {ctx.cpu_wait_event}\n"
|
|
||||||
# f" GPU Wait Event: {ctx.gpu_wait_event} ({ctx.gpu_wait_event.query()})\n"
|
|
||||||
# f" CPU Signal Event: {ctx.cpu_signal_event}\n"
|
|
||||||
# f" GPU Signal Event: {ctx.gpu_signal_event} ({ctx.gpu_signal_event.query()})\n")
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def make_ubatch_contexts(
|
def make_ubatch_contexts(
|
||||||
num_micro_batches: int,
|
num_micro_batches: int,
|
||||||
compute_stream: torch.cuda.Stream,
|
compute_stream: torch.cuda.Stream,
|
||||||
@ -205,7 +213,6 @@ def make_ubatch_contexts(
|
|||||||
schedule: str = "default",
|
schedule: str = "default",
|
||||||
) -> list[UBatchContext]:
|
) -> list[UBatchContext]:
|
||||||
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
assert num_micro_batches == 2, "only been tested with 2 micro-batches"
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Create a context manager for micro-batching synchronization.
|
Create a context manager for micro-batching synchronization.
|
||||||
"""
|
"""
|
||||||
@ -225,11 +232,11 @@ def make_ubatch_contexts(
|
|||||||
compute_stream=compute_stream,
|
compute_stream=compute_stream,
|
||||||
comm_stream=comm_stream,
|
comm_stream=comm_stream,
|
||||||
cpu_wait_event=cpu_events[i],
|
cpu_wait_event=cpu_events[i],
|
||||||
cpu_signal_event=cpu_events[(i + 1) % num_micro_batches],
|
cpu_signal_event=cpu_events[(i + 1) %
|
||||||
|
num_micro_batches],
|
||||||
gpu_comm_done_event=gpu_comm_done_events[i],
|
gpu_comm_done_event=gpu_comm_done_events[i],
|
||||||
gpu_compute_done_event=gpu_compute_done_events[i],
|
gpu_compute_done_event=gpu_compute_done_events[i],
|
||||||
schedule=schedule
|
schedule=schedule)
|
||||||
)
|
|
||||||
ctxs.append(ctx)
|
ctxs.append(ctx)
|
||||||
|
|
||||||
return ctxs
|
return ctxs
|
||||||
Loading…
x
Reference in New Issue
Block a user