mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-12 01:47:14 +08:00
format
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
dc1b6af362
commit
bfa828f399
@ -48,12 +48,13 @@ class DPMetadata:
|
|||||||
return num_tokens_tensor
|
return num_tokens_tensor
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def should_ubatch_across_dp(should_ubatch: bool, dp_size: int, dp_rank: int) -> bool:
|
def should_ubatch_across_dp(should_ubatch: bool, dp_size: int,
|
||||||
|
dp_rank: int) -> bool:
|
||||||
should_ubatch_across_dp = [0] * dp_size
|
should_ubatch_across_dp = [0] * dp_size
|
||||||
should_ubatch_across_dp[dp_rank] = 1 if should_ubatch else 0
|
should_ubatch_across_dp[dp_rank] = 1 if should_ubatch else 0
|
||||||
should_ubatch_tensor = torch.tensor(should_ubatch_across_dp,
|
should_ubatch_tensor = torch.tensor(should_ubatch_across_dp,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
from vllm.distributed.parallel_state import get_dp_group
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group)
|
dist.all_reduce(should_ubatch_tensor, group=get_dp_group().cpu_group)
|
||||||
|
|
||||||
@ -61,8 +62,8 @@ class DPMetadata:
|
|||||||
# num_tokens_across_dp. If there's an incorrect ordering of ARs
|
# num_tokens_across_dp. If there's an incorrect ordering of ARs
|
||||||
# across DP ranks, this tensor can end up containing the number
|
# across DP ranks, this tensor can end up containing the number
|
||||||
# of padded tokens for a DP rank.
|
# of padded tokens for a DP rank.
|
||||||
|
assert torch.all((should_ubatch_tensor == 0)
|
||||||
assert torch.all((should_ubatch_tensor == 0) | (should_ubatch_tensor == 1))
|
| (should_ubatch_tensor == 1))
|
||||||
|
|
||||||
result: bool = bool(torch.all(should_ubatch_tensor == 1).item())
|
result: bool = bool(torch.all(should_ubatch_tensor == 1).item())
|
||||||
return result
|
return result
|
||||||
|
|||||||
@ -31,8 +31,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
from vllm.v1.worker.ubatching import get_current_ubatch_context
|
|
||||||
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
||||||
|
from vllm.v1.worker.ubatching import get_current_ubatch_context
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .fused_batched_moe import BatchedTritonExperts
|
from .fused_batched_moe import BatchedTritonExperts
|
||||||
@ -1571,15 +1571,22 @@ class FusedMoE(torch.nn.Module):
|
|||||||
ubatch_ctx = get_current_ubatch_context()
|
ubatch_ctx = get_current_ubatch_context()
|
||||||
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
ubatch_id = ubatch_ctx.id if ubatch_ctx is not None else -1
|
||||||
batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id
|
batch_buffer_idx = 0 if ubatch_id == -1 else ubatch_id
|
||||||
batched_hidden_states = self.batched_hidden_states[batch_buffer_idx, :]
|
|
||||||
batched_router_logits = self.batched_router_logits[batch_buffer_idx, :]
|
assert self.batched_hidden_states is not None
|
||||||
|
assert self.batched_router_logits is not None
|
||||||
|
batched_hidden_states = self.batched_hidden_states[
|
||||||
|
batch_buffer_idx, :]
|
||||||
|
batched_router_logits = self.batched_router_logits[
|
||||||
|
batch_buffer_idx, :]
|
||||||
|
|
||||||
assert (batched_hidden_states.size(0) # type: ignore
|
assert (batched_hidden_states.size(0) # type: ignore
|
||||||
>= chunk_size)
|
>= chunk_size)
|
||||||
assert (batched_router_logits.size(0) # type: ignore
|
assert (batched_router_logits.size(0) # type: ignore
|
||||||
>= chunk_size)
|
>= chunk_size)
|
||||||
staged_hidden_states = batched_hidden_states[:chunk_size, :] # type: ignore
|
staged_hidden_states = batched_hidden_states[:
|
||||||
staged_router_logits = batched_router_logits[:chunk_size, :] # type: ignore
|
chunk_size, :] # type: ignore
|
||||||
|
staged_router_logits = batched_router_logits[:
|
||||||
|
chunk_size, :] # type: ignore
|
||||||
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
||||||
staged_router_logits.copy_(router_logits, non_blocking=True)
|
staged_router_logits.copy_(router_logits, non_blocking=True)
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
|
||||||
make_local_attention_virtual_batches)
|
make_local_attention_virtual_batches, slice_query_start_locs)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
from vllm.v1.worker.block_table import BlockTable
|
from vllm.v1.worker.block_table import BlockTable
|
||||||
|
|
||||||
|
|||||||
@ -60,6 +60,20 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def build_slice(
|
||||||
|
self,
|
||||||
|
req_slice: slice,
|
||||||
|
token_slice: slice,
|
||||||
|
max_query_len: int,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
) -> M:
|
||||||
|
"""
|
||||||
|
Should only be called on builders that support attention slicing
|
||||||
|
for micro batching
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def can_run_in_cudagraph(
|
def can_run_in_cudagraph(
|
||||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -105,6 +119,7 @@ def slice_query_start_locs(
|
|||||||
return query_start_loc[req_slice.start: req_slice.stop + 1] -\
|
return query_start_loc[req_slice.start: req_slice.stop + 1] -\
|
||||||
query_start_loc[req_slice.start]
|
query_start_loc[req_slice.start]
|
||||||
|
|
||||||
|
|
||||||
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
def validate_kv_sharing_target(current_layer_name, target_layer_name,
|
||||||
static_forward_context):
|
static_forward_context):
|
||||||
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
|
error_msg = (f"Specified KV sharing target layer for {current_layer_name} "
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import dataclasses
|
||||||
import gc
|
import gc
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@ -29,8 +30,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
|||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
get_pp_group, get_tp_group, graph_capture,
|
get_pp_group, get_tp_group, graph_capture,
|
||||||
prepare_communication_buffer_for_model)
|
prepare_communication_buffer_for_model)
|
||||||
from vllm.forward_context import (create_forward_context, get_forward_context,
|
from vllm.forward_context import (DPMetadata, create_forward_context,
|
||||||
override_forward_context, DPMetadata,
|
get_forward_context,
|
||||||
|
override_forward_context,
|
||||||
set_forward_context)
|
set_forward_context)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||||
@ -48,8 +50,8 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
|||||||
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
|
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
|
||||||
check_use_alibi, get_dtype_size,
|
check_use_alibi, get_dtype_size,
|
||||||
is_pin_memory_available, round_up)
|
is_pin_memory_available, round_up)
|
||||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
|
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata)
|
CommonAttentionMetadata)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
@ -75,7 +77,6 @@ from vllm.v1.worker.ubatching import UBatchContext, make_ubatch_contexts
|
|||||||
|
|
||||||
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
|
||||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||||
import dataclasses
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
@ -99,6 +100,7 @@ PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
|
|||||||
UbatchSlice: TypeAlias = tuple[slice, slice]
|
UbatchSlice: TypeAlias = tuple[slice, slice]
|
||||||
UBatchSlices: TypeAlias = list[UbatchSlice]
|
UBatchSlices: TypeAlias = list[UbatchSlice]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class UbatchMetadata:
|
class UbatchMetadata:
|
||||||
context: UBatchContext
|
context: UBatchContext
|
||||||
@ -577,10 +579,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_batch.refresh_sampling_metadata()
|
self.input_batch.refresh_sampling_metadata()
|
||||||
|
|
||||||
def _ubatch_split(
|
def _ubatch_split(
|
||||||
self,
|
self, max_num_scheduled_tokens: int,
|
||||||
max_num_scheduled_tokens: int,
|
scheduler_output: "SchedulerOutput"
|
||||||
scheduler_output: "SchedulerOutput"
|
) -> tuple[Optional[UBatchSlices], int, Optional[torch.Tensor]]:
|
||||||
) -> tuple[Optional[UBatchSlices], int, Optional[torch.Tensor]]:
|
|
||||||
# Don't bother with the should_ubatch handshaking unless microbatching
|
# Don't bother with the should_ubatch handshaking unless microbatching
|
||||||
# is enabled
|
# is enabled
|
||||||
if not self.parallel_config.enable_microbatching:
|
if not self.parallel_config.enable_microbatching:
|
||||||
@ -607,15 +608,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
b0_tokens_end < total_num_scheduled_tokens
|
b0_tokens_end < total_num_scheduled_tokens
|
||||||
ubatch_slices = [
|
ubatch_slices = [
|
||||||
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
|
(slice(0, b0_reqs_end), slice(0, b0_tokens_end)),
|
||||||
(slice(b0_reqs_end, num_reqs),
|
(slice(b0_reqs_end,
|
||||||
slice(b0_tokens_end, total_num_scheduled_tokens)),
|
num_reqs), slice(b0_tokens_end,
|
||||||
|
total_num_scheduled_tokens)),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Compute ubatch padding. This currently only accounts for DP padding
|
# Compute ubatch padding. This currently only accounts for DP padding
|
||||||
num_pad_tokens = 0
|
num_pad_tokens = 0
|
||||||
num_tokens_after_padding = None
|
num_tokens_after_padding = None
|
||||||
ubatch_abort = False
|
ubatch_abort = False
|
||||||
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(ubatch_slices)
|
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding_ubatch(
|
||||||
|
ubatch_slices)
|
||||||
if num_pad_tokens > 0:
|
if num_pad_tokens > 0:
|
||||||
# Check if the padding would result in an empty second ubatch.
|
# Check if the padding would result in an empty second ubatch.
|
||||||
# If so abort ubatching
|
# If so abort ubatching
|
||||||
@ -624,10 +627,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
ubatch_abort = True
|
ubatch_abort = True
|
||||||
|
|
||||||
# Note that if we are attempting to ubatch by this point then we know that no
|
# Note that if we are attempting to ubatch by this point then we know
|
||||||
# DP ranks are doing dummy runs. Meaning, we don't need a second call to
|
# that no DP ranks are doing dummy runs. Meaning, we don't need a
|
||||||
# should_ubatch in _dummy_run
|
# second call to should_ubatch in _dummy_run
|
||||||
should_ubatch = self.should_ubatch(False if ubatch_abort else True)
|
should_ubatch = self.should_ubatch(not ubatch_abort)
|
||||||
if not should_ubatch:
|
if not should_ubatch:
|
||||||
return (None, 0, None)
|
return (None, 0, None)
|
||||||
return (ubatch_slices, num_pad_tokens, num_tokens_after_padding)
|
return (ubatch_slices, num_pad_tokens, num_tokens_after_padding)
|
||||||
@ -653,12 +656,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return cu_num_tokens, arange
|
return cu_num_tokens, arange
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self, scheduler_output: "SchedulerOutput"
|
||||||
scheduler_output: "SchedulerOutput"
|
) -> tuple[PerLayerAttnMetadata, bool, torch.Tensor,
|
||||||
) -> tuple[dict[str, Any], bool, torch.Tensor,
|
|
||||||
Optional[SpecDecodeMetadata], np.ndarray,
|
Optional[SpecDecodeMetadata], np.ndarray,
|
||||||
Optional[UBatchSlices],
|
Optional[UBatchSlices], int, Optional[torch.Tensor]]:
|
||||||
int, Optional[torch.Tensor]]:
|
|
||||||
"""
|
"""
|
||||||
:return: tuple[
|
:return: tuple[
|
||||||
attn_metadata: layer-to-attention_metadata mapping,
|
attn_metadata: layer-to-attention_metadata mapping,
|
||||||
@ -874,8 +875,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||||
|
|
||||||
return (attn_metadata, attention_cuda_graphs, logits_indices,
|
return (attn_metadata, attention_cuda_graphs, logits_indices,
|
||||||
spec_decode_metadata, num_scheduled_tokens, ubatch_slices, num_pad_tokens,
|
spec_decode_metadata, num_scheduled_tokens, ubatch_slices,
|
||||||
num_tokens_after_padding)
|
num_pad_tokens, num_tokens_after_padding)
|
||||||
|
|
||||||
def _compute_cascade_attn_prefix_len(
|
def _compute_cascade_attn_prefix_len(
|
||||||
self,
|
self,
|
||||||
@ -1343,8 +1344,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||||
|
|
||||||
def get_padding(self,
|
def get_padding(
|
||||||
num_tokens_unpadded: int) -> tuple[int, Optional[torch.Tensor]]:
|
self,
|
||||||
|
num_tokens_unpadded: int) -> tuple[int, Optional[torch.Tensor]]:
|
||||||
|
|
||||||
num_tokens_padded = num_tokens_unpadded
|
num_tokens_padded = num_tokens_unpadded
|
||||||
|
|
||||||
@ -1352,7 +1354,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
|
and num_tokens_unpadded <= self.cudagraph_batch_sizes[-1]):
|
||||||
# Use piecewise CUDA graphs.
|
# Use piecewise CUDA graphs.
|
||||||
# Add padding to the batch size.
|
# Add padding to the batch size.
|
||||||
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens_unpadded)
|
num_tokens_padded = self.vllm_config.pad_for_cudagraph(
|
||||||
|
num_tokens_unpadded)
|
||||||
else:
|
else:
|
||||||
# Eager mode.
|
# Eager mode.
|
||||||
# Pad tokens to multiple of tensor_parallel_size when
|
# Pad tokens to multiple of tensor_parallel_size when
|
||||||
@ -1364,12 +1367,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_tokens_padded = round_up(num_tokens_unpadded, tp_size)
|
num_tokens_padded = round_up(num_tokens_unpadded, tp_size)
|
||||||
|
|
||||||
num_pad_tokens = num_tokens_padded - num_tokens_unpadded
|
num_pad_tokens = num_tokens_padded - num_tokens_unpadded
|
||||||
num_dp_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_tokens_padded)
|
num_dp_pad_tokens, num_tokens_after_padding = self.get_dp_padding(
|
||||||
|
num_tokens_padded)
|
||||||
|
|
||||||
return num_dp_pad_tokens + num_pad_tokens, num_tokens_after_padding
|
return num_dp_pad_tokens + num_pad_tokens, num_tokens_after_padding
|
||||||
|
|
||||||
def get_dp_padding_ubatch(self,
|
def get_dp_padding_ubatch(
|
||||||
ubatch_slices: UBatchSlices) -> tuple[int, Optional[torch.Tensor]]:
|
self,
|
||||||
|
ubatch_slices: UBatchSlices) -> tuple[int, Optional[torch.Tensor]]:
|
||||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||||
|
|
||||||
if dp_size == 1:
|
if dp_size == 1:
|
||||||
@ -1379,21 +1384,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
first_ubatch_slice = ubatch_slices[0]
|
first_ubatch_slice = ubatch_slices[0]
|
||||||
second_ubatch_slice = ubatch_slices[1]
|
second_ubatch_slice = ubatch_slices[1]
|
||||||
|
|
||||||
first_ubatch_num_tokens = first_ubatch_slice[1].stop - first_ubatch_slice[1].start
|
first_ubatch_num_tokens = first_ubatch_slice[
|
||||||
second_ubatch_num_tokens = second_ubatch_slice[1].stop - second_ubatch_slice[1].start
|
1].stop - first_ubatch_slice[1].start
|
||||||
|
second_ubatch_num_tokens = second_ubatch_slice[
|
||||||
|
1].stop - second_ubatch_slice[1].start
|
||||||
# We don't support prefills yet so the two ubatches should only differ
|
# We don't support prefills yet so the two ubatches should only differ
|
||||||
# by at most one token
|
# by at most one token
|
||||||
assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1
|
assert abs(first_ubatch_num_tokens - second_ubatch_num_tokens) <= 1
|
||||||
|
|
||||||
from vllm.utils import round_up
|
num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens
|
||||||
|
|
||||||
num_tokens_unpadded = first_ubatch_num_tokens + second_ubatch_num_tokens
|
|
||||||
num_tokens_padded = round_up(num_tokens_unpadded, 2)
|
num_tokens_padded = round_up(num_tokens_unpadded, 2)
|
||||||
|
|
||||||
num_tokens_per_ubatch = num_tokens_padded // 2
|
num_tokens_per_ubatch = num_tokens_padded // 2
|
||||||
|
|
||||||
# Note that we compute the number of padded tokens per ubatch
|
# Note that we compute the number of padded tokens per ubatch
|
||||||
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_tokens_per_ubatch)
|
num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(
|
||||||
|
num_tokens_per_ubatch)
|
||||||
|
|
||||||
num_pad_tokens = ((num_pad_tokens + num_tokens_per_ubatch) * 2) - \
|
num_pad_tokens = ((num_pad_tokens + num_tokens_per_ubatch) * 2) - \
|
||||||
num_tokens_unpadded
|
num_tokens_unpadded
|
||||||
@ -1407,26 +1413,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_pad_tokens: int):
|
num_pad_tokens: int):
|
||||||
original_num_tokens = ubatch_slices[1][1].stop
|
original_num_tokens = ubatch_slices[1][1].stop
|
||||||
assert num_pad_tokens < original_num_tokens
|
assert num_pad_tokens < original_num_tokens
|
||||||
total_num_tokens_per_ubatch = (original_num_tokens + num_pad_tokens) // 2
|
total_num_tokens_per_ubatch = (original_num_tokens +
|
||||||
|
num_pad_tokens) // 2
|
||||||
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
|
padded_first_ubatch_slice = slice(0, total_num_tokens_per_ubatch)
|
||||||
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch, original_num_tokens)
|
padded_second_ubatch_slice = slice(total_num_tokens_per_ubatch,
|
||||||
|
original_num_tokens)
|
||||||
|
|
||||||
ubatch_slices[0] = (padded_first_ubatch_slice, padded_first_ubatch_slice)
|
ubatch_slices[0] = (padded_first_ubatch_slice,
|
||||||
ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice)
|
padded_first_ubatch_slice)
|
||||||
|
ubatch_slices[1] = (padded_second_ubatch_slice,
|
||||||
|
padded_second_ubatch_slice)
|
||||||
|
|
||||||
# This is where the second ubatch is adjusted to account for the padding.
|
# This is where the second ubatch is adjusted to account for the padding.
|
||||||
# Should be called after attention metadata creation. This just pads
|
# Should be called after attention metadata creation. This just pads
|
||||||
# the second ubatch slice out to the total number of tokens
|
# the second ubatch slice out to the total number of tokens
|
||||||
# (num_tokens + padding)
|
# (num_tokens + padding)
|
||||||
def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices, num_total_tokens: int):
|
def pad_out_ubatch_second_stage(self, ubatch_slices: UBatchSlices,
|
||||||
|
num_total_tokens: int):
|
||||||
# TODO Add asserts to make sure stage one ran
|
# TODO Add asserts to make sure stage one ran
|
||||||
padded_second_ubatch_slice = slice(ubatch_slices[1][1].start, num_total_tokens)
|
padded_second_ubatch_slice = slice(ubatch_slices[1][1].start,
|
||||||
ubatch_slices[1] = (padded_second_ubatch_slice, padded_second_ubatch_slice)
|
num_total_tokens)
|
||||||
|
ubatch_slices[1] = (padded_second_ubatch_slice,
|
||||||
|
padded_second_ubatch_slice)
|
||||||
|
|
||||||
def should_ubatch(self, should_ubatch: bool) -> bool:
|
def should_ubatch(self, should_ubatch: bool) -> bool:
|
||||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||||
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||||
return DPMetadata.should_ubatch_across_dp(should_ubatch, dp_size, dp_rank)
|
return DPMetadata.should_ubatch_across_dp(should_ubatch, dp_size,
|
||||||
|
dp_rank)
|
||||||
|
|
||||||
def _get_dummy_model_inputs(self, num_tokens: int) -> tuple:
|
def _get_dummy_model_inputs(self, num_tokens: int) -> tuple:
|
||||||
# Dummy batch. (hopefully we are the last one so we can just
|
# Dummy batch. (hopefully we are the last one so we can just
|
||||||
@ -1455,8 +1469,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device=self.device))
|
device=self.device))
|
||||||
|
|
||||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||||
slice(0, num_tokens), None, False)
|
slice(0, num_tokens), None, False)
|
||||||
|
|
||||||
|
|
||||||
return input_ids, positions, inputs_embeds, intermediate_tensors
|
return input_ids, positions, inputs_embeds, intermediate_tensors
|
||||||
|
|
||||||
@ -1509,55 +1522,50 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def model_inputs(self, tokens_slice: slice, use_dummy_input: bool,
|
def model_inputs(self, tokens_slice: slice, use_dummy_input: bool,
|
||||||
scheduler_output: Optional["SchedulerOutput"]) -> tuple:
|
scheduler_output: Optional["SchedulerOutput"]) -> tuple:
|
||||||
if use_dummy_input:
|
if use_dummy_input:
|
||||||
return self._get_dummy_model_inputs(tokens_slice.stop - tokens_slice.start)
|
return self._get_dummy_model_inputs(tokens_slice.stop -
|
||||||
|
tokens_slice.start)
|
||||||
else:
|
else:
|
||||||
assert scheduler_output is not None
|
assert scheduler_output is not None
|
||||||
return self._get_model_inputs(tokens_slice, scheduler_output)
|
return self._get_model_inputs(tokens_slice, scheduler_output)
|
||||||
|
|
||||||
|
def _make_ubatch_metadata(self, ubatch_slices, attn_metadata,
|
||||||
|
compute_stream, is_dummy_run,
|
||||||
def _make_ubatch_metadata(self,
|
num_tokens_across_dp, skip_cuda_graphs,
|
||||||
ubatch_slices,
|
|
||||||
attn_metadata,
|
|
||||||
compute_stream,
|
|
||||||
is_dummy_run,
|
|
||||||
num_tokens_across_dp,
|
|
||||||
skip_cuda_graphs,
|
|
||||||
scheduler_output) -> list[UbatchMetadata]:
|
scheduler_output) -> list[UbatchMetadata]:
|
||||||
|
|
||||||
# Create one forward context per ubatch
|
# Create one forward context per ubatch
|
||||||
forward_contexts = []
|
forward_contexts = []
|
||||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
num_tokens = (tokens_slice.stop - tokens_slice.start)
|
num_tokens = (tokens_slice.stop - tokens_slice.start)
|
||||||
forward_contexts.append(create_forward_context(
|
forward_contexts.append(
|
||||||
attn_metadata[i]
|
create_forward_context(
|
||||||
if attn_metadata is not None else None,
|
attn_metadata[i] if attn_metadata is not None else None,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
skip_cuda_graphs=skip_cuda_graphs))
|
skip_cuda_graphs=skip_cuda_graphs))
|
||||||
|
|
||||||
ubatch_ctxs = make_ubatch_contexts(num_micro_batches=len(ubatch_slices),
|
ubatch_ctxs = make_ubatch_contexts(
|
||||||
compute_stream=compute_stream,
|
num_micro_batches=len(ubatch_slices),
|
||||||
forward_contexts=forward_contexts,
|
compute_stream=compute_stream,
|
||||||
device=self.device)
|
forward_contexts=forward_contexts,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
ubatch_metadata: list[UbatchMetadata] = []
|
ubatch_metadata: list[UbatchMetadata] = []
|
||||||
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
for i, (_, tokens_slice) in enumerate(ubatch_slices):
|
||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
self.model_inputs(tokens_slice, is_dummy_run, scheduler_output)
|
self.model_inputs(tokens_slice, is_dummy_run, scheduler_output)
|
||||||
ubatch_metadata.append(UbatchMetadata(
|
ubatch_metadata.append(
|
||||||
context=ubatch_ctxs[i],
|
UbatchMetadata(context=ubatch_ctxs[i],
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
intermediate_tensors=intermediate_tensors
|
intermediate_tensors=intermediate_tensors))
|
||||||
))
|
|
||||||
|
|
||||||
return ubatch_metadata
|
return ubatch_metadata
|
||||||
|
|
||||||
|
|
||||||
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
def _run_ubatches(self, ubatch_metadata, model) -> torch.Tensor:
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _ubatch_thread(results, model, ubatch_metadata):
|
def _ubatch_thread(results, model, ubatch_metadata):
|
||||||
with ubatch_metadata.context:
|
with ubatch_metadata.context:
|
||||||
@ -1578,11 +1586,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
ubatch_threads = []
|
ubatch_threads = []
|
||||||
for metadata in ubatch_metadata:
|
for metadata in ubatch_metadata:
|
||||||
thread = threading.Thread(target=_ubatch_thread,
|
thread = threading.Thread(target=_ubatch_thread,
|
||||||
args=(
|
args=(
|
||||||
results,
|
results,
|
||||||
model,
|
model,
|
||||||
metadata,
|
metadata,
|
||||||
))
|
))
|
||||||
ubatch_threads.append(thread)
|
ubatch_threads.append(thread)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
@ -1602,37 +1610,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||||
skip_cuda_graphs: bool = False):
|
skip_cuda_graphs: bool = False):
|
||||||
|
|
||||||
|
|
||||||
# run micro-batched
|
# run micro-batched
|
||||||
if ubatch_slices is not None:
|
if ubatch_slices is not None:
|
||||||
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
assert len(ubatch_slices) == 2, "Only two ubatches has been tested"
|
||||||
|
|
||||||
compute_stream = torch.cuda.current_stream()
|
compute_stream = torch.cuda.current_stream()
|
||||||
ubatch_metadata = self._make_ubatch_metadata(
|
ubatch_metadata = self._make_ubatch_metadata(
|
||||||
ubatch_slices=ubatch_slices,
|
ubatch_slices=ubatch_slices,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
compute_stream=compute_stream,
|
compute_stream=compute_stream,
|
||||||
is_dummy_run=is_dummy_run,
|
is_dummy_run=is_dummy_run,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
skip_cuda_graphs=skip_cuda_graphs,
|
skip_cuda_graphs=skip_cuda_graphs,
|
||||||
scheduler_output=scheduler_output
|
scheduler_output=scheduler_output)
|
||||||
)
|
|
||||||
return self._run_ubatches(ubatch_metadata, self.model)
|
return self._run_ubatches(ubatch_metadata, self.model)
|
||||||
# run normal batch
|
# run normal batch
|
||||||
else:
|
else:
|
||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
self.model_inputs(slice(0, num_scheduled_tokens), is_dummy_run, scheduler_output)
|
self.model_inputs(slice(0, num_scheduled_tokens),
|
||||||
|
is_dummy_run,
|
||||||
|
scheduler_output)
|
||||||
with set_forward_context(attn_metadata,
|
with set_forward_context(attn_metadata,
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
num_tokens=num_scheduled_tokens or 1,
|
num_tokens=num_scheduled_tokens or 1,
|
||||||
num_tokens_across_dp=num_tokens_across_dp,
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
skip_cuda_graphs=skip_cuda_graphs):
|
skip_cuda_graphs=skip_cuda_graphs):
|
||||||
return self.model(
|
return self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _pool(
|
def _pool(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -1693,18 +1702,20 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
return self.kv_connector_no_forward(scheduler_output)
|
return self.kv_connector_no_forward(scheduler_output)
|
||||||
|
|
||||||
# num_scheduled_tokens_old = scheduler_output.total_num_scheduled_tokens
|
|
||||||
# num_pad_tokens, num_tokens_after_padding = self.get_dp_padding(num_scheduled_tokens_old)
|
|
||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
attn_metadata, attention_cuda_graphs, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, ubatch_slices, num_pad_tokens, num_tokens_after_padding = (
|
(attn_metadata, attention_cuda_graphs, logits_indices,
|
||||||
self._prepare_inputs(scheduler_output))
|
spec_decode_metadata, num_scheduled_tokens_np, ubatch_slices,
|
||||||
|
num_pad_tokens,
|
||||||
|
num_tokens_after_padding) = self._prepare_inputs(scheduler_output)
|
||||||
|
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
num_input_tokens = num_scheduled_tokens
|
num_input_tokens = num_scheduled_tokens
|
||||||
if ubatch_slices and num_pad_tokens > 0:
|
if ubatch_slices and num_pad_tokens > 0:
|
||||||
num_input_tokens += num_pad_tokens
|
num_input_tokens += num_pad_tokens
|
||||||
self.pad_out_ubatch_second_stage(ubatch_slices, num_input_tokens)
|
self.pad_out_ubatch_second_stage(ubatch_slices, num_input_tokens)
|
||||||
elif ubatch_slices is None:
|
elif ubatch_slices is None:
|
||||||
num_pad, num_tokens_after_padding = self.get_padding(num_input_tokens)
|
num_pad, num_tokens_after_padding = self.get_padding(
|
||||||
|
num_input_tokens)
|
||||||
num_input_tokens += num_pad
|
num_input_tokens += num_pad
|
||||||
|
|
||||||
# Some attention backends only support CUDA Graphs in pure decode.
|
# Some attention backends only support CUDA Graphs in pure decode.
|
||||||
@ -1856,6 +1867,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Speculative decoding is not enabled.
|
# Speculative decoding is not enabled.
|
||||||
spec_token_ids = None
|
spec_token_ids = None
|
||||||
else:
|
else:
|
||||||
|
assert not ubatch_slices
|
||||||
|
assert isinstance(attn_metadata, dict)
|
||||||
spec_token_ids = self.propose_draft_token_ids(
|
spec_token_ids = self.propose_draft_token_ids(
|
||||||
scheduler_output,
|
scheduler_output,
|
||||||
valid_sampled_token_ids,
|
valid_sampled_token_ids,
|
||||||
@ -2354,7 +2367,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
|
|
||||||
|
|
||||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||||
num_scheduled_tokens):
|
num_scheduled_tokens):
|
||||||
outputs = self._run_model(
|
outputs = self._run_model(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user