diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 889c4eb9d8e6..d92177d58a48 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, is_block_tables_empty) from vllm.attention.layer import Attention from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -140,12 +140,10 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(vllm_config, Attention) per_layer_params: Dict[str, PerLayerParameters] = {} for key, layer in layers.items(): - assert isinstance(layer, Attention) - impl = layer.impl assert isinstance(impl, FlashInferImpl) diff --git a/vllm/config.py b/vllm/config.py index a9f39ecceac0..fcbf962ac685 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3445,7 +3445,8 @@ class CompilationConfig(BaseModel): compilation_time: float = PrivateAttr # Per-model forward context - # Map from layer name to the attention cls + # Map from layer name to layer objects that need to be accessed outside + # model code, e.g., Attention, FusedMOE when dp_size>1. static_forward_context: dict[str, Any] = PrivateAttr def compute_hash(self) -> str: @@ -4079,3 +4080,16 @@ def assert_hashable(text): f"vLLM tried to hash some configs that may have Python objects ids " f"in them. This is a bug, please file an issue. " f"Text being hashed: {text}") + + +T = TypeVar("T") + + +def get_layers_from_vllm_config(vllm_config: VllmConfig, + layer_type: type[T]) -> dict[str, T]: + return { + layer_name: layer + for layer_name, layer in + vllm_config.compilation_config.static_forward_context.items() + if isinstance(layer, layer_type) + } diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 17341ecfa4fe..bce446bd2b82 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -14,7 +14,8 @@ import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import (VllmConfig, get_current_vllm_config, + get_layers_from_vllm_config) from vllm.logger import init_logger from vllm.v1.attention.backends.flash_attn import use_cascade_attention @@ -81,12 +82,10 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(vllm_config, Attention) per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): - assert isinstance(layer, Attention) - impl = layer.impl assert isinstance(impl, FlashInferImpl) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index b9b4ce4d19ac..5775701b941e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,13 +12,13 @@ import torch.nn as nn from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention -from vllm.config import CompilationLevel, VllmConfig +from vllm.config import (CompilationLevel, VllmConfig, + get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import get_model from vllm.multimodal import MULTIMODAL_REGISTRY @@ -1733,17 +1733,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - if isinstance(attn_module, FusedMoE): - continue - - # TODO: Support other attention modules, e.g., sliding window, - # cross-attention - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): + # TODO: Support other attention modules, e.g., cross-attention if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 98b0ddcccb5d..67f8af29db0e 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -17,7 +17,7 @@ import vllm.envs as envs from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import VllmConfig +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -429,11 +429,10 @@ class TPUModelRunner: format. Layers that do not need KV cache are not included. """ - forward_ctx = self.vllm_config.compilation_config.static_forward_context + layers = get_layers_from_vllm_config(self.vllm_config, Attention) block_size = self.vllm_config.cache_config.block_size kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in forward_ctx.items(): - assert isinstance(attn_module, Attention) + for layer_name, attn_module in layers.items(): if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec(