mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 13:35:48 +08:00
[feat]: Create interface for model-specific M-RoPE (#24194)
Signed-off-by: AzizCode92 <azizbenothman76@gmail.com> Signed-off-by: Aziz <azizbenothman76@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
064cac7bb7
commit
38db529f66
@ -1,10 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
|
from .interfaces import (HasInnerState, SupportsLoRA, SupportsMRoPE,
|
||||||
SupportsPP, SupportsTranscription, SupportsV0Only,
|
SupportsMultiModal, SupportsPP, SupportsTranscription,
|
||||||
has_inner_state, supports_lora, supports_multimodal,
|
SupportsV0Only, has_inner_state, supports_lora,
|
||||||
supports_pp, supports_transcription, supports_v0_only)
|
supports_mrope, supports_multimodal, supports_pp,
|
||||||
|
supports_transcription, supports_v0_only)
|
||||||
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
|
from .interfaces_base import (VllmModelForPooling, VllmModelForTextGeneration,
|
||||||
is_pooling_model, is_text_generation_model)
|
is_pooling_model, is_text_generation_model)
|
||||||
from .registry import ModelRegistry
|
from .registry import ModelRegistry
|
||||||
@ -21,6 +22,8 @@ __all__ = [
|
|||||||
"supports_lora",
|
"supports_lora",
|
||||||
"SupportsMultiModal",
|
"SupportsMultiModal",
|
||||||
"supports_multimodal",
|
"supports_multimodal",
|
||||||
|
"SupportsMRoPE",
|
||||||
|
"supports_mrope",
|
||||||
"SupportsPP",
|
"SupportsPP",
|
||||||
"supports_pp",
|
"supports_pp",
|
||||||
"SupportsTranscription",
|
"SupportsTranscription",
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from transformers import PretrainedConfig
|
||||||
from transformers.models.whisper.tokenization_whisper import LANGUAGES
|
from transformers.models.whisper.tokenization_whisper import LANGUAGES
|
||||||
from typing_extensions import Self, TypeIs
|
from typing_extensions import Self, TypeIs
|
||||||
|
|
||||||
@ -852,3 +853,70 @@ def supports_eagle3(
|
|||||||
model: Union[type[object], object],
|
model: Union[type[object], object],
|
||||||
) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]:
|
) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]:
|
||||||
return isinstance(model, SupportsEagle3)
|
return isinstance(model, SupportsEagle3)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SupportsMRoPE(Protocol):
|
||||||
|
"""The interface required for all models that support M-RoPE."""
|
||||||
|
|
||||||
|
supports_mrope: ClassVar[Literal[True]] = True
|
||||||
|
"""
|
||||||
|
A flag that indicates this model supports M-RoPE.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
There is no need to redefine this flag if this class is in the
|
||||||
|
MRO of your model class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_mrope_input_positions(
|
||||||
|
self,
|
||||||
|
input_tokens: list[int],
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||||
|
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||||
|
second_per_grid_ts: Optional[list[float]] = None,
|
||||||
|
context_len: int = 0,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
"""
|
||||||
|
Get M-RoPE input positions and delta value for this specific model.
|
||||||
|
|
||||||
|
This method should be implemented by each model that supports M-RoPE
|
||||||
|
to provide model-specific logic for computing input positions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_tokens: List of input token IDs
|
||||||
|
hf_config: HuggingFace model configuration
|
||||||
|
image_grid_thw: Image grid dimensions (t, h, w)
|
||||||
|
video_grid_thw: Video grid dimensions (t, h, w)
|
||||||
|
second_per_grid_ts: Seconds per grid timestep for videos
|
||||||
|
context_len: Context length
|
||||||
|
seq_len: Sequence length
|
||||||
|
audio_feature_lengths: Audio feature lengths for multimodal models
|
||||||
|
use_audio_in_video: Whether to use audio in video for interleaving
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (llm_positions, mrope_position_delta)
|
||||||
|
- llm_positions: Tensor of shape [3, num_tokens]
|
||||||
|
with T/H/W positions
|
||||||
|
- mrope_position_delta: Delta for position calculations
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_mrope(model: type[object]) -> TypeIs[type[SupportsMRoPE]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def supports_mrope(model: object) -> TypeIs[SupportsMRoPE]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def supports_mrope(
|
||||||
|
model: Union[type[object], object],
|
||||||
|
) -> Union[TypeIs[type[SupportsMRoPE]], TypeIs[SupportsMRoPE]]:
|
||||||
|
return isinstance(model, SupportsMRoPE)
|
||||||
|
|||||||
@ -32,7 +32,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from transformers import AutoConfig, BatchFeature
|
from transformers import AutoConfig, BatchFeature, PretrainedConfig
|
||||||
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
|
from transformers.models.qwen2_vl import (Qwen2VLImageProcessor,
|
||||||
Qwen2VLProcessor)
|
Qwen2VLProcessor)
|
||||||
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
||||||
@ -73,7 +73,7 @@ from vllm.transformers_utils.config import uses_mrope
|
|||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMRoPE,
|
||||||
SupportsMultiModal, SupportsPP)
|
SupportsMultiModal, SupportsPP)
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
@ -1096,7 +1096,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
|||||||
info=Qwen2VLProcessingInfo,
|
info=Qwen2VLProcessingInfo,
|
||||||
dummy_inputs=Qwen2VLDummyInputsBuilder)
|
dummy_inputs=Qwen2VLDummyInputsBuilder)
|
||||||
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsLoRA, SupportsPP):
|
SupportsLoRA, SupportsPP, SupportsMRoPE):
|
||||||
|
|
||||||
# To ensure correct weight loading and mapping.
|
# To ensure correct weight loading and mapping.
|
||||||
hf_to_vllm_mapper = WeightsMapper(
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
@ -1109,6 +1109,118 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
"model.": "language_model.model.",
|
"model.": "language_model.model.",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def get_mrope_input_positions(
|
||||||
|
self,
|
||||||
|
input_tokens: list[int],
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
image_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||||
|
video_grid_thw: Optional[Union[list[list[int]], torch.Tensor]],
|
||||||
|
second_per_grid_ts: Optional[list[float]] = None,
|
||||||
|
context_len: int = 0,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
"""Get M-RoPE input positions for Qwen2-VL model."""
|
||||||
|
if image_grid_thw is None:
|
||||||
|
image_grid_thw = []
|
||||||
|
if video_grid_thw is None:
|
||||||
|
video_grid_thw = []
|
||||||
|
if second_per_grid_ts is None:
|
||||||
|
second_per_grid_ts = []
|
||||||
|
|
||||||
|
image_token_id = hf_config.image_token_id
|
||||||
|
video_token_id = hf_config.video_token_id
|
||||||
|
vision_start_token_id = hf_config.vision_start_token_id
|
||||||
|
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||||
|
tokens_per_second = getattr(hf_config.vision_config,
|
||||||
|
"tokens_per_second", 1.0)
|
||||||
|
|
||||||
|
input_tokens_tensor = torch.tensor(input_tokens)
|
||||||
|
vision_start_indices = torch.argwhere(
|
||||||
|
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
||||||
|
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
|
||||||
|
image_nums = (vision_tokens == image_token_id).sum()
|
||||||
|
video_nums = (vision_tokens == video_token_id).sum()
|
||||||
|
llm_pos_ids_list: list = []
|
||||||
|
|
||||||
|
st = 0
|
||||||
|
remain_images, remain_videos = image_nums, video_nums
|
||||||
|
|
||||||
|
image_index, video_index = 0, 0
|
||||||
|
for _ in range(image_nums + video_nums):
|
||||||
|
video_second_per_grid_t = 0.0
|
||||||
|
if remain_images > 0:
|
||||||
|
try:
|
||||||
|
ed_image = input_tokens.index(image_token_id, st)
|
||||||
|
except ValueError:
|
||||||
|
ed_image = len(input_tokens) + 1
|
||||||
|
else:
|
||||||
|
ed_image = len(input_tokens) + 1
|
||||||
|
if remain_videos > 0:
|
||||||
|
try:
|
||||||
|
ed_video = input_tokens.index(video_token_id, st)
|
||||||
|
except ValueError:
|
||||||
|
ed_video = len(input_tokens) + 1
|
||||||
|
else:
|
||||||
|
ed_video = len(input_tokens) + 1
|
||||||
|
if ed_image < ed_video:
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[image_index][0],
|
||||||
|
image_grid_thw[image_index][1],
|
||||||
|
image_grid_thw[image_index][2],
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
remain_images -= 1
|
||||||
|
ed = ed_image
|
||||||
|
else:
|
||||||
|
t, h, w = (
|
||||||
|
video_grid_thw[video_index][0],
|
||||||
|
video_grid_thw[video_index][1],
|
||||||
|
video_grid_thw[video_index][2],
|
||||||
|
)
|
||||||
|
video_second_per_grid_t = 1.0
|
||||||
|
if second_per_grid_ts:
|
||||||
|
video_second_per_grid_t = second_per_grid_ts[video_index]
|
||||||
|
video_index += 1
|
||||||
|
remain_videos -= 1
|
||||||
|
ed = ed_video
|
||||||
|
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||||
|
t, h // spatial_merge_size, w // spatial_merge_size
|
||||||
|
text_len = ed - st
|
||||||
|
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||||
|
llm_pos_ids_list) > 0 else 0
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||||
|
|
||||||
|
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||||
|
-1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
|
||||||
|
tokens_per_second).long().flatten()
|
||||||
|
|
||||||
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||||
|
llm_grid_t, -1, llm_grid_w).flatten()
|
||||||
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||||
|
llm_grid_t, llm_grid_h, -1).flatten()
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||||
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
|
if st < len(input_tokens):
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||||
|
llm_pos_ids_list) > 0 else 0
|
||||||
|
text_len = len(input_tokens) - st
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
mrope_position_delta = (llm_positions.max() + 1 -
|
||||||
|
len(input_tokens)).item()
|
||||||
|
llm_positions = llm_positions[:, context_len:seq_len]
|
||||||
|
|
||||||
|
return llm_positions, mrope_position_delta
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||||
if modality.startswith("image"):
|
if modality.startswith("image"):
|
||||||
|
|||||||
@ -42,6 +42,7 @@ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
|||||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||||
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
||||||
supports_eagle3,
|
supports_eagle3,
|
||||||
|
supports_mrope,
|
||||||
supports_transcription)
|
supports_transcription)
|
||||||
from vllm.model_executor.models.interfaces_base import (
|
from vllm.model_executor.models.interfaces_base import (
|
||||||
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
||||||
@ -730,16 +731,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if mm_input.get("use_audio_in_video") is True:
|
if mm_input.get("use_audio_in_video") is True:
|
||||||
use_audio_in_video = True
|
use_audio_in_video = True
|
||||||
|
|
||||||
req_state.mrope_positions, req_state.mrope_position_delta = \
|
if supports_mrope(self.model):
|
||||||
MRotaryEmbedding.get_input_positions_tensor(
|
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||||
req_state.prompt_token_ids,
|
self.model.get_mrope_input_positions(
|
||||||
hf_config=self.model_config.hf_config,
|
req_state.prompt_token_ids,
|
||||||
image_grid_thw=image_grid_thw,
|
hf_config=self.model_config.hf_config,
|
||||||
video_grid_thw=video_grid_thw,
|
image_grid_thw=image_grid_thw,
|
||||||
second_per_grid_ts=second_per_grid_ts,
|
video_grid_thw=video_grid_thw,
|
||||||
audio_feature_lengths=audio_feature_lengths,
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
use_audio_in_video=use_audio_in_video,
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
)
|
use_audio_in_video=use_audio_in_video,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
req_state.mrope_positions, req_state.mrope_position_delta = \
|
||||||
|
MRotaryEmbedding.get_input_positions_tensor(
|
||||||
|
req_state.prompt_token_ids,
|
||||||
|
hf_config=self.model_config.hf_config,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
|
use_audio_in_video=use_audio_in_video,
|
||||||
|
)
|
||||||
|
|
||||||
def _extract_mm_kwargs(
|
def _extract_mm_kwargs(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -41,7 +41,8 @@ from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
|
|||||||
get_sampler)
|
get_sampler)
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
from vllm.model_executor.models import (supports_lora, supports_mrope,
|
||||||
|
supports_multimodal)
|
||||||
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
|
||||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||||
MultiModalKwargs, MultiModalPlaceholderMap,
|
MultiModalKwargs, MultiModalPlaceholderMap,
|
||||||
@ -670,18 +671,33 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
inter_data.seq_ids[seq_idx]]
|
inter_data.seq_ids[seq_idx]]
|
||||||
token_ids = seq_data.get_token_ids()
|
token_ids = seq_data.get_token_ids()
|
||||||
|
|
||||||
mrope_input_positions, mrope_position_delta = \
|
if supports_mrope(self.runner.model):
|
||||||
MRotaryEmbedding.get_input_positions(
|
mrope_input_positions, mrope_position_delta = \
|
||||||
token_ids,
|
self.runner.model.get_mrope_input_positions(
|
||||||
hf_config=hf_config,
|
token_ids,
|
||||||
image_grid_thw=image_grid_thw,
|
hf_config=hf_config,
|
||||||
video_grid_thw=video_grid_thw,
|
image_grid_thw=image_grid_thw,
|
||||||
second_per_grid_ts=second_per_grid_ts,
|
video_grid_thw=video_grid_thw,
|
||||||
context_len=inter_data.context_lens[seq_idx],
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
seq_len=inter_data.seq_lens[seq_idx],
|
context_len=inter_data.context_lens[seq_idx],
|
||||||
audio_feature_lengths=audio_feature_lengths,
|
seq_len=inter_data.seq_lens[seq_idx],
|
||||||
use_audio_in_video=use_audio_in_video,
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
)
|
use_audio_in_video=use_audio_in_video,
|
||||||
|
)
|
||||||
|
mrope_input_positions = mrope_input_positions.tolist()
|
||||||
|
else:
|
||||||
|
mrope_input_positions, mrope_position_delta = \
|
||||||
|
MRotaryEmbedding.get_input_positions(
|
||||||
|
token_ids,
|
||||||
|
hf_config=hf_config,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
|
context_len=inter_data.context_lens[seq_idx],
|
||||||
|
seq_len=inter_data.seq_lens[seq_idx],
|
||||||
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
|
use_audio_in_video=use_audio_in_video,
|
||||||
|
)
|
||||||
|
|
||||||
seq_data.mrope_position_delta = mrope_position_delta
|
seq_data.mrope_position_delta = mrope_position_delta
|
||||||
inter_data.mrope_input_positions[
|
inter_data.mrope_input_positions[
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user