Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-05-19 04:47:58 +00:00
parent 980a172474
commit 8293182c8c
4 changed files with 176 additions and 85 deletions

View File

@ -31,6 +31,8 @@ if current_platform.is_cuda():
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata)
from vllm.v1.attention.backends.utils import slice_query_start_locs
logger = init_logger(__name__)
@ -322,23 +324,23 @@ class FlashAttentionMetadataBuilder:
scheduler_output: "SchedulerOutput") -> bool:
return False
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
def build_slice(self, max_query_len: int, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
req_slice: slice,
token_slice: slice) -> FlashAttentionMetadata:
num_reqs = req_slice.stop - req_slice.start
num_tokens = token_slice.stop - token_slice.start
max_seq_len = self.runner.seq_lens_np[req_slice].max()
query_start_loc = slice_query_start_locs(
common_attn_metadata.query_start_loc, req_slice)
seq_lens = common_attn_metadata.seq_lens[req_slice]
block_table = self.block_table
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
block_table_tensor = block_table.get_device_tensor()[req_slice]
block_table.slot_mapping[:num_actual_tokens].copy_(
block_table.slot_mapping_cpu[:num_actual_tokens],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
block_table.slot_mapping[num_actual_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
block_table.slot_mapping[token_slice].copy_(
block_table.slot_mapping_cpu[token_slice], non_blocking=True)
slot_mapping = block_table.slot_mapping[token_slice]
if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
@ -380,8 +382,8 @@ class FlashAttentionMetadataBuilder:
seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \
virt_block_table_tensor = make_local_attention_virtual_batches(
self.runner.attention_chunk_size,
self.runner.query_start_loc_np[:num_reqs + 1],
self.runner.seq_lens_np[:num_reqs],
query_start_loc,
seq_lens,
block_table_tensor,
self.block_size,
)
@ -411,20 +413,20 @@ class FlashAttentionMetadataBuilder:
use_cascade = common_prefix_len > 0
if use_cascade:
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
cu_prefix_query_lens = torch.tensor([0, num_tokens],
dtype=torch.int32,
device=self.runner.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.runner.device)
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
suffix_kv_lens = (self.runner.seq_lens_np[req_slice] -
common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
self.runner.device)
prefix_scheduler_metadata = schedule(
batch_size=1,
cu_query_lens=cu_prefix_query_lens,
max_query_len=num_actual_tokens,
max_query_len=num_tokens,
seqlens=prefix_kv_lens,
max_seq_len=common_prefix_len,
causal=False)
@ -448,7 +450,7 @@ class FlashAttentionMetadataBuilder:
causal=True)
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
@ -466,6 +468,17 @@ class FlashAttentionMetadataBuilder:
)
return attn_metadata
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
return self.build_slice(
max_query_len=max_query_len,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
req_slice=slice(0, num_reqs),
token_slice=slice(0, num_actual_tokens),
)
def use_cascade_attention(self, *args, **kwargs) -> bool:
return use_cascade_attention(*args, **kwargs)

View File

@ -16,3 +16,11 @@ class CommonAttentionMetadata:
seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
def slice_query_start_locs(
query_start_loc: torch.Tensor,
req_slice: slice,
) -> torch.Tensor:
return query_start_loc[req_slice.start: req_slice.stop + 1] -\
query_start_loc[req_slice.start]

View File

@ -4,7 +4,7 @@ import copy
import gc
import time
import weakref
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional, TypeAlias, Union
import numpy as np
import torch
@ -69,6 +69,11 @@ else:
logger = init_logger(__name__)
AttnMetadataDict: TypeAlias = dict[str, FlashAttentionMetadata]
# list when ubatching is enabled
PerLayerAttnMetadata: TypeAlias = Union[list[AttnMetadataDict],
AttnMetadataDict]
class GPUModelRunner(LoRAModelRunnerMixin):
@ -491,7 +496,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor,
ubatch_slices: Optional[list[tuple[
slice, slice]]] = None, # req_slice, token_slice
) -> tuple[PerLayerAttnMetadata, torch.Tensor,
Optional[SpecDecodeMetadata]]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
@ -612,7 +619,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, seq_lens=seq_lens)
attn_metadata: dict[str, FlashAttentionMetadata] = {}
attn_metadata: PerLayerAttnMetadata = {}
if ubatch_slices is not None:
attn_metadata = [dict() for _ in range(len(ubatch_slices))]
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
@ -629,15 +639,37 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.attn_metadata_builders[kv_cache_group_id],
)
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata))
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode.
if self.vllm_config.compilation_config.full_cuda_graph:
self.input_batch.block_table[kv_cache_group_id]\
.slot_mapping.fill_(-1)
if ubatch_slices is not None:
for ubid, (req_slice, token_slice) in enumerate(ubatch_slices):
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].
build_slice(
max(tokens[req_slice]),
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
req_slice=req_slice,
token_slice=token_slice,
))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i
else:
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].build(
num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is dict
attn_metadata[layer_name] = attn_metadata_i
use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
@ -1057,10 +1089,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
def sync_and_slice_intermediate_tensors(
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
self, tokens_slice: slice,
intermediate_tensors: IntermediateTensors,
sync_self: bool) -> IntermediateTensors:
assert self.intermediate_tensors is not None
num_tokens = tokens_slice.stop - tokens_slice.start
tp = self.vllm_config.parallel_config.tensor_parallel_size
enabled_sp = self.vllm_config.compilation_config.pass_config. \
@ -1072,49 +1106,37 @@ class GPUModelRunner(LoRAModelRunnerMixin):
is_residual_scattered = tp > 1 and enabled_sp \
and num_tokens % tp == 0
def copy_slice(is_scattered: bool) -> slice:
if is_scattered:
return slice(tokens_slice.start // tp, tokens_slice.stop // tp)
else:
return tokens_slice
# When sequence parallelism is enabled, the "residual" tensor is sharded
# across tensor parallel ranks, so each rank only needs its own slice.
if sync_self:
assert intermediate_tensors is not None
for k, v in intermediate_tensors.items():
is_scattered = "residual" and is_residual_scattered
copy_len = num_tokens // tp if is_scattered else \
num_tokens
self.intermediate_tensors[k][:copy_len].copy_(
v[:copy_len], non_blocking=True)
_copy_slice = copy_slice(is_scattered)
self.intermediate_tensors[k][_copy_slice].copy_(
v[_copy_slice], non_blocking=True)
return IntermediateTensors({
k:
v[:num_tokens // tp]
if k == "residual" and is_residual_scattered else v[:num_tokens]
v[copy_slice(k == "residual" and is_residual_scattered)]
for k, v in self.intermediate_tensors.items()
})
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
def _get_model_inputs(self, tokens_slice: slice,
scheduler_output: "SchedulerOutput"):
num_tokens = tokens_slice.stop - tokens_slice.start
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
and num_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
tokens_slice = \
slice(tokens_slice.start, self.vllm_config.pad_for_cudagraph(num_tokens))
else:
# Eager mode.
# Pad tokens to multiple of tensor_parallel_size when
@ -1123,9 +1145,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.vllm_config.compilation_config.pass_config. \
enable_sequence_parallelism and tp_size > 1:
from vllm.utils import round_up
num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else:
num_input_tokens = num_scheduled_tokens
tokens_slice = slice(tokens_slice.start,
round_up(num_tokens, tp_size))
num_tokens = tokens_slice.stop - tokens_slice.start
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
@ -1140,51 +1162,94 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens]
input_ids = self.input_ids[tokens_slice]
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens]
self.inputs_embeds[tokens_slice].copy_(inputs_embeds)
inputs_embeds = self.inputs_embeds[tokens_slice]
input_ids = None
else:
# For text-only models, we use token ids as input.
# While it is possible to use embeddings as input just like the
# multimodal models, it is not desirable for performance since
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids[:num_input_tokens]
input_ids = self.input_ids[tokens_slice]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_input_tokens]
positions = self.mrope_positions[:, tokens_slice]
else:
positions = self.positions[:num_input_tokens]
positions = self.positions[tokens_slice]
if get_pp_group().is_first_rank:
intermediate_tensors = None
else:
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_input_tokens, intermediate_tensors, True)
tokens_slice, intermediate_tensors, True)
return input_ids, positions, inputs_embeds, intermediate_tensors
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, IntermediateTensors]:
ubatch_slices: Optional[list[tuple[slice, slice]]] = None
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
return self.kv_connector_no_forward(scheduler_output)
# Prepare the decoder inputs.
attn_metadata, logits_indices, spec_decode_metadata = (
self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
self.maybe_setup_kv_connector(scheduler_output)
self.maybe_setup_kv_connector(scheduler_output)
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
if ubatch_slices is not None:
for i, (_, tokens_slice) in enumerate(ubatch_slices):
input_ids, positions, inputs_embeds, intermediate_tensors = \
self._get_model_inputs(tokens_slice, scheduler_output)
num_input_token = tokens_slice.stop - tokens_slice.start
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
with set_forward_context(attn_metadata[i],
self.vllm_config,
num_tokens=num_input_tokens):
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
else:
input_ids, positions, inputs_embeds, intermediate_tensors = \
self._get_model_inputs(slice(0, num_scheduled_tokens),
scheduler_output)
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_scheduled_tokens):
model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output

View File

@ -0,0 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
class UBatchContext:
def __init__(self, ubatch_id: int):
self.ubatch_id = ubatch_id