[Misc] Move some model utils into vision file (#11848)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-09 01:04:46 +08:00 committed by GitHub
parent 78f4590b60
commit ca47e176af
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 94 additions and 92 deletions

View File

@ -20,11 +20,10 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData
from .vision import VisionEncoderInfo
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:

View File

@ -31,14 +31,13 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
resolve_visual_encoder_outputs)
consecutive_placeholder_ranges)
from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import VisionEncoderInfo
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
try:
from xformers import ops as xops

View File

@ -66,8 +66,9 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
from .vision import get_vit_attn_backend
logger = init_logger(__name__)

View File

@ -24,11 +24,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData
from .vision import VisionEncoderInfo
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:

View File

@ -8,16 +8,12 @@ import torch.nn as nn
from torch.func import functional_call
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available, print_warning_once
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@ -612,37 +608,6 @@ def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
return make_empty_intermediate_tensors
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
"""
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
print_warning_once(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
selected_backend = _Backend.XFORMERS
elif current_platform.is_cpu() or current_platform.is_rocm():
# ROCM doesn't support xformers
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS
return selected_backend
def maybe_prefix(prefix: str, name: str) -> str:
"""Add a prefix to a name if the prefix is non-empty.

View File

@ -1,8 +1,15 @@
from abc import ABC, abstractmethod
from typing import Final, Generic, Protocol, TypeVar
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
import torch
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.platforms import _Backend, current_platform
from vllm.utils import print_warning_once
_C = TypeVar("_C", bound=PretrainedConfig)
@ -60,3 +67,77 @@ def get_vision_encoder_info(
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
"""
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
if selected_backend is None:
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
# For Volta and Turing GPUs, use xformers instead.
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
selected_backend = _Backend.FLASH_ATTN
else:
print_warning_once(
"Current `vllm-flash-attn` has a bug inside vision module, "
"so we use xformers backend instead. You can run "
"`pip install flash-attn` to use flash-attention backend.")
selected_backend = _Backend.XFORMERS
elif current_platform.is_cpu() or current_platform.is_rocm():
# ROCM doesn't support xformers
selected_backend = _Backend.TORCH_SDPA
else:
selected_backend = _Backend.XFORMERS
return selected_backend
def resolve_visual_encoder_outputs(
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
feature_sample_layers: Optional[list[int]],
post_layer_norm: Optional[torch.nn.LayerNorm],
max_possible_layers: int,
) -> torch.Tensor:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if feature_sample_layers is None:
if post_layer_norm is not None:
return post_layer_norm(encoder_outputs)
return encoder_outputs
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs contains a list
# of hidden states in the same order as the encoder layers
# that produced them.
offset = max_possible_layers - len(encoder_outputs)
hs_pool = [
encoder_outputs[layer_idx]
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
for layer_idx in feature_sample_layers
]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)

View File

@ -99,6 +99,8 @@ class MultiModalDataBuiltins(TypedDict, total=False):
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
"""
A dictionary containing an entry for each modality type to input.
The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
"""
@ -485,7 +487,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
"""
A dictionary containing placeholder ranges.
A dictionary containing placeholder ranges for each modality.
"""

View File

@ -5,7 +5,6 @@ from urllib.parse import ParseResult, urlparse
import numpy as np
import numpy.typing as npt
import torch
from PIL import Image
import vllm.envs as envs
@ -285,49 +284,6 @@ def encode_video_base64(frames: npt.NDArray) -> str:
return video_io.encode_base64(frames)
def resolve_visual_encoder_outputs(
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
feature_sample_layers: Optional[list[int]],
post_layer_norm: Optional[torch.nn.LayerNorm],
max_possible_layers: int,
) -> torch.Tensor:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if feature_sample_layers is None:
if post_layer_norm is not None:
return post_layer_norm(encoder_outputs)
return encoder_outputs
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs contains a list
# of hidden states in the same order as the encoder layers
# that produced them.
offset = max_possible_layers - len(encoder_outputs)
hs_pool = [
encoder_outputs[layer_idx]
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
for layer_idx in feature_sample_layers
]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
# Utilities for input processors
_T = TypeVar("_T", str, int)