[Bugfix] Get a specific type of layer from forward context (#17222)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-04-27 15:58:05 +08:00 committed by GitHub
parent 4283a28c2f
commit 838cedade7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 28 additions and 23 deletions

View File

@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
make_tensor_with_pad) make_tensor_with_pad)
@ -140,12 +140,10 @@ def get_per_layer_parameters(
to use during `plan`. to use during `plan`.
""" """
layers = vllm_config.compilation_config.static_forward_context layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: Dict[str, PerLayerParameters] = {} per_layer_params: Dict[str, PerLayerParameters] = {}
for key, layer in layers.items(): for key, layer in layers.items():
assert isinstance(layer, Attention)
impl = layer.impl impl = layer.impl
assert isinstance(impl, FlashInferImpl) assert isinstance(impl, FlashInferImpl)

View File

@ -3445,7 +3445,8 @@ class CompilationConfig(BaseModel):
compilation_time: float = PrivateAttr compilation_time: float = PrivateAttr
# Per-model forward context # Per-model forward context
# Map from layer name to the attention cls # Map from layer name to layer objects that need to be accessed outside
# model code, e.g., Attention, FusedMOE when dp_size>1.
static_forward_context: dict[str, Any] = PrivateAttr static_forward_context: dict[str, Any] = PrivateAttr
def compute_hash(self) -> str: def compute_hash(self) -> str:
@ -4079,3 +4080,16 @@ def assert_hashable(text):
f"vLLM tried to hash some configs that may have Python objects ids " f"vLLM tried to hash some configs that may have Python objects ids "
f"in them. This is a bug, please file an issue. " f"in them. This is a bug, please file an issue. "
f"Text being hashed: {text}") f"Text being hashed: {text}")
T = TypeVar("T")
def get_layers_from_vllm_config(vllm_config: VllmConfig,
layer_type: type[T]) -> dict[str, T]:
return {
layer_name: layer
for layer_name, layer in
vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, layer_type)
}

View File

@ -14,7 +14,8 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionType) AttentionType)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import (VllmConfig, get_current_vllm_config,
get_layers_from_vllm_config)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention
@ -81,12 +82,10 @@ def get_per_layer_parameters(
to use during `plan`. to use during `plan`.
""" """
layers = vllm_config.compilation_config.static_forward_context layers = get_layers_from_vllm_config(vllm_config, Attention)
per_layer_params: dict[str, PerLayerParameters] = {} per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items(): for key, layer in layers.items():
assert isinstance(layer, Attention)
impl = layer.impl impl = layer.impl
assert isinstance(impl, FlashInferImpl) assert isinstance(impl, FlashInferImpl)

View File

@ -12,13 +12,13 @@ import torch.nn as nn
from vllm.attention import AttentionType, get_attn_backend from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group) has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.distributed.parallel_state import get_pp_group, graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@ -1733,17 +1733,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
format. Layers that do not need KV cache are not included. format. Layers that do not need KV cache are not included.
""" """
forward_ctx = self.vllm_config.compilation_config.static_forward_context layers = get_layers_from_vllm_config(self.vllm_config, Attention)
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items(): for layer_name, attn_module in layers.items():
if isinstance(attn_module, FusedMoE): # TODO: Support other attention modules, e.g., cross-attention
continue
# TODO: Support other attention modules, e.g., sliding window,
# cross-attention
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None: if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec( kv_cache_spec[layer_name] = SlidingWindowSpec(

View File

@ -17,7 +17,7 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
@ -429,11 +429,10 @@ class TPUModelRunner:
format. Layers that do not need KV cache are not included. format. Layers that do not need KV cache are not included.
""" """
forward_ctx = self.vllm_config.compilation_config.static_forward_context layers = get_layers_from_vllm_config(self.vllm_config, Attention)
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items(): for layer_name, attn_module in layers.items():
assert isinstance(attn_module, Attention)
if attn_module.attn_type == AttentionType.DECODER: if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None: if attn_module.sliding_window is not None:
kv_cache_spec[layer_name] = SlidingWindowSpec( kv_cache_spec[layer_name] = SlidingWindowSpec(