[Model] Move vision_feature_select_strategy into resolve_visual_encoder_outputs (#25938)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-09-30 19:24:57 +08:00 committed by GitHub
parent ef6e0e7132
commit d7e34b4210
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 155 additions and 179 deletions

View File

@ -18,7 +18,7 @@ from vllm.utils import get_open_port, update_environment_variables
@pytest.mark.parametrize(
("feature_sample_layers", "num_layers_loaded", "max_possible_layers",
("select_layers", "num_layers_loaded", "max_possible_layers",
"expected_features"),
[
# All layers loaded
@ -28,8 +28,8 @@ from vllm.utils import get_open_port, update_environment_variables
([1, 10], 10, 20, [1, 10]),
([-20, -11], 10, 20, [1, 10]),
])
def test_resolve_visual_encoder_outputs(feature_sample_layers,
num_layers_loaded, max_possible_layers,
def test_resolve_visual_encoder_outputs(select_layers, num_layers_loaded,
max_possible_layers,
expected_features):
"""
Test that offsets are correctly handled for vision feature layers.
@ -39,9 +39,10 @@ def test_resolve_visual_encoder_outputs(feature_sample_layers,
]
output_tensor = resolve_visual_encoder_outputs(
encoder_outputs=encoder_outputs,
feature_sample_layers=feature_sample_layers,
post_layer_norm=None,
max_possible_layers=max_possible_layers)
select_layers=select_layers,
max_possible_layers=max_possible_layers,
)
assert torch.equal(torch.tensor(expected_features), output_tensor)

View File

@ -27,7 +27,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -350,29 +349,11 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self,
vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
**kwargs,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
target_dtype: torch.dtype = \
vision_tower.get_input_embeddings().weight.dtype
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values.to(dtype=target_dtype), **kwargs)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
leaf,
strategy=self.config.vision_feature_select_strategy,
)
return json_map_leaves(select_features, image_features)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
return vision_tower(
pixel_values.to(dtype=vision_tower.dtype),
feature_select_strategy=self.config.vision_feature_select_strategy,
)
def _process_image_input(self, image_input: AyaVisionImagePixelInputs,
**kwargs) -> list[torch.Tensor]:

View File

@ -19,7 +19,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsQuant
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
resolve_visual_encoder_outputs)
class CLIPEncoderInfo(VisionEncoderInfo[CLIPVisionConfig]):
@ -308,24 +309,29 @@ class CLIPVisionTransformer(nn.Module):
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
*,
select_layers: Optional[list[int]] = None,
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
return_all_hidden_states = feature_sample_layers is not None
# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
# depending on if we have select_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states)
return_all_hidden_states=select_layers is not None,
)
# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)
encoder_outputs,
self.post_layernorm,
select_layers=select_layers,
max_possible_layers=self.config.num_hidden_layers,
feature_select_strategy=feature_select_strategy,
)
return encoder_outputs
@ -355,9 +361,14 @@ class CLIPVisionModel(nn.Module, SupportsQuant):
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
select_layers: Optional[list[int]] = None,
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
) -> torch.Tensor:
return self.vision_model(pixel_values, feature_sample_layers)
return self.vision_model(
pixel_values,
select_layers=select_layers,
feature_select_strategy=feature_select_strategy,
)
@property
def device(self):

View File

@ -33,7 +33,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
@ -604,16 +603,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
@ -622,16 +611,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
leaf,
strategy=self.config.vision_feature_select_strategy,
)
return json_map_leaves(select_features, image_features)
return vision_tower(
pixel_values,
feature_select_strategy=self.config.vision_feature_select_strategy,
)
def _process_image_pixels(
self,

View File

@ -235,12 +235,12 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
# Determine the layer up to which we will initialize the vision tower
if isinstance(vision_feature_layer, int):
vision_hidden_size = config.vision_config.hidden_size
self.feature_sample_layers = None
self.select_layers = None
# Used for multimodal granite models to control encoder outputs
elif isinstance(vision_feature_layer, (list, tuple)):
vision_hidden_size = config.vision_config.hidden_size * len(
vision_feature_layer)
self.feature_sample_layers = vision_feature_layer
self.select_layers = vision_feature_layer
else:
raise TypeError(
f"vision_layer_feature type: {type(vision_feature_layer)}"
@ -312,30 +312,17 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(
pixel_values, feature_sample_layers=self.feature_sample_layers)
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
return vision_tower(
pixel_values,
select_layers=self.select_layers,
feature_select_strategy=self.config.vision_feature_select_strategy,
)
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py

View File

@ -349,27 +349,16 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
"w": expected_w,
})
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _video_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
image_features = self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
image_features = vision_tower(
pixel_values,
feature_select_strategy=self.config.vision_feature_select_strategy,
)
image_features = self.vision_resampler(image_features)
image_features = self.multi_modal_projector(image_features)

View File

@ -577,27 +577,16 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return mm_input_by_modality
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
return vision_tower(
pixel_values,
feature_select_strategy=self.config.vision_feature_select_strategy,
)
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
@ -750,13 +739,11 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
video_features = vision_tower(pixel_values)
video_features = self._select_image_features(
video_features,
strategy=self.config.vision_feature_select_strategy,
video_features = vision_tower(
pixel_values,
feature_select_strategy=self.config.vision_feature_select_strategy,
)
video_features = self.multi_modal_projector(video_features)
video_features = self.apply_pooling(video_features)

View File

@ -17,7 +17,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
@ -221,15 +220,6 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
@ -238,16 +228,10 @@ class MiniMaxVL01ForConditionalGeneration(nn.Module, SupportsMultiModal,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features: tuple[torch.Tensor, ...] = \
tuple(vision_tower(p) for p in pixel_values)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
leaf,
strategy=self.config.vision_feature_select_strategy,
)
return json_map_leaves(select_features, image_features)
feature_select_strategy = self.config.vision_feature_select_strategy
return tuple(
vision_tower(p, feature_select_strategy=feature_select_strategy)
for p in pixel_values)
# adapted from https://huggingface.co/MiniMaxAI/MiniMax-VL-01/blob/main/modeling_minimax_vl_01.py#L616-L631
def pack_image_features(self, image_features: list[torch.Tensor],

View File

@ -51,7 +51,8 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
resolve_visual_encoder_outputs)
try:
from xformers import ops as xops
@ -1218,7 +1219,9 @@ class PixtralHFVisionModel(nn.Module):
def forward(
self,
pixel_values: list[torch.Tensor],
feature_sample_layers: Optional[list[int]] = None,
*,
select_layers: Optional[list[int]] = None,
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
) -> tuple[torch.Tensor, ...]:
"""
Args:
@ -1226,7 +1229,7 @@ class PixtralHFVisionModel(nn.Module):
in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images,
each with their own shape potentially
feature_sample_layers: Layer indices whose features should be
select_layers: Layer indices whose features should be
concatenated and used as the visual encoder output. If none
are provided, the last layer is used.
@ -1267,15 +1270,20 @@ class PixtralHFVisionModel(nn.Module):
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds)
return_all_hidden_states = feature_sample_layers is not None
out = self.transformer(
patch_embeds,
attention_mask,
position_embedding,
return_all_hidden_states=return_all_hidden_states)
return_all_hidden_states=select_layers is not None,
)
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers)
out = resolve_visual_encoder_outputs(
out,
None,
select_layers=select_layers,
max_possible_layers=self.config.num_hidden_layers,
feature_select_strategy=feature_select_strategy,
)
# squeeze dim 0 and split into separate tensors for each image
return torch.split(out.squeeze(0), embed_sizes)

View File

@ -23,7 +23,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
resolve_visual_encoder_outputs)
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
@ -415,28 +416,31 @@ class SiglipVisionTransformer(nn.Module):
def forward(
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = True,
feature_sample_layers: Optional[list[int]] = None,
*,
interpolate_pos_encoding: bool = False,
select_layers: Optional[list[int]] = None,
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(
pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
return_all_hidden_states = feature_sample_layers is not None
# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
# depending on if we have select_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states,
return_all_hidden_states=select_layers is not None,
)
# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)
encoder_outputs,
self.post_layernorm,
select_layers=select_layers,
max_possible_layers=self.config.num_hidden_layers,
feature_select_strategy=feature_select_strategy,
)
# TODO: add this back when pooled_output is used in inference.
# if self.use_head:
@ -471,16 +475,22 @@ class SiglipVisionModel(nn.Module):
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@property
def dtype(self):
return self.get_input_embeddings().weight.dtype
def forward(
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = False,
feature_sample_layers: Optional[list[int]] = None,
select_layers: Optional[list[int]] = None,
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
) -> torch.Tensor:
return self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
feature_sample_layers=feature_sample_layers,
select_layers=select_layers,
feature_select_strategy=feature_select_strategy,
)
def load_weights(self, weights: Iterable[tuple[str,

View File

@ -33,7 +33,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
@ -476,30 +475,16 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: Union[torch.Tensor, list[torch.Tensor]],
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
# From vLLM LLaVA, vision tower output handling
image_hidden_states: Union[torch.Tensor, tuple[torch.Tensor, ...]] = \
vision_tower(pixel_values)
def select_features_fn(leaf: torch.Tensor):
return self._select_image_features(
leaf,
strategy=self.config.vision_feature_select_strategy,
)
return json_map_leaves(select_features_fn, image_hidden_states)
return vision_tower(
pixel_values,
feature_select_strategy=self.config.vision_feature_select_strategy,
)
def _add_tarsier_split_tokens(
self, projected_image_features: torch.Tensor) -> torch.Tensor:

View File

@ -4,10 +4,12 @@
import itertools
import math
from abc import ABC, abstractmethod
from typing import Final, Generic, Literal, Optional, Protocol, TypeVar, Union
from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
TypeVar, Union)
import torch
from transformers import PretrainedConfig
from typing_extensions import assert_never
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@ -86,11 +88,39 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
return current_platform.get_vit_attn_backend(head_size, dtype)
VisionFeatureSelectStrategy = Union[
Literal["class", "default", "full"],
Callable[[torch.Tensor], torch.Tensor],
]
def _get_vision_feature_selector(
strategy: VisionFeatureSelectStrategy,
) -> Callable[[torch.Tensor], torch.Tensor]:
if callable(strategy):
return strategy
# https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762
if strategy == "class":
return lambda feats: feats[:, 0, :]
# https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196
if strategy == "default":
return lambda feats: feats[:, 1:, :]
if strategy == "full":
return lambda feats: feats
assert_never(strategy)
def resolve_visual_encoder_outputs(
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
feature_sample_layers: Optional[list[int]],
post_layer_norm: Optional[torch.nn.LayerNorm],
max_possible_layers: int,
*,
select_layers: Optional[list[int]] = None,
max_possible_layers: Optional[int] = None,
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
) -> torch.Tensor:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
@ -98,17 +128,32 @@ def resolve_visual_encoder_outputs(
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
select_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
max_possible_layers: Total layers in the fully loaded visual encoder.
feature_select_strategy: Defines how to select the hidden states
from each layer.
"""
if feature_sample_layers is None:
if select_layers is None:
if not isinstance(encoder_outputs, torch.Tensor):
raise ValueError("Expected only a single encoder output when "
"`select_layers` is not provided")
if feature_select_strategy is not None:
select_features = _get_vision_feature_selector(
feature_select_strategy)
encoder_outputs = select_features(encoder_outputs)
if post_layer_norm is not None:
return post_layer_norm(encoder_outputs)
return encoder_outputs
if max_possible_layers is None:
raise ValueError("`max_possible_layers` must be provided "
"alongside `select_layers`")
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
@ -120,13 +165,18 @@ def resolve_visual_encoder_outputs(
hs_pool = [
encoder_outputs[layer_idx]
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
for layer_idx in feature_sample_layers
for layer_idx in select_layers
]
if feature_select_strategy is not None:
select_features = _get_vision_feature_selector(feature_select_strategy)
hs_pool = [select_features(hs) for hs in hs_pool]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
uses_last_layer = select_layers[-1] in (max_possible_layers - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
hs_pool[-1] = post_layer_norm(hs_pool[-1])
return torch.cat(hs_pool, dim=-1)