mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 07:55:01 +08:00
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
126 lines
4.3 KiB
Python
126 lines
4.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
|
|
|
import torch
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.attention.selector import get_env_variable_attn_backend
|
|
from vllm.logger import init_logger
|
|
from vllm.platforms import _Backend, current_platform
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
_C = TypeVar("_C", bound=PretrainedConfig)
|
|
|
|
|
|
class VisionEncoderInfo(ABC, Generic[_C]):
|
|
|
|
def __init__(self, hf_config: _C) -> None:
|
|
super().__init__()
|
|
|
|
self.hf_config = hf_config
|
|
self.vision_config = hf_config.vision_config
|
|
|
|
@abstractmethod
|
|
def get_num_image_tokens(
|
|
self,
|
|
*,
|
|
image_width: int,
|
|
image_height: int,
|
|
) -> int:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_image_size(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_patch_size(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def get_patch_grid_length(self) -> int:
|
|
raise NotImplementedError
|
|
|
|
|
|
class VisionLanguageConfig(Protocol):
|
|
vision_config: Final[PretrainedConfig]
|
|
|
|
|
|
def get_vision_encoder_info(
|
|
hf_config: VisionLanguageConfig) -> VisionEncoderInfo:
|
|
# Avoid circular imports
|
|
from .clip import CLIPEncoderInfo, CLIPVisionConfig
|
|
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
|
|
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
|
|
|
|
if isinstance(hf_config.vision_config, CLIPVisionConfig):
|
|
return CLIPEncoderInfo(hf_config)
|
|
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
|
return PixtralHFEncoderInfo(hf_config)
|
|
if isinstance(hf_config.vision_config, SiglipVisionConfig):
|
|
return SiglipEncoderInfo(hf_config)
|
|
|
|
msg = f"Unsupported vision config: {type(hf_config.vision_config)}"
|
|
raise NotImplementedError(msg)
|
|
|
|
|
|
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
|
|
"""
|
|
Get the available attention backend for Vision Transformer.
|
|
"""
|
|
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
|
|
|
|
selected_backend: Optional[_Backend] = get_env_variable_attn_backend()
|
|
if selected_backend is not None:
|
|
return selected_backend
|
|
|
|
return current_platform.get_vit_attn_backend(support_fa)
|
|
|
|
|
|
def resolve_visual_encoder_outputs(
|
|
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
|
|
feature_sample_layers: Optional[list[int]],
|
|
post_layer_norm: Optional[torch.nn.LayerNorm],
|
|
max_possible_layers: int,
|
|
) -> torch.Tensor:
|
|
"""Given the outputs a visual encoder module that may correspond to the
|
|
output of the last layer, or a list of hidden states to be stacked,
|
|
handle post normalization and resolve it into a single output tensor.
|
|
|
|
Args:
|
|
encoder_outputs: Output of encoder's last layer or all hidden states.
|
|
feature_sample_layers: Optional layer indices to grab from the encoder
|
|
outputs; if provided, encoder outputs must be a list.
|
|
post_layer_norm: Post norm to apply to the output of the encoder.
|
|
max_possible_layers: Total layers in the fully loaded visual encoder.
|
|
|
|
"""
|
|
if feature_sample_layers is None:
|
|
if post_layer_norm is not None:
|
|
return post_layer_norm(encoder_outputs)
|
|
return encoder_outputs
|
|
|
|
# Get the hidden states corresponding to the layer indices.
|
|
# Negative values are relative to the full visual encoder,
|
|
# so offset them depending on how many layers were loaded.
|
|
# NOTE: this assumes that encoder_outputs is a list containing
|
|
# the inputs to the visual encoder, followed by the hidden states
|
|
# of each layer.
|
|
num_loaded_layers = len(encoder_outputs) - 1
|
|
offset = max_possible_layers - num_loaded_layers
|
|
hs_pool = [
|
|
encoder_outputs[layer_idx]
|
|
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
|
|
for layer_idx in feature_sample_layers
|
|
]
|
|
|
|
# Apply post-norm on the final hidden state if we are using it
|
|
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
|
|
if post_layer_norm is not None and uses_last_layer:
|
|
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
|
return torch.cat(hs_pool, dim=-1)
|