diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index aa218cc37af96..9e4fbe0b4c6c2 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -210,6 +210,8 @@ class Attention(nn.Module): if self.use_direct_call: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, query, @@ -226,6 +228,8 @@ class Attention(nn.Module): if self.use_direct_call: forward_context = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.layer_name] self_kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(self, query, key, value, self_kv_cache, attn_metadata) @@ -343,7 +347,7 @@ def wait_for_kv_layer_from_connector(layer_name: str): attn_metadata = forward_context.attn_metadata if attn_metadata is None: return - + assert isinstance(attn_metadata, dict) connector.wait_for_layer_load(layer_name) @@ -360,8 +364,9 @@ def maybe_save_kv_layer_to_connector( attn_metadata = forward_context.attn_metadata if attn_metadata is None: return - - connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) + assert isinstance(attn_metadata, dict) + connector.save_kv_layer(layer_name, kv_cache_layer, + attn_metadata[layer_name]) def unified_attention( @@ -374,6 +379,8 @@ def unified_attention( forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] output = self.impl.forward(self, query, key, value, kv_cache, @@ -411,6 +418,8 @@ def unified_attention_with_output( wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index c75d8f088c5ba..9ddc3d1f2c518 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -4,7 +4,7 @@ import time from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union import torch import torch.distributed as dist @@ -38,8 +38,13 @@ class DPMetadata: class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] - # TODO: extend to support per-layer dynamic forward context - attn_metadata: "AttentionMetadata" # set dynamically for each forward pass + """ + Type AttentionMetadata for v0, + Type Dict[str, AttentionMetadata] for v1, map from layer_name of each + attention layer to its attention metadata + set dynamically for each forward pass + """ + attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index f986d797f2b03..db79269021548 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -18,6 +18,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -309,13 +310,11 @@ class FlashAttentionMetadataBuilder: return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): max_seq_len = self.runner.seq_lens_np[:num_reqs].max() - query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1] - query_start_loc = query_start_loc_cpu.to(self.runner.device, - non_blocking=True) - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - seq_lens = seq_lens_cpu.to(self.runner.device, non_blocking=True) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 6e964b471faef..0852e15f9c19c 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -18,6 +18,7 @@ from vllm.config import (VllmConfig, get_current_vllm_config, get_layers_from_vllm_config) from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention +from vllm.v1.attention.backends.utils import CommonAttentionMetadata if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -394,16 +395,15 @@ class FlashInferMetadataBuilder: ) def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): assert self._num_decodes + self._num_prefills == num_reqs assert (self._num_decode_tokens + self._num_prefill_tokens == num_actual_tokens) page_size = self.runner.block_size device = self.runner.device - qo_indptr = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - self.runner.device, non_blocking=True) - seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device, - non_blocking=True) + qo_indptr = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 8b1875e7356b2..0d18a5639c2aa 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -207,6 +207,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down +from vllm.v1.attention.backends.utils import CommonAttentionMetadata try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -451,7 +452,8 @@ class MLACommonMetadataBuilder(Generic[M]): ) def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int) -> M: + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata) -> M: assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this @@ -460,15 +462,13 @@ class MLACommonMetadataBuilder(Generic[M]): device = self.runner.device block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( - device, non_blocking=True) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True).long() input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() - seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs] - seq_lens = seq_lens_cpu.to(device, non_blocking=True) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens prefill_metadata = None if self._num_prefills > 0: diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py new file mode 100644 index 0000000000000..10a771e830b68 --- /dev/null +++ b/vllm/v1/attention/backends/utils.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass + +import torch + + +@dataclass +class CommonAttentionMetadata: + """ + Attention metadata attributes that can be shared by layers in different KV + cache groups and thus having different block table. + """ + + query_start_loc: torch.Tensor + """(batch_size + 1,), the start location of each request in query Tensor""" + seq_lens: torch.Tensor + """(batch_size,), the length of each request including both computed tokens + and newly scheduled tokens""" diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6d71743c5e365..2293410e73cd1 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -2,7 +2,9 @@ import torch import torch.nn as nn -from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config +from vllm.attention.layer import Attention +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config, set_current_vllm_config) from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader.loader import get_model_loader @@ -276,6 +278,8 @@ class EagleProposer: loader = get_model_loader(self.vllm_config.load_config) target_layer_num = self.vllm_config.model_config.get_num_layers( self.vllm_config.parallel_config) + target_attn_layer_names = set( + get_layers_from_vllm_config(self.vllm_config, Attention).keys()) draft_model_config = \ self.vllm_config.speculative_config.draft_model_config @@ -292,6 +296,11 @@ class EagleProposer: vllm_config=self.vllm_config, start_layer_id=target_layer_num).to(target_device) + draft_attn_layer_names = ( + get_layers_from_vllm_config(self.vllm_config, Attention).keys() - + target_attn_layer_names) + assert len(draft_attn_layer_names) == 1 + self.attn_layer_name = next(iter(draft_attn_layer_names)) loaded_weights = self.model.load_weights( loader.get_all_weights(draft_model_config, self.model)) if self.vllm_config.speculative_config.method == "eagle3": diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8137cb6b9b6bb..e0c3d05c7976d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -30,6 +30,7 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, LayerBlockType, LazyLoader, cdiv, check_use_alibi, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, @@ -157,9 +158,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Sampler self.sampler = Sampler() - # Lazy initialization + # Lazy initializations # self.model: nn.Module # Set after load_model + # Initialize in initialize_kv_cache self.kv_caches: list[torch.Tensor] = [] + # self.kv_cache_config: KVCacheConfig + # req_id -> (input_id -> encoder_output) self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} @@ -488,7 +492,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[FlashAttentionMetadata, torch.Tensor, + ) -> tuple[dict[str, FlashAttentionMetadata], torch.Tensor, Optional[SpecDecodeMetadata]]: total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -585,20 +589,39 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - # Prepare for cascade attention if enabled & beneficial. - common_prefix_len = 0 - if self.cascade_attn_enabled: - common_prefix_len = self._compute_cascade_attn_prefix_len( - num_scheduled_tokens, - scheduler_output.num_common_prefix_blocks, - ) + query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( + self.device, non_blocking=True) + seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, + non_blocking=True) + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, seq_lens=seq_lens) - attn_metadata = self.attn_metadata_builder.build( - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - common_prefix_len=common_prefix_len, - ) + attn_metadata: dict[str, FlashAttentionMetadata] = {} + # Prepare the attention metadata for each KV cache group and make layers + # in the same group share the same metadata. + # NOTE(Chen): there is exactly one KV cache group that contains all + # attetnion layers in the model for now, so the current logic for + # getting attn_metadata is not related to kv_cache_group information. + # Will extend this part to support multiple KV cache groups later. + for kv_cache_group_id, kv_cache_group_spec in enumerate( + self.kv_cache_config.kv_cache_groups): + + # Prepare for cascade attention if enabled & beneficial. + common_prefix_len = 0 + if self.cascade_attn_enabled: + common_prefix_len = self._compute_cascade_attn_prefix_len( + num_scheduled_tokens, + scheduler_output.num_common_prefix_blocks, + ) + + attn_metadata_i = self.attn_metadata_builder.build( + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + common_prefix_len=common_prefix_len, + common_attn_metadata=common_attn_metadata) + for layer_name in kv_cache_group_spec.layer_names: + attn_metadata[layer_name] = attn_metadata_i use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -608,7 +631,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # from these partial requests, we do so for simplicity. # We will ignore the sampled tokens from the partial requests. # TODO: Support prompt logprobs. - logits_indices = attn_metadata.query_start_loc[1:] - 1 + logits_indices = query_start_loc[1:] - 1 spec_decode_metadata = None else: # Get the number of draft tokens for each request. @@ -1230,6 +1253,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) + eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name] if spec_decode_metadata is None: # input_ids can be None for multimodal models. @@ -1241,8 +1265,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): dim=-1) else: target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = attn_metadata.slot_mapping - cu_num_tokens = attn_metadata.query_start_loc + target_slot_mapping = eagle_attn_metadata.slot_mapping + cu_num_tokens = eagle_attn_metadata.query_start_loc else: # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens @@ -1256,7 +1280,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): device=self.device, ) cu_num_tokens, token_indices = self.drafter.prepare_inputs( - attn_metadata.query_start_loc, + eagle_attn_metadata.query_start_loc, num_rejected_tokens, ) target_token_ids = self.input_ids[token_indices] @@ -1266,7 +1290,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] - target_slot_mapping = attn_metadata.slot_mapping[token_indices] + target_slot_mapping = eagle_attn_metadata.slot_mapping[ + token_indices] draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, @@ -1275,7 +1300,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, cu_num_tokens=cu_num_tokens, - block_table=attn_metadata.block_table, + block_table=eagle_attn_metadata.block_table, sampling_metadata=sampling_metadata, ) spec_token_ids = draft_token_ids.tolist() @@ -1708,6 +1733,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): raise NotImplementedError( "Hybrid models with more than one KV cache type are not " "supported yet.") + self.kv_cache_config = kv_cache_config kv_caches: dict[str, torch.Tensor] = {} diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 8e162d5170d69..f5626abb2a122 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -588,7 +588,14 @@ class TPUModelRunner: # Padded to avoid recompiling when `num_reqs` varies. logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) - return attn_metadata, logits_indices, padded_num_reqs + + layer_names = get_layers_from_vllm_config(self.vllm_config, + Attention).keys() + per_layer_attn_metadata = { + layer_name: attn_metadata + for layer_name in layer_names + } + return per_layer_attn_metadata, logits_indices, padded_num_reqs def _scatter_placeholders( self, @@ -956,7 +963,14 @@ class TPUModelRunner: torch._dynamo.mark_dynamic(position_ids, 0) torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - with set_forward_context(attn_metadata, self.vllm_config, 0): + layer_names = get_layers_from_vllm_config(self.vllm_config, + Attention).keys() + per_layer_attn_metadata = { + layer_name: attn_metadata + for layer_name in layer_names + } + + with set_forward_context(per_layer_attn_metadata, self.vllm_config, 0): out = self.model(input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds)