diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py new file mode 100644 index 000000000000..d64c0e6d4e43 --- /dev/null +++ b/tests/models/test_vision.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from vllm.model_executor.models.vision import resolve_visual_encoder_outputs + + +@pytest.mark.parametrize( + ("feature_sample_layers", "num_layers_loaded", "max_possible_layers", + "expected_features"), + [ + # All layers loaded + ([1, 10], 10, 10, [1, 10]), + ([-10, -1], 10, 10, [1, 10]), + # Some layers not loaded + ([1, 10], 10, 20, [1, 10]), + ([-20, -11], 10, 20, [1, 10]), + ]) +def test_resolve_visual_encoder_outputs(feature_sample_layers, + num_layers_loaded, max_possible_layers, + expected_features): + """ + Test that offsets are correctly handled for vision feature layers. + """ + encoder_outputs = [ + torch.tensor([idx]) for idx in range(num_layers_loaded + 1) + ] + output_tensor = resolve_visual_encoder_outputs( + encoder_outputs=encoder_outputs, + feature_sample_layers=feature_sample_layers, + post_layer_norm=None, + max_possible_layers=max_possible_layers) + assert torch.equal(torch.tensor(expected_features), output_tensor) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 73c109a27ac7..dc3aa9cbe86b 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -251,7 +251,7 @@ class CLIPEncoder(nn.Module): def forward( self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool ) -> Union[torch.Tensor, list[torch.Tensor]]: - hidden_states_pool = [] + hidden_states_pool = [inputs_embeds] hidden_states = inputs_embeds for encoder_layer in self.layers: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index dcd90474e936..6a4277adb6bf 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -428,7 +428,7 @@ def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: - """Given an signed vision feature layer, get the number of hidden layers + """Given a signed vision feature layer, get the number of hidden layers needed to leverage it. Args: @@ -438,7 +438,7 @@ def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: """ if feature_layer_index < 0: return num_hidden_layers + feature_layer_index + 1 - return feature_layer_index + 1 + return feature_layer_index def init_vision_tower_for_llava( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index e78e8d62cc47..44fca852805a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -969,7 +969,7 @@ class PixtralHFTransformer(nn.Module): position_embeddings: torch.Tensor, return_all_hidden_states: bool, ) -> torch.Tensor: - hidden_states_pool = [] + hidden_states_pool = [x] for layer in self.layers: x = layer(x, attention_mask, position_embeddings) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index ddae78d7739e..2892f696107b 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -378,7 +378,7 @@ class SiglipEncoder(nn.Module): inputs_embeds: torch.Tensor, return_all_hidden_states: bool, ) -> Union[torch.Tensor, list[torch.Tensor]]: - hidden_states_pool = [] + hidden_states_pool = [inputs_embeds] hidden_states = inputs_embeds for encoder_layer in self.layers: diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 0d67ee7bb5dd..9a6fac2eec56 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -132,10 +132,11 @@ def resolve_visual_encoder_outputs( # Get the hidden states corresponding to the layer indices. # Negative values are relative to the full visual encoder, # so offset them depending on how many layers were loaded. - # NOTE: this assumes that encoder_outputs contains a list - # of hidden states in the same order as the encoder layers - # that produced them. - offset = max_possible_layers - len(encoder_outputs) + # NOTE: this assumes that encoder_outputs is a list containing + # the inputs to the visual encoder, followed by the hidden states + # of each layer. + num_loaded_layers = len(encoder_outputs) - 1 + offset = max_possible_layers - num_loaded_layers hs_pool = [ encoder_outputs[layer_idx] if layer_idx >= 0 else encoder_outputs[layer_idx + offset]