mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 04:45:01 +08:00
[Core] Refactor padding logic and pad for CUDA graphs before attention metadata building (#28579)
This commit is contained in:
parent
430dd4d9eb
commit
56539cddac
@ -84,12 +84,14 @@ See the following figures for a quick comparison between the previous and curren
|
||||
```python
|
||||
class BatchDescriptor(NamedTuple):
|
||||
num_tokens: int
|
||||
uniform_decode: bool = False
|
||||
num_reqs: int
|
||||
uniform: bool = False
|
||||
has_lora: bool = False
|
||||
```
|
||||
|
||||
where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`.
|
||||
where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`.
|
||||
|
||||
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode.
|
||||
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item.
|
||||
|
||||
!!! note
|
||||
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (<https://github.com/vllm-project/vllm/pull/23679>), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs).
|
||||
|
||||
@ -42,12 +42,24 @@ def _create_vllm_config(
|
||||
mock_config.compilation_config = compilation_config
|
||||
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
|
||||
mock_config.parallel_config = ParallelConfig()
|
||||
mock_config.speculative_config = None # No speculative decoding
|
||||
if not lora_config:
|
||||
mock_config.lora_config = None
|
||||
# Mimic the behavior of VllmConfig.__post_init__()
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
compilation_config.set_splitting_ops_for_v1()
|
||||
|
||||
# mimic VllmConfig.__post_init__
|
||||
if compilation_config.cudagraph_capture_sizes:
|
||||
compilation_config.max_cudagraph_capture_size = (
|
||||
compilation_config.cudagraph_capture_sizes[-1]
|
||||
)
|
||||
|
||||
compilation_config.post_init_cudagraph_sizes()
|
||||
mock_config.pad_for_cudagraph = (
|
||||
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
|
||||
)
|
||||
|
||||
return mock_config
|
||||
|
||||
|
||||
@ -109,9 +121,11 @@ class TestCudagraphDispatcher:
|
||||
# 1. non-uniform batch, size in cudagraph size list
|
||||
desc_full_exact = BatchDescriptor(
|
||||
num_tokens=8,
|
||||
uniform_decode=False,
|
||||
uniform=False,
|
||||
)
|
||||
rt_mode, key = dispatcher.dispatch(
|
||||
num_tokens=8, uniform_decode=False, has_lora=False
|
||||
)
|
||||
rt_mode, key = dispatcher.dispatch(desc_full_exact)
|
||||
if cudagraph_mode_str == "FULL":
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_full_exact
|
||||
@ -122,32 +136,37 @@ class TestCudagraphDispatcher:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
|
||||
# 2. uniform decode batch, size in cudagraph size list
|
||||
desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
|
||||
rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
|
||||
desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
|
||||
rt_mode, key = dispatcher.dispatch(
|
||||
num_tokens=8, uniform_decode=True, has_lora=False
|
||||
)
|
||||
if cudagraph_mode_str == "FULL":
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_uniform_exact.non_uniform
|
||||
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
|
||||
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
|
||||
assert rt_mode == CUDAGraphMode.FULL
|
||||
assert key == desc_uniform_exact
|
||||
elif cudagraph_mode_str == "PIECEWISE":
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
assert key == desc_uniform_exact.non_uniform
|
||||
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
|
||||
else:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
|
||||
# 3. No key match
|
||||
desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
|
||||
rt_mode, key = dispatcher.dispatch(desc_no_match)
|
||||
rt_mode, key = dispatcher.dispatch(
|
||||
num_tokens=15, uniform_decode=False, has_lora=False
|
||||
)
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
assert key is None
|
||||
assert key == BatchDescriptor(num_tokens=15)
|
||||
|
||||
# 4. Cascade attention should have a fall back mode
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
|
||||
rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True)
|
||||
desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
|
||||
rt_mode, key = dispatcher.dispatch(
|
||||
num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
|
||||
)
|
||||
if "PIECEWISE" in cudagraph_mode_str: # string contains check
|
||||
assert rt_mode == CUDAGraphMode.PIECEWISE
|
||||
assert key == desc_full_exact.non_uniform
|
||||
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
|
||||
else:
|
||||
assert rt_mode == CUDAGraphMode.NONE
|
||||
|
||||
|
||||
@ -35,23 +35,27 @@ class BatchDescriptor(NamedTuple):
|
||||
"""
|
||||
|
||||
num_tokens: int
|
||||
uniform_decode: bool = False
|
||||
num_reqs: int | None = None
|
||||
"""
|
||||
False can also be used for an uniform decode batch to dispatch to the
|
||||
cudagraph supporting non-uniform batches.
|
||||
Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
|
||||
the cudagraphs can handle any number of requests.
|
||||
"""
|
||||
uniform: bool = False
|
||||
"""
|
||||
True if all the requests in the batch have the same number of tokens.
|
||||
"""
|
||||
has_lora: bool = False
|
||||
"""
|
||||
Whether this batch has active LoRA adapters.
|
||||
"""
|
||||
|
||||
@property
|
||||
def non_uniform(self) -> "BatchDescriptor":
|
||||
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
|
||||
"""
|
||||
Return a non-uniform version of current batch descriptor.
|
||||
Return a relaxed version of current batch descriptor that is still compatible
|
||||
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
|
||||
"""
|
||||
return BatchDescriptor(
|
||||
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
|
||||
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -930,30 +930,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
if num_decodes > 0:
|
||||
pure_decode = num_prefills == 0
|
||||
# possible required padding for cudagraph replay
|
||||
use_cudagraph = (
|
||||
self.enable_cuda_graph
|
||||
and pure_decode
|
||||
and num_decode_tokens <= self._decode_cudagraph_max_bs
|
||||
)
|
||||
if use_cudagraph:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
num_decode_tokens
|
||||
)
|
||||
# Carefully fulfill the padding region with reasonable value
|
||||
# on cpu.
|
||||
# Make sure paged_kv_indptr_cpu is not decreasing
|
||||
self.paged_kv_indptr_cpu[
|
||||
1 + num_decodes : 1 + num_input_tokens
|
||||
].fill_(paged_kv_indptr_cpu[-1])
|
||||
# Fill the remaining paged_kv_last_page_len_cpu with 1.
|
||||
# This is because flashinfer treats 0 as a full page
|
||||
# instead of empty.
|
||||
self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
|
||||
1
|
||||
)
|
||||
|
||||
else:
|
||||
num_input_tokens = num_decode_tokens
|
||||
|
||||
attn_metadata.decode_wrapper = self._get_decode_wrapper(
|
||||
|
||||
@ -107,6 +107,8 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
)
|
||||
# -1 in case it's non-computed and causes later issues with indexing
|
||||
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
|
||||
# -1 in the case we have a padded request (0 seq-len)
|
||||
block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0)
|
||||
|
||||
return (
|
||||
block_idx_last_computed_token,
|
||||
|
||||
@ -72,6 +72,7 @@ class CommonAttentionMetadata:
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
max_query_len: int
|
||||
@ -857,7 +858,9 @@ def split_decodes_and_prefills(
|
||||
if require_uniform:
|
||||
is_prefill = query_lens != query_lens[0]
|
||||
else:
|
||||
is_prefill = query_lens > decode_threshold
|
||||
# 0-query len indicates a padded request; leave this at the back
|
||||
# of the batch with the prefills
|
||||
is_prefill = (query_lens > decode_threshold) | (query_lens == 0)
|
||||
|
||||
if not torch.any(is_prefill):
|
||||
return num_reqs, 0, num_tokens, 0
|
||||
|
||||
@ -4,6 +4,9 @@ from itertools import product
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CudagraphDispatcher:
|
||||
@ -28,7 +31,11 @@ class CudagraphDispatcher:
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
self.uniform_decode_query_len = (
|
||||
1
|
||||
if not self.vllm_config.speculative_config
|
||||
else 1 + self.vllm_config.speculative_config.num_speculative_tokens
|
||||
)
|
||||
|
||||
# Dict to store valid cudagraph dispatching keys.
|
||||
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
|
||||
@ -36,25 +43,42 @@ class CudagraphDispatcher:
|
||||
CUDAGraphMode.FULL: set(),
|
||||
}
|
||||
|
||||
not_use_piecewise_compilation = (
|
||||
not self.cudagraph_mode.requires_piecewise_compilation()
|
||||
)
|
||||
|
||||
assert (
|
||||
not_use_piecewise_compilation
|
||||
not self.compilation_config.cudagraph_mode.requires_piecewise_compilation()
|
||||
or self.compilation_config.is_attention_compiled_piecewise()
|
||||
), (
|
||||
"Compilation mode should be CompilationMode.VLLM_COMPILE when "
|
||||
"cudagraph_mode piecewise cudagraphs is used, "
|
||||
"and attention should be in splitting_ops or "
|
||||
"inductor splitting should be used. "
|
||||
f"cudagraph_mode={self.cudagraph_mode}, "
|
||||
f"cudagraph_mode={self.compilation_config.cudagraph_mode}, "
|
||||
f"compilation_mode={self.compilation_config.mode}, "
|
||||
f"splitting_ops={self.compilation_config.splitting_ops}"
|
||||
)
|
||||
|
||||
self.keys_initialized = False
|
||||
|
||||
def _create_padded_batch_descriptor(
|
||||
self, num_tokens: int, uniform_decode: bool, has_lora: bool
|
||||
) -> BatchDescriptor:
|
||||
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
uniform_decode_query_len = self.uniform_decode_query_len
|
||||
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
|
||||
if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
|
||||
num_reqs = num_tokens_padded // uniform_decode_query_len
|
||||
assert num_tokens_padded % uniform_decode_query_len == 0
|
||||
else:
|
||||
uniform_decode = False
|
||||
num_reqs = min(num_tokens_padded, max_num_seqs)
|
||||
|
||||
return BatchDescriptor(
|
||||
num_tokens=num_tokens_padded,
|
||||
num_reqs=num_reqs,
|
||||
uniform=uniform_decode,
|
||||
has_lora=has_lora,
|
||||
)
|
||||
|
||||
def add_cudagraph_key(
|
||||
self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
|
||||
):
|
||||
@ -66,7 +90,9 @@ class CudagraphDispatcher:
|
||||
def initialize_cudagraph_keys(
|
||||
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
|
||||
):
|
||||
# This should be called only after attention backend is initialized.
|
||||
# This should be called only after attention backend is initialized. So we can
|
||||
# get the correct cudagraph mode after backend support is resolved.
|
||||
self.cudagraph_mode = cudagraph_mode
|
||||
|
||||
# LoRA activation cases to specialize the cuda graphs on
|
||||
if self.vllm_config.lora_config:
|
||||
@ -86,9 +112,9 @@ class CudagraphDispatcher:
|
||||
):
|
||||
self.add_cudagraph_key(
|
||||
cudagraph_mode.mixed_mode(),
|
||||
BatchDescriptor(
|
||||
num_tokens=bs, uniform_decode=False, has_lora=has_lora
|
||||
),
|
||||
self._create_padded_batch_descriptor(
|
||||
bs, False, has_lora
|
||||
).relax_for_mixed_batch_cudagraphs(),
|
||||
)
|
||||
|
||||
# if decode cudagraph mode is FULL, and we don't already have mixed
|
||||
@ -109,40 +135,49 @@ class CudagraphDispatcher:
|
||||
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
|
||||
self.add_cudagraph_key(
|
||||
CUDAGraphMode.FULL,
|
||||
BatchDescriptor(
|
||||
num_tokens=bs, uniform_decode=True, has_lora=has_lora
|
||||
),
|
||||
self._create_padded_batch_descriptor(bs, True, has_lora),
|
||||
)
|
||||
|
||||
self.keys_initialized = True
|
||||
|
||||
def dispatch(
|
||||
self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False
|
||||
) -> tuple[CUDAGraphMode, BatchDescriptor | None]:
|
||||
self,
|
||||
num_tokens: int,
|
||||
uniform_decode: bool,
|
||||
has_lora: bool,
|
||||
use_cascade_attn: bool = False,
|
||||
) -> tuple[CUDAGraphMode, BatchDescriptor]:
|
||||
"""
|
||||
Given conditions(e.g.,batch descriptor and if using cascade attention),
|
||||
dispatch to a cudagraph runtime mode and the valid batch descriptor.
|
||||
A new batch descriptor is returned as we might dispatch a uniform batch
|
||||
to a graph that supports a more general batch (uniform to non-uniform).
|
||||
"""
|
||||
# if not initialized, just skip dispatching.
|
||||
if not self.keys_initialized:
|
||||
return CUDAGraphMode.NONE, None
|
||||
if (
|
||||
not self.keys_initialized
|
||||
or self.cudagraph_mode == CUDAGraphMode.NONE
|
||||
or num_tokens > self.compilation_config.max_cudagraph_capture_size
|
||||
):
|
||||
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
||||
|
||||
batch_desc = self._create_padded_batch_descriptor(
|
||||
num_tokens, uniform_decode, has_lora
|
||||
)
|
||||
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
|
||||
|
||||
non_uniform_key = batch_descriptor.non_uniform
|
||||
# if a batch use cascade attention, bypass checking full cudagraphs
|
||||
if not use_cascade_attn:
|
||||
# check if key exists for full cudagraph
|
||||
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, batch_descriptor
|
||||
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, batch_desc
|
||||
|
||||
# otherwise, check if non-uniform key exists
|
||||
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, non_uniform_key
|
||||
# otherwise, check if the relaxed key exists
|
||||
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, relaxed_batch_desc
|
||||
|
||||
# also check if non-uniform key exists for more "general"
|
||||
# also check if the relaxed key exists for more "general"
|
||||
# piecewise cudagraph
|
||||
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
|
||||
return CUDAGraphMode.PIECEWISE, non_uniform_key
|
||||
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
|
||||
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
|
||||
|
||||
# finally, just return no cudagraphs
|
||||
return CUDAGraphMode.NONE, None
|
||||
# finally, just return no cudagraphs and a trivial batch descriptor
|
||||
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
|
||||
|
||||
@ -9,6 +9,7 @@ from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlice,
|
||||
UBatchSlices,
|
||||
check_ubatch_thresholds,
|
||||
create_ubatch_slices,
|
||||
@ -88,6 +89,17 @@ def _post_process_dp_padding(tensor: torch.Tensor, should_dp_pad: bool) -> torch
|
||||
return num_tokens_across_dp.cpu()
|
||||
|
||||
|
||||
# This just pads the second ubatch slice out to the total number of tokens
|
||||
# (num_tokens + padding) since we do `create_ubatch_slices` before applying DP padding.
|
||||
def _pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
|
||||
padded_second_ubatch_slice = slice(
|
||||
ubatch_slices[1].token_slice.start, num_total_tokens
|
||||
)
|
||||
ubatch_slices[1] = UBatchSlice(
|
||||
padded_second_ubatch_slice, padded_second_ubatch_slice
|
||||
)
|
||||
|
||||
|
||||
def _synchronize_dp_ranks(
|
||||
num_tokens_unpadded: int,
|
||||
num_tokens_padded: int,
|
||||
@ -220,11 +232,14 @@ def coordinate_batch_across_dp(
|
||||
# to the second ubatch in pad_out_ubatch_slice after attention
|
||||
# metadata creation
|
||||
assert num_tokens_after_padding is not None
|
||||
token_split_point = int(num_tokens_after_padding[0].item()) // 2
|
||||
num_tokens_padded = int(num_tokens_after_padding[0].item())
|
||||
token_split_point = int(num_tokens_padded) // 2
|
||||
|
||||
assert num_scheduled_tokens_per_request is not None
|
||||
ubatch_slices = create_ubatch_slices(
|
||||
num_scheduled_tokens_per_request, token_split_point
|
||||
)
|
||||
ubatch_slices = _pad_out_ubatch_slice(ubatch_slices, num_tokens_padded)
|
||||
assert sum(s.num_tokens for s in ubatch_slices) == num_tokens_padded
|
||||
|
||||
return (ubatch_slices, num_tokens_after_padding)
|
||||
|
||||
@ -151,7 +151,6 @@ from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.ubatch_utils import (
|
||||
UBatchSlice,
|
||||
UBatchSlices,
|
||||
check_ubatch_thresholds,
|
||||
)
|
||||
@ -1239,17 +1238,13 @@ class GPUModelRunner(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
max_num_scheduled_tokens: int,
|
||||
) -> tuple[
|
||||
torch.Tensor,
|
||||
SpecDecodeMetadata | None,
|
||||
UBatchSlices | None,
|
||||
torch.Tensor | None,
|
||||
]:
|
||||
"""
|
||||
:return: tuple[
|
||||
logits_indices, spec_decode_metadata,
|
||||
ubatch_slices, num_tokens_across_dp,
|
||||
]
|
||||
"""
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
@ -1364,28 +1359,6 @@ class GPUModelRunner(
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
|
||||
|
||||
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
|
||||
num_tokens_padded = self._get_num_input_tokens(num_tokens_unpadded)
|
||||
uniform_decode = (
|
||||
max_num_scheduled_tokens == self.uniform_decode_query_len
|
||||
) and (total_num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
|
||||
|
||||
# Disable DP padding when running eager to avoid excessive padding when
|
||||
# running prefills. This lets us set enforce_eager on the prefiller in
|
||||
# a P/D setup and still use CUDA graphs (enabled by this padding) on the
|
||||
# decoder.
|
||||
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
|
||||
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens_unpadded,
|
||||
parallel_config=self.parallel_config,
|
||||
allow_microbatching=True,
|
||||
allow_dp_padding=allow_dp_padding,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
uniform_decode=uniform_decode,
|
||||
num_scheduled_tokens_per_request=num_scheduled_tokens,
|
||||
)
|
||||
|
||||
self.seq_lens.np[:num_reqs] = (
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs] + num_scheduled_tokens
|
||||
)
|
||||
@ -1486,15 +1459,15 @@ class GPUModelRunner(
|
||||
return (
|
||||
logits_indices,
|
||||
spec_decode_metadata,
|
||||
ubatch_slices,
|
||||
num_tokens_across_dp,
|
||||
)
|
||||
|
||||
def _build_attention_metadata(
|
||||
self,
|
||||
total_num_scheduled_tokens: int,
|
||||
max_num_scheduled_tokens: int,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
max_query_len: int,
|
||||
num_tokens_padded: int | None = None,
|
||||
num_reqs_padded: int | None = None,
|
||||
ubatch_slices: UBatchSlices | None = None,
|
||||
logits_indices: torch.Tensor | None = None,
|
||||
use_spec_decode: bool = False,
|
||||
@ -1505,6 +1478,9 @@ class GPUModelRunner(
|
||||
"""
|
||||
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
|
||||
"""
|
||||
num_tokens_padded = num_tokens_padded or num_tokens
|
||||
num_reqs_padded = num_reqs_padded or num_reqs
|
||||
|
||||
logits_indices_padded = None
|
||||
num_logits_indices = None
|
||||
if logits_indices is not None:
|
||||
@ -1522,28 +1498,13 @@ class GPUModelRunner(
|
||||
self.dcp_rank,
|
||||
self.parallel_config.cp_kv_cache_interleave_size,
|
||||
)
|
||||
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
|
||||
self.dcp_local_seq_lens.cpu[num_reqs:].fill_(0)
|
||||
self.dcp_local_seq_lens.copy_to_gpu(num_reqs_padded)
|
||||
|
||||
attn_metadata: PerLayerAttnMetadata = {}
|
||||
if ubatch_slices is not None:
|
||||
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
|
||||
|
||||
# Used in the below loop
|
||||
query_start_loc = self.query_start_loc.gpu[: num_reqs + 1]
|
||||
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs + 1]
|
||||
seq_lens = self.seq_lens.gpu[:num_reqs]
|
||||
seq_lens_cpu = self.seq_lens.cpu[:num_reqs]
|
||||
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
||||
:num_reqs
|
||||
]
|
||||
|
||||
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
|
||||
if self.dcp_world_size > 1:
|
||||
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs]
|
||||
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs]
|
||||
|
||||
spec_decode_common_attn_metadata = None
|
||||
|
||||
if for_cudagraph_capture:
|
||||
# For some attention backends (e.g. FA) with sliding window models we need
|
||||
# to make sure the backend see a max_seq_len that is larger to the sliding
|
||||
@ -1559,6 +1520,22 @@ class GPUModelRunner(
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
self.num_accepted_tokens.copy_to_gpu()
|
||||
|
||||
# Used in the below loop, uses padded shapes
|
||||
query_start_loc = self.query_start_loc.gpu[: num_reqs_padded + 1]
|
||||
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
|
||||
seq_lens = self.seq_lens.gpu[:num_reqs_padded]
|
||||
seq_lens_cpu = self.seq_lens.cpu[:num_reqs_padded]
|
||||
num_computed_tokens_cpu = self.input_batch.num_computed_tokens_cpu_tensor[
|
||||
:num_reqs_padded
|
||||
]
|
||||
|
||||
dcp_local_seq_lens, dcp_local_seq_lens_cpu = None, None
|
||||
if self.dcp_world_size > 1:
|
||||
dcp_local_seq_lens = self.dcp_local_seq_lens.gpu[:num_reqs_padded]
|
||||
dcp_local_seq_lens_cpu = self.dcp_local_seq_lens.cpu[:num_reqs_padded]
|
||||
|
||||
spec_decode_common_attn_metadata = None
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_gid, kv_cache_group in enumerate(
|
||||
@ -1567,30 +1544,31 @@ class GPUModelRunner(
|
||||
encoder_seq_lens, encoder_seq_lens_cpu = self._get_encoder_seq_lens(
|
||||
num_scheduled_tokens or {},
|
||||
kv_cache_group.kv_cache_spec,
|
||||
num_reqs,
|
||||
num_reqs_padded,
|
||||
)
|
||||
|
||||
if isinstance(kv_cache_group.kv_cache_spec, EncoderOnlyAttentionSpec):
|
||||
# Encoder-only layers do not have KV cache, so we need to
|
||||
# create a dummy block table and slot mapping for them.
|
||||
blk_table_tensor = torch.zeros(
|
||||
(num_reqs, 1),
|
||||
(num_tokens_padded, 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
slot_mapping = torch.zeros(
|
||||
(total_num_scheduled_tokens,),
|
||||
(num_tokens_padded,),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
blk_table = self.input_batch.block_table[kv_cache_gid]
|
||||
blk_table_tensor = blk_table.get_device_tensor(num_reqs)
|
||||
slot_mapping = blk_table.slot_mapping.gpu[:total_num_scheduled_tokens]
|
||||
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
|
||||
slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
|
||||
|
||||
# Fill unused with -1. Needed for reshape_and_cache in full cuda
|
||||
# graph mode.
|
||||
blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1)
|
||||
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
|
||||
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
|
||||
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
|
||||
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
@ -1598,9 +1576,9 @@ class GPUModelRunner(
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_cpu=seq_lens_cpu,
|
||||
num_computed_tokens_cpu=num_computed_tokens_cpu,
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
num_actual_tokens=num_tokens_padded,
|
||||
num_reqs=num_reqs_padded,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
block_table_tensor=blk_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
@ -1631,9 +1609,11 @@ class GPUModelRunner(
|
||||
extra_attn_metadata_args = {}
|
||||
if use_spec_decode and isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
extra_attn_metadata_args = dict(
|
||||
num_accepted_tokens=self.num_accepted_tokens.gpu[:num_reqs],
|
||||
num_accepted_tokens=self.num_accepted_tokens.gpu[
|
||||
:num_reqs_padded
|
||||
],
|
||||
num_decode_draft_tokens_cpu=self.num_decode_draft_tokens.cpu[
|
||||
:num_reqs
|
||||
:num_reqs_padded
|
||||
],
|
||||
)
|
||||
|
||||
@ -1677,6 +1657,7 @@ class GPUModelRunner(
|
||||
def _compute_cascade_attn_prefix_lens(
|
||||
self,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_computed_tokens: np.ndarray,
|
||||
num_common_prefix_blocks: list[int],
|
||||
) -> list[list[int]] | None:
|
||||
"""
|
||||
@ -1699,6 +1680,7 @@ class GPUModelRunner(
|
||||
# 0 if cascade attention should not be used
|
||||
cascade_attn_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
num_computed_tokens,
|
||||
num_common_prefix_blocks[kv_cache_gid],
|
||||
attn_group.kv_cache_spec,
|
||||
attn_group.get_metadata_builder(),
|
||||
@ -1711,6 +1693,7 @@ class GPUModelRunner(
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
self,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_computed_tokens: np.ndarray,
|
||||
num_common_prefix_blocks: int,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
attn_metadata_builder: AttentionMetadataBuilder,
|
||||
@ -1777,10 +1760,7 @@ class GPUModelRunner(
|
||||
# and the second kernel will get an empty input. While this is not
|
||||
# a fundamental problem, our current implementation does not support
|
||||
# this case.
|
||||
num_reqs = len(num_scheduled_tokens)
|
||||
common_prefix_len = min(
|
||||
common_prefix_len, self.input_batch.num_computed_tokens_cpu[:num_reqs].min()
|
||||
)
|
||||
common_prefix_len = min(common_prefix_len, num_computed_tokens.min())
|
||||
# common_prefix_len should be a multiple of the block size.
|
||||
common_prefix_len = (
|
||||
common_prefix_len // kv_cache_spec.block_size * kv_cache_spec.block_size
|
||||
@ -2334,19 +2314,6 @@ class GPUModelRunner(
|
||||
log_stats=self.parallel_config.eplb_config.log_balancedness,
|
||||
)
|
||||
|
||||
# This is where the second ubatch is adjusted to account for the padding.
|
||||
# Should be called after attention metadata creation. This just pads
|
||||
# the second ubatch slice out to the total number of tokens
|
||||
# (num_tokens + padding)
|
||||
@staticmethod
|
||||
def pad_out_ubatch_slice(ubatch_slices: UBatchSlices, num_total_tokens: int):
|
||||
padded_second_ubatch_slice = slice(
|
||||
ubatch_slices[1].token_slice.start, num_total_tokens
|
||||
)
|
||||
ubatch_slices[1] = UBatchSlice(
|
||||
padded_second_ubatch_slice, padded_second_ubatch_slice
|
||||
)
|
||||
|
||||
def _pool(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -2391,18 +2358,7 @@ class GPUModelRunner(
|
||||
pooler_output=pooler_output,
|
||||
)
|
||||
|
||||
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and hasattr(self, "cudagraph_batch_sizes")
|
||||
and self.cudagraph_batch_sizes
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]
|
||||
):
|
||||
# Use CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
return self.vllm_config.pad_for_cudagraph(num_scheduled_tokens)
|
||||
|
||||
# Eager mode.
|
||||
def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
|
||||
# Pad tokens to multiple of tensor_parallel_size when
|
||||
# enabled collective fusion for SP
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
@ -2738,6 +2694,87 @@ class GPUModelRunner(
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
def _determine_batch_execution_and_padding(
|
||||
self,
|
||||
num_tokens: int,
|
||||
num_reqs: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
max_num_scheduled_tokens: int,
|
||||
use_cascade_attn: bool,
|
||||
allow_microbatching: bool = True,
|
||||
force_eager: bool = False,
|
||||
# For cudagraph capture TODO(lucas): Refactor how we capture cudagraphs (will
|
||||
# be improved in model runner v2)
|
||||
force_uniform_decode: bool | None = None,
|
||||
force_has_lora: bool | None = None,
|
||||
) -> tuple[
|
||||
CUDAGraphMode, BatchDescriptor, UBatchSlices | None, torch.Tensor | None
|
||||
]:
|
||||
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
|
||||
uniform_decode = (
|
||||
(
|
||||
(max_num_scheduled_tokens == self.uniform_decode_query_len)
|
||||
and (num_tokens_padded == max_num_scheduled_tokens * num_reqs)
|
||||
)
|
||||
if force_uniform_decode is None
|
||||
else force_uniform_decode
|
||||
)
|
||||
|
||||
has_lora = (
|
||||
len(self.input_batch.lora_id_to_lora_request) > 0
|
||||
if force_has_lora is None
|
||||
else force_has_lora
|
||||
)
|
||||
|
||||
dispatch_cudagraph = (
|
||||
lambda num_tokens: self.cudagraph_dispatcher.dispatch(
|
||||
num_tokens=num_tokens,
|
||||
has_lora=has_lora,
|
||||
use_cascade_attn=use_cascade_attn,
|
||||
uniform_decode=uniform_decode,
|
||||
)
|
||||
if not force_eager
|
||||
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
|
||||
)
|
||||
|
||||
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
|
||||
num_tokens_padded = batch_descriptor.num_tokens
|
||||
|
||||
# Extra coordination when running data-parallel since we need to coordinate
|
||||
# across ranks
|
||||
ubatch_slices, num_tokens_across_dp = None, None
|
||||
if self.vllm_config.parallel_config.data_parallel_size > 1:
|
||||
# Disable DP padding when running eager to avoid excessive padding when
|
||||
# running prefills. This lets us set cudagraph_mode="NONE" on the prefiller
|
||||
# in a P/D setup and still use CUDA graphs (enabled by this padding) on the
|
||||
# decoder.
|
||||
allow_dp_padding = (
|
||||
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
)
|
||||
|
||||
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=num_tokens_padded,
|
||||
parallel_config=self.parallel_config,
|
||||
allow_microbatching=allow_microbatching,
|
||||
allow_dp_padding=allow_dp_padding,
|
||||
num_tokens_padded=num_tokens_padded,
|
||||
uniform_decode=uniform_decode,
|
||||
num_scheduled_tokens_per_request=num_scheduled_tokens_np,
|
||||
)
|
||||
|
||||
# Extract DP padding if there is any
|
||||
if num_tokens_across_dp is not None:
|
||||
dp_rank = self.parallel_config.data_parallel_rank
|
||||
num_tokens_padded = int(num_tokens_across_dp[dp_rank].item())
|
||||
|
||||
# Re-dispatch with DP padding
|
||||
cudagraph_mode, batch_descriptor = dispatch_cudagraph(num_tokens_padded)
|
||||
# Assert to make sure the agreed upon token count is correct otherwise
|
||||
# num_tokens_across_dp will no-longer be valid
|
||||
assert batch_descriptor.num_tokens == num_tokens_padded
|
||||
|
||||
return cudagraph_mode, batch_descriptor, ubatch_slices, num_tokens_across_dp
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -2790,7 +2827,7 @@ class GPUModelRunner(
|
||||
# returns True. before returning early here we call
|
||||
# dummy run to ensure coordinate_batch_across_dp
|
||||
# is called into to avoid out of sync issues.
|
||||
self._dummy_run(self._get_num_input_tokens(1))
|
||||
self._dummy_run(1)
|
||||
if not has_kv_transfer_group():
|
||||
# Return empty ModelRunnerOutput if no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
@ -2809,36 +2846,63 @@ class GPUModelRunner(
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens_np = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = int(num_scheduled_tokens_np.max())
|
||||
num_tokens_unpadded = scheduler_output.total_num_scheduled_tokens
|
||||
|
||||
(
|
||||
logits_indices,
|
||||
spec_decode_metadata,
|
||||
ubatch_slices,
|
||||
num_tokens_across_dp,
|
||||
) = self._prepare_inputs(
|
||||
scheduler_output, num_scheduled_tokens_np, max_num_scheduled_tokens
|
||||
scheduler_output,
|
||||
num_scheduled_tokens_np,
|
||||
)
|
||||
|
||||
cascade_attn_prefix_lens = None
|
||||
# Disable cascade attention when using microbatching (DBO)
|
||||
if self.cascade_attn_enabled and ubatch_slices is None:
|
||||
if self.cascade_attn_enabled and not self.parallel_config.enable_dbo:
|
||||
# Pre-compute cascade attention prefix lengths
|
||||
# NOTE: Must be AFTER _prepare_inputs uses self.input_batch state
|
||||
cascade_attn_prefix_lens = self._compute_cascade_attn_prefix_lens(
|
||||
num_scheduled_tokens_np,
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs],
|
||||
scheduler_output.num_common_prefix_blocks,
|
||||
)
|
||||
|
||||
# TODO(lucas): move cudagraph dispatching here:
|
||||
# https://github.com/vllm-project/vllm/issues/23789
|
||||
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
attn_metadata, spec_decode_common_attn_metadata = (
|
||||
self._build_attention_metadata(
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
max_num_scheduled_tokens=max_num_scheduled_tokens,
|
||||
(
|
||||
cudagraph_mode,
|
||||
batch_desc,
|
||||
ubatch_slices,
|
||||
num_tokens_across_dp,
|
||||
) = self._determine_batch_execution_and_padding(
|
||||
num_tokens=num_tokens_unpadded,
|
||||
num_reqs=num_reqs,
|
||||
num_scheduled_tokens_np=num_scheduled_tokens_np,
|
||||
max_num_scheduled_tokens=max_num_scheduled_tokens,
|
||||
use_cascade_attn=cascade_attn_prefix_lens is not None,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Running batch with cudagraph_mode: %s, batch_descriptor: %s, "
|
||||
"ubatch_slices: %s, num_tokens_across_dp: %s",
|
||||
cudagraph_mode,
|
||||
batch_desc,
|
||||
ubatch_slices,
|
||||
num_tokens_across_dp,
|
||||
)
|
||||
|
||||
num_tokens_padded = batch_desc.num_tokens
|
||||
num_reqs_padded = (
|
||||
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
|
||||
)
|
||||
|
||||
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
|
||||
pad_attn = cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
(attn_metadata, spec_decode_common_attn_metadata) = (
|
||||
self._build_attention_metadata(
|
||||
num_tokens=num_tokens_unpadded,
|
||||
num_tokens_padded=num_tokens_padded if pad_attn else None,
|
||||
num_reqs=num_reqs,
|
||||
num_reqs_padded=num_reqs_padded if pad_attn else None,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
ubatch_slices=ubatch_slices,
|
||||
logits_indices=logits_indices,
|
||||
use_spec_decode=use_spec_decode,
|
||||
@ -2847,18 +2911,6 @@ class GPUModelRunner(
|
||||
)
|
||||
)
|
||||
|
||||
dp_rank = self.parallel_config.data_parallel_rank
|
||||
if ubatch_slices:
|
||||
assert num_tokens_across_dp is not None
|
||||
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
||||
self.pad_out_ubatch_slice(ubatch_slices, num_input_tokens)
|
||||
elif num_tokens_across_dp is not None:
|
||||
num_input_tokens = int(num_tokens_across_dp[dp_rank].item())
|
||||
else:
|
||||
num_input_tokens = self._get_num_input_tokens(
|
||||
scheduler_output.total_num_scheduled_tokens
|
||||
)
|
||||
|
||||
(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
@ -2867,29 +2919,14 @@ class GPUModelRunner(
|
||||
model_kwargs,
|
||||
ec_connector_output,
|
||||
) = self._preprocess(
|
||||
scheduler_output, num_input_tokens, intermediate_tensors
|
||||
)
|
||||
|
||||
uniform_decode = (
|
||||
max_num_scheduled_tokens == self.uniform_decode_query_len
|
||||
) and (num_scheduled_tokens == num_reqs * max_num_scheduled_tokens)
|
||||
batch_desc = BatchDescriptor(
|
||||
num_tokens=num_input_tokens,
|
||||
uniform_decode=uniform_decode,
|
||||
has_lora=len(self.input_batch.lora_id_to_lora_request) > 0,
|
||||
)
|
||||
cudagraph_runtime_mode, batch_descriptor = (
|
||||
self.cudagraph_dispatcher.dispatch(
|
||||
batch_desc,
|
||||
use_cascade_attn=cascade_attn_prefix_lens is not None,
|
||||
)
|
||||
scheduler_output, num_tokens_padded, intermediate_tensors
|
||||
)
|
||||
|
||||
# Set cudagraph mode to none if calc_kv_scales is true.
|
||||
# KV scales calculation involves dynamic operations that are incompatible
|
||||
# with CUDA graph capture.
|
||||
if self.calculate_kv_scales:
|
||||
cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
cudagraph_mode = CUDAGraphMode.NONE
|
||||
# Mark KV scales as calculated after the first forward pass
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
@ -2899,10 +2936,10 @@ class GPUModelRunner(
|
||||
set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens=num_tokens_padded,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
cudagraph_runtime_mode=cudagraph_mode,
|
||||
batch_descriptor=batch_desc,
|
||||
ubatch_slices=ubatch_slices,
|
||||
),
|
||||
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||
@ -2952,7 +2989,7 @@ class GPUModelRunner(
|
||||
if not get_pp_group().is_last_rank:
|
||||
all_gather_tensors = {
|
||||
"residual": not is_residual_scattered_for_sp(
|
||||
self.vllm_config, num_input_tokens
|
||||
self.vllm_config, num_tokens_padded
|
||||
)
|
||||
}
|
||||
get_pp_group().send_tensor_dict(
|
||||
@ -3841,52 +3878,44 @@ class GPUModelRunner(
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32)
|
||||
total_num_scheduled_tokens = int(num_scheduled_tokens.sum())
|
||||
num_tokens_unpadded = int(num_scheduled_tokens.sum())
|
||||
|
||||
num_sampled_tokens = np.ones(num_reqs, dtype=np.int32)
|
||||
|
||||
# Disable DP padding when running eager
|
||||
allow_dp_padding = self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
|
||||
# We currently only microbatch if the number of tokens is
|
||||
# over a certain threshold.
|
||||
ubatch_slices, num_tokens_across_dp = coordinate_batch_across_dp(
|
||||
num_tokens_unpadded=total_num_scheduled_tokens,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
_cudagraph_mode, batch_desc, ubatch_slices, num_tokens_across_dp = (
|
||||
self._determine_batch_execution_and_padding(
|
||||
num_tokens=num_tokens_unpadded,
|
||||
num_reqs=num_reqs,
|
||||
num_scheduled_tokens_np=num_scheduled_tokens,
|
||||
max_num_scheduled_tokens=max_query_len,
|
||||
use_cascade_attn=False,
|
||||
allow_microbatching=allow_microbatching,
|
||||
allow_dp_padding=allow_dp_padding,
|
||||
num_tokens_padded=total_num_scheduled_tokens,
|
||||
uniform_decode=uniform_decode,
|
||||
num_scheduled_tokens_per_request=num_scheduled_tokens,
|
||||
force_eager=is_profile
|
||||
or (cudagraph_runtime_mode == CUDAGraphMode.NONE),
|
||||
# `force_uniform_decode` is used for cudagraph capture; because for
|
||||
# capturing mixed prefill-decode batches, we sometimes use
|
||||
# num_tokens == num_reqs which looks like a uniform decode batch to the
|
||||
# dispatcher; but we actually want to capture a piecewise cudagraph
|
||||
force_uniform_decode=uniform_decode,
|
||||
# `force_has_lora` is used for cudagraph capture; because LoRA is
|
||||
# activated later in the context manager, but we need to know the
|
||||
# LoRA state when determining the batch descriptor for capture
|
||||
force_has_lora=activate_lora,
|
||||
)
|
||||
)
|
||||
num_tokens_after_padding = num_tokens
|
||||
if num_tokens_across_dp is not None:
|
||||
dp_rank = self.parallel_config.data_parallel_rank
|
||||
num_tokens_after_padding = int(num_tokens_across_dp[dp_rank])
|
||||
|
||||
# filter out the valid batch descriptor
|
||||
_cg_mode, batch_descriptor = (
|
||||
self.cudagraph_dispatcher.dispatch(
|
||||
BatchDescriptor(
|
||||
num_tokens=num_tokens_after_padding,
|
||||
uniform_decode=uniform_decode,
|
||||
has_lora=activate_lora and self.lora_config is not None,
|
||||
)
|
||||
)
|
||||
if not is_profile
|
||||
else (CUDAGraphMode.NONE, None)
|
||||
)
|
||||
if cudagraph_runtime_mode is not None:
|
||||
# we allow forcing NONE when the dispatcher disagrees to support
|
||||
# warm ups for cudagraph capture
|
||||
assert (
|
||||
cudagraph_runtime_mode == CUDAGraphMode.NONE
|
||||
or cudagraph_runtime_mode == _cg_mode
|
||||
), (
|
||||
f"Cudagraph runtime mode mismatch at dummy_run. "
|
||||
f"Expected {_cg_mode}, but got {cudagraph_runtime_mode}."
|
||||
)
|
||||
if cudagraph_runtime_mode is None:
|
||||
cudagraph_runtime_mode = _cudagraph_mode
|
||||
else:
|
||||
cudagraph_runtime_mode = _cg_mode
|
||||
assert cudagraph_runtime_mode == _cudagraph_mode, (
|
||||
f"Cudagraph runtime mode mismatch in dummy_run. "
|
||||
f"Expected {_cudagraph_mode}, but got {cudagraph_runtime_mode}."
|
||||
)
|
||||
|
||||
num_tokens_padded = batch_desc.num_tokens
|
||||
num_reqs_padded = (
|
||||
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
|
||||
)
|
||||
|
||||
attn_metadata: PerLayerAttnMetadata | None = None
|
||||
|
||||
@ -3909,9 +3938,9 @@ class GPUModelRunner(
|
||||
self.query_start_loc.copy_to_gpu()
|
||||
|
||||
attn_metadata, _ = self._build_attention_metadata(
|
||||
total_num_scheduled_tokens=num_tokens,
|
||||
max_num_scheduled_tokens=max_query_len,
|
||||
num_reqs=num_reqs,
|
||||
num_tokens=num_tokens_unpadded,
|
||||
num_reqs=num_reqs_padded,
|
||||
max_query_len=max_query_len,
|
||||
ubatch_slices=ubatch_slices,
|
||||
for_cudagraph_capture=True,
|
||||
)
|
||||
@ -3924,29 +3953,29 @@ class GPUModelRunner(
|
||||
remove_lora,
|
||||
):
|
||||
# Make sure padding doesn't exceed max_num_tokens
|
||||
assert num_tokens_after_padding <= self.max_num_tokens
|
||||
model_kwargs = self._init_model_kwargs(num_tokens_after_padding)
|
||||
assert num_tokens_padded <= self.max_num_tokens
|
||||
model_kwargs = self._init_model_kwargs(num_tokens_padded)
|
||||
if self.supports_mm_inputs and not self.model_config.is_encoder_decoder:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding]
|
||||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
|
||||
model_kwargs = {
|
||||
**model_kwargs,
|
||||
**self._dummy_mm_kwargs(num_reqs),
|
||||
}
|
||||
elif self.enable_prompt_embeds:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_after_padding]
|
||||
model_kwargs = self._init_model_kwargs(num_tokens_after_padding)
|
||||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
|
||||
model_kwargs = self._init_model_kwargs(num_tokens_padded)
|
||||
else:
|
||||
input_ids = self.input_ids.gpu[:num_tokens_after_padding]
|
||||
input_ids = self.input_ids.gpu[:num_tokens_padded]
|
||||
inputs_embeds = None
|
||||
|
||||
if self.uses_mrope:
|
||||
positions = self.mrope_positions.gpu[:, :num_tokens_after_padding]
|
||||
positions = self.mrope_positions.gpu[:, :num_tokens_padded]
|
||||
elif self.uses_xdrope_dim > 0:
|
||||
positions = self.xdrope_positions.gpu[:, :num_tokens_after_padding]
|
||||
positions = self.xdrope_positions.gpu[:, :num_tokens_padded]
|
||||
else:
|
||||
positions = self.positions.gpu[:num_tokens_after_padding]
|
||||
positions = self.positions.gpu[:num_tokens_padded]
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
intermediate_tensors = None
|
||||
@ -3961,26 +3990,26 @@ class GPUModelRunner(
|
||||
)
|
||||
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_tokens_after_padding, None, False
|
||||
num_tokens_padded, None, False
|
||||
)
|
||||
|
||||
if ubatch_slices is not None:
|
||||
# Adjust values to reflect a single ubatch.
|
||||
# TODO(sage,lucas): this is cruft that should be addressed in
|
||||
# the padding refactor.
|
||||
num_tokens_after_padding = ubatch_slices[0].num_tokens
|
||||
num_tokens_padded = ubatch_slices[0].num_tokens
|
||||
if num_tokens_across_dp is not None:
|
||||
num_tokens_across_dp[:] = num_tokens_after_padding
|
||||
num_tokens_across_dp[:] = num_tokens_padded
|
||||
|
||||
with (
|
||||
self.maybe_randomize_inputs(input_ids),
|
||||
set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens_after_padding,
|
||||
num_tokens=num_tokens_padded,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=cudagraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
batch_descriptor=batch_desc,
|
||||
ubatch_slices=ubatch_slices,
|
||||
),
|
||||
):
|
||||
@ -4706,8 +4735,7 @@ class GPUModelRunner(
|
||||
|
||||
# Trigger cudagraph dispatching keys initialization after
|
||||
# resolved cudagraph mode.
|
||||
cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
assert cudagraph_mode is not None
|
||||
self.compilation_config.cudagraph_mode = cudagraph_mode
|
||||
self.cudagraph_dispatcher.initialize_cudagraph_keys(
|
||||
cudagraph_mode, self.uniform_decode_query_len
|
||||
)
|
||||
|
||||
@ -8,12 +8,13 @@ from contextlib import AbstractContextManager, nullcontext
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.distributed import (
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
@ -487,6 +488,7 @@ class Worker(WorkerBase):
|
||||
hidden_states, last_hidden_states = self.model_runner._dummy_run(
|
||||
num_tokens=max_num_reqs,
|
||||
skip_eplb=True,
|
||||
cudagraph_runtime_mode=CUDAGraphMode.NONE,
|
||||
)
|
||||
if self.model_runner.is_pooling_model:
|
||||
self.model_runner._dummy_pooler_run(hidden_states)
|
||||
@ -534,12 +536,39 @@ class Worker(WorkerBase):
|
||||
intermediate_tensors = None
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
num_input_tokens = self.model_runner._get_num_input_tokens(num_scheduled_tokens)
|
||||
all_gather_tensors = {}
|
||||
compilation_config = self.vllm_config.compilation_config
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
|
||||
if (
|
||||
parallel_config.pipeline_parallel_size > 1
|
||||
and compilation_config.pass_config.enable_sequence_parallelism
|
||||
and forward_pass
|
||||
):
|
||||
# currently only supported by V1 GPUModelRunner
|
||||
assert isinstance(self.model_runner, GPUModelRunner)
|
||||
num_scheduled_tokens_np = np.array(
|
||||
list(scheduler_output.num_scheduled_tokens.values()),
|
||||
dtype=np.int32,
|
||||
)
|
||||
# TODO(lucas): This is pretty gross; ideally we should only ever call
|
||||
# `_determine_batch_execution_and_padding` once (will get called again
|
||||
# in `execute_model`) but this requires a larger refactor of PP.
|
||||
_, batch_desc, _, _ = (
|
||||
self.model_runner._determine_batch_execution_and_padding(
|
||||
num_tokens=num_scheduled_tokens,
|
||||
num_reqs=len(num_scheduled_tokens_np),
|
||||
num_scheduled_tokens_np=num_scheduled_tokens_np,
|
||||
max_num_scheduled_tokens=num_scheduled_tokens_np.max(),
|
||||
use_cascade_attn=False, # TODO(lucas): Handle cascade attention
|
||||
)
|
||||
)
|
||||
all_gather_tensors = {
|
||||
"residual": not is_residual_scattered_for_sp(
|
||||
self.vllm_config, num_input_tokens
|
||||
self.vllm_config, batch_desc.num_tokens
|
||||
)
|
||||
}
|
||||
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
tensor_dict = get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group(),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user