mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[Model][VLM] Add Qwen2.5-Omni model support (thinker only) (#15130)
Signed-off-by: fyabc <suyang.fy@alibaba-inc.com> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Roger Wang <ywang@roblox.com> Co-authored-by: Xiong Wang <wangxiongts@163.com>
This commit is contained in:
parent
5c9121203c
commit
2c1bd848a6
@ -1040,6 +1040,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `Qwen2_5OmniThinkerForConditionalGeneration`
|
||||||
|
* Qwen2.5-Omni
|
||||||
|
* T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup>
|
||||||
|
* `Qwen/Qwen2.5-Omni-7B`
|
||||||
|
*
|
||||||
|
* ✅︎
|
||||||
|
* ✅︎\*
|
||||||
- * `SkyworkR1VChatModel`
|
- * `SkyworkR1VChatModel`
|
||||||
* Skywork-R1V-38B
|
* Skywork-R1V-38B
|
||||||
* T + I
|
* T + I
|
||||||
@ -1109,6 +1116,14 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
|
|||||||
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
|
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
:::{note}
|
||||||
|
To use Qwen2.5-Omni, you have to install a fork of Hugging Face Transformers library from source via
|
||||||
|
`pip install git+https://github.com/BakerBunker/transformers.git@qwen25omni`.
|
||||||
|
|
||||||
|
Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
|
||||||
|
`--mm-processor-kwargs '{"use_audio_in_video": True}'`.
|
||||||
|
:::
|
||||||
|
|
||||||
### Pooling Models
|
### Pooling Models
|
||||||
|
|
||||||
See [this page](pooling-models) for more information on how to use pooling models.
|
See [this page](pooling-models) for more information on how to use pooling models.
|
||||||
|
|||||||
@ -130,6 +130,36 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Qwen2.5-Omni
|
||||||
|
def run_qwen2_5_omni(question: str, audio_count: int):
|
||||||
|
model_name = "Qwen/Qwen2.5-Omni-7B"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=5,
|
||||||
|
limit_mm_per_prompt={"audio": audio_count},
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_in_prompt = "".join([
|
||||||
|
"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
|
||||||
|
])
|
||||||
|
|
||||||
|
default_system = (
|
||||||
|
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
|
||||||
|
"Group, capable of perceiving auditory and visual inputs, as well as "
|
||||||
|
"generating text and speech.")
|
||||||
|
|
||||||
|
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n"
|
||||||
|
f"{audio_in_prompt}{question}<|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n")
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompt=prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Ultravox 0.5-1B
|
# Ultravox 0.5-1B
|
||||||
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
|
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
|
||||||
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||||
@ -182,6 +212,7 @@ model_example_map = {
|
|||||||
"minicpmo": run_minicpmo,
|
"minicpmo": run_minicpmo,
|
||||||
"phi4_mm": run_phi4mm,
|
"phi4_mm": run_phi4mm,
|
||||||
"qwen2_audio": run_qwen2_audio,
|
"qwen2_audio": run_qwen2_audio,
|
||||||
|
"qwen2_5_omni": run_qwen2_5_omni,
|
||||||
"ultravox": run_ultravox,
|
"ultravox": run_ultravox,
|
||||||
"whisper": run_whisper,
|
"whisper": run_whisper,
|
||||||
}
|
}
|
||||||
|
|||||||
32
examples/offline_inference/qwen2_5_omni/README.md
Normal file
32
examples/offline_inference/qwen2_5_omni/README.md
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Qwen2.5-Omni Offline Inference Examples
|
||||||
|
|
||||||
|
This folder provides several example scripts on how to inference Qwen2.5-Omni offline.
|
||||||
|
|
||||||
|
## Thinker Only
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Audio + image + video
|
||||||
|
python examples/offline_inference/qwen2_5_omni/only_thinker.py -q mixed_modalities
|
||||||
|
|
||||||
|
# Read vision and audio inputs from a single video file
|
||||||
|
# NOTE: V1 engine does not support interleaved modalities yet.
|
||||||
|
VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q use_audio_in_video
|
||||||
|
|
||||||
|
# Multiple audios
|
||||||
|
VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q multi_audios
|
||||||
|
```
|
||||||
|
|
||||||
|
This script will run the thinker part of Qwen2.5-Omni, and generate text response.
|
||||||
|
|
||||||
|
You can also test Qwen2.5-Omni on a single modality:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Process audio inputs
|
||||||
|
python examples/offline_inference/audio_language.py --model-type qwen2_5_omni
|
||||||
|
|
||||||
|
# Process image inputs
|
||||||
|
python examples/offline_inference/vision_language.py --modality image --model-type qwen2_5_omni
|
||||||
|
|
||||||
|
# Process video inputs
|
||||||
|
python examples/offline_inference/vision_language.py --modality video --model-type qwen2_5_omni
|
||||||
|
```
|
||||||
160
examples/offline_inference/qwen2_5_omni/only_thinker.py
Normal file
160
examples/offline_inference/qwen2_5_omni/only_thinker.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
This example shows how to use vLLM for running offline inference
|
||||||
|
with the correct prompt format on Qwen2.5-Omni (thinker only).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.assets.audio import AudioAsset
|
||||||
|
from vllm.assets.image import ImageAsset
|
||||||
|
from vllm.assets.video import VideoAsset
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
class QueryResult(NamedTuple):
|
||||||
|
inputs: dict
|
||||||
|
limit_mm_per_prompt: dict[str, int]
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
|
||||||
|
# lower-end GPUs.
|
||||||
|
# Unless specified, these settings have been tested to work on a single L4.
|
||||||
|
|
||||||
|
default_system = (
|
||||||
|
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
|
||||||
|
"Group, capable of perceiving auditory and visual inputs, as well as "
|
||||||
|
"generating text and speech.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_mixed_modalities_query() -> QueryResult:
|
||||||
|
question = ("What is recited in the audio? "
|
||||||
|
"What is the content of this image? Why is this video funny?")
|
||||||
|
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||||
|
"<|vision_bos|><|IMAGE|><|vision_eos|>"
|
||||||
|
"<|vision_bos|><|VIDEO|><|vision_eos|>"
|
||||||
|
f"{question}<|im_end|>\n"
|
||||||
|
f"<|im_start|>assistant\n")
|
||||||
|
return QueryResult(
|
||||||
|
inputs={
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"audio":
|
||||||
|
AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||||
|
"image":
|
||||||
|
ImageAsset("cherry_blossom").pil_image.convert("RGB"),
|
||||||
|
"video":
|
||||||
|
VideoAsset(name="sample_demo_1.mp4",
|
||||||
|
num_frames=16).np_ndarrays,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
limit_mm_per_prompt={
|
||||||
|
"audio": 1,
|
||||||
|
"image": 1,
|
||||||
|
"video": 1
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_use_audio_in_video_query() -> QueryResult:
|
||||||
|
question = ("Describe the content of the video, "
|
||||||
|
"then convert what the baby say into text.")
|
||||||
|
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
|
||||||
|
f"{question}<|im_end|>\n"
|
||||||
|
f"<|im_start|>assistant\n")
|
||||||
|
asset = VideoAsset(name="sample_demo_1.mp4", num_frames=16)
|
||||||
|
audio = asset.get_audio(sampling_rate=16000)
|
||||||
|
assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. "
|
||||||
|
"Please launch this example with "
|
||||||
|
"`VLLM_USE_V1=0`.")
|
||||||
|
return QueryResult(
|
||||||
|
inputs={
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"video": asset.np_ndarrays,
|
||||||
|
"audio": audio,
|
||||||
|
},
|
||||||
|
"mm_processor_kwargs": {
|
||||||
|
"use_audio_in_video": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
limit_mm_per_prompt={
|
||||||
|
"audio": 1,
|
||||||
|
"video": 1
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_multi_audios_query() -> QueryResult:
|
||||||
|
question = "Are these two audio clips the same?"
|
||||||
|
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||||
|
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||||
|
"<|audio_bos|><|AUDIO|><|audio_eos|>"
|
||||||
|
f"{question}<|im_end|>\n"
|
||||||
|
f"<|im_start|>assistant\n")
|
||||||
|
return QueryResult(
|
||||||
|
inputs={
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"audio": [
|
||||||
|
AudioAsset("winning_call").audio_and_sample_rate,
|
||||||
|
AudioAsset("mary_had_lamb").audio_and_sample_rate,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
limit_mm_per_prompt={
|
||||||
|
"audio": 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
query_map = {
|
||||||
|
"mixed_modalities": get_mixed_modalities_query,
|
||||||
|
"use_audio_in_video": get_use_audio_in_video_query,
|
||||||
|
"multi_audios": get_multi_audios_query,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
model_name = "Qwen/Qwen2.5-Omni-7B"
|
||||||
|
query_result = query_map[args.query_type]()
|
||||||
|
|
||||||
|
llm = LLM(model=model_name,
|
||||||
|
max_model_len=5632,
|
||||||
|
max_num_seqs=5,
|
||||||
|
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
|
||||||
|
seed=args.seed)
|
||||||
|
|
||||||
|
# We set temperature to 0.2 so that outputs can be different
|
||||||
|
# even when all prompts are identical when running batch inference.
|
||||||
|
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)
|
||||||
|
|
||||||
|
outputs = llm.generate(query_result.inputs,
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description='Demo on using vLLM for offline inference with '
|
||||||
|
'audio language models')
|
||||||
|
parser.add_argument('--query-type',
|
||||||
|
'-q',
|
||||||
|
type=str,
|
||||||
|
default="mixed_modalities",
|
||||||
|
choices=query_map.keys(),
|
||||||
|
help='Query type.')
|
||||||
|
parser.add_argument("--seed",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Set the seed when initializing `vllm.LLM`.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@ -941,6 +941,42 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Qwen2.5-Omni
|
||||||
|
def run_qwen2_5_omni(questions: list[str], modality: str):
|
||||||
|
model_name = "Qwen/Qwen2.5-Omni-7B"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=5,
|
||||||
|
mm_processor_kwargs={
|
||||||
|
"min_pixels": 28 * 28,
|
||||||
|
"max_pixels": 1280 * 28 * 28,
|
||||||
|
"fps": [1],
|
||||||
|
},
|
||||||
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
if modality == "image":
|
||||||
|
placeholder = "<|IMAGE|>"
|
||||||
|
elif modality == "video":
|
||||||
|
placeholder = "<|VIDEO|>"
|
||||||
|
|
||||||
|
default_system = (
|
||||||
|
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
|
||||||
|
"Group, capable of perceiving auditory and visual inputs, as well as "
|
||||||
|
"generating text and speech.")
|
||||||
|
|
||||||
|
prompts = [(f"<|im_start|>system\n{default_system}<|im_end|>\n"
|
||||||
|
f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>"
|
||||||
|
f"{question}<|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n") for question in questions]
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompts=prompts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# SkyworkR1V
|
# SkyworkR1V
|
||||||
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
|
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
@ -1010,6 +1046,7 @@ model_example_map = {
|
|||||||
"qwen_vl": run_qwen_vl,
|
"qwen_vl": run_qwen_vl,
|
||||||
"qwen2_vl": run_qwen2_vl,
|
"qwen2_vl": run_qwen2_vl,
|
||||||
"qwen2_5_vl": run_qwen2_5_vl,
|
"qwen2_5_vl": run_qwen2_5_vl,
|
||||||
|
"qwen2_5_omni": run_qwen2_5_omni,
|
||||||
"skywork_chat": run_skyworkr1v,
|
"skywork_chat": run_skyworkr1v,
|
||||||
"smolvlm": run_smolvlm,
|
"smolvlm": run_smolvlm,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -139,6 +139,23 @@ VLM_TEST_SETTINGS = {
|
|||||||
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
),
|
),
|
||||||
|
"qwen2_5_omni": VLMTestInfo(
|
||||||
|
models=["Qwen/Qwen2.5-Omni-7B"],
|
||||||
|
test_type=(
|
||||||
|
VLMTestType.IMAGE,
|
||||||
|
VLMTestType.MULTI_IMAGE,
|
||||||
|
VLMTestType.VIDEO
|
||||||
|
),
|
||||||
|
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
|
||||||
|
img_idx_to_prompt=lambda idx: "<|vision_bos|><|IMAGE|><|vision_eos|>", # noqa: E501
|
||||||
|
video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", # noqa: E501
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=2,
|
||||||
|
auto_cls=AutoModelForVision2Seq,
|
||||||
|
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
|
||||||
|
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
|
||||||
|
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
|
||||||
|
),
|
||||||
#### Extended model tests
|
#### Extended model tests
|
||||||
"aria": VLMTestInfo(
|
"aria": VLMTestInfo(
|
||||||
models=["rhymes-ai/Aria"],
|
models=["rhymes-ai/Aria"],
|
||||||
|
|||||||
@ -280,6 +280,7 @@ def _test_processing_correctness_mistral(
|
|||||||
"Qwen/Qwen2-VL-2B-Instruct",
|
"Qwen/Qwen2-VL-2B-Instruct",
|
||||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||||
|
"Qwen/Qwen2.5-Omni-7B",
|
||||||
"Skywork/Skywork-R1V-38B",
|
"Skywork/Skywork-R1V-38B",
|
||||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||||
"openai/whisper-large-v3",
|
"openai/whisper-large-v3",
|
||||||
|
|||||||
@ -362,6 +362,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||||
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
||||||
min_transformers_version="4.49"), # noqa: E501
|
min_transformers_version="4.49"), # noqa: E501
|
||||||
|
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B", # noqa: E501
|
||||||
|
min_transformers_version="4.52"), # noqa: E501
|
||||||
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
|
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
|
||||||
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
|
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
|
||||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Literal
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,8 +10,15 @@ import numpy.typing as npt
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm.utils import PlaceholderModule
|
||||||
|
|
||||||
from .base import get_cache_dir
|
from .base import get_cache_dir
|
||||||
|
|
||||||
|
try:
|
||||||
|
import librosa
|
||||||
|
except ImportError:
|
||||||
|
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def download_video_asset(filename: str) -> str:
|
def download_video_asset(filename: str) -> str:
|
||||||
@ -85,3 +92,12 @@ class VideoAsset:
|
|||||||
video_path = download_video_asset(self.name)
|
video_path = download_video_asset(self.name)
|
||||||
ret = video_to_ndarrays(video_path, self.num_frames)
|
ret = video_to_ndarrays(video_path, self.num_frames)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
|
||||||
|
"""
|
||||||
|
Read audio data from the video asset, used in Qwen2.5-Omni examples.
|
||||||
|
|
||||||
|
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
|
||||||
|
"""
|
||||||
|
video_path = download_video_asset(self.name)
|
||||||
|
return librosa.load(video_path, sr=sampling_rate)[0]
|
||||||
|
|||||||
@ -506,6 +506,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
return "<|image|>"
|
return "<|image|>"
|
||||||
if model_type in ("qwen2_vl", "qwen2_5_vl"):
|
if model_type in ("qwen2_vl", "qwen2_5_vl"):
|
||||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
|
if model_type == "qwen2_5_omni":
|
||||||
|
return "<|vision_start|><|IMAGE|><|vision_end|>"
|
||||||
if model_type == "molmo":
|
if model_type == "molmo":
|
||||||
return ""
|
return ""
|
||||||
if model_type == "aria":
|
if model_type == "aria":
|
||||||
@ -521,7 +523,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
return "<|audio|>"
|
return "<|audio|>"
|
||||||
if model_type == "phi4mm":
|
if model_type == "phi4mm":
|
||||||
return "<|endoftext11|>" # 200011 (see vocab.json in hf model)
|
return "<|endoftext11|>" # 200011 (see vocab.json in hf model)
|
||||||
if model_type == "qwen2_audio":
|
if model_type in ("qwen2_audio", "qwen2_5_omni"):
|
||||||
return (f"Audio {current_count}: "
|
return (f"Audio {current_count}: "
|
||||||
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
|
f"<|audio_bos|><|AUDIO|><|audio_eos|>")
|
||||||
if model_type == "minicpmo":
|
if model_type == "minicpmo":
|
||||||
@ -530,6 +532,8 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
elif modality == "video":
|
elif modality == "video":
|
||||||
if model_type in ("qwen2_vl", "qwen2_5_vl"):
|
if model_type in ("qwen2_vl", "qwen2_5_vl"):
|
||||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||||
|
if model_type == "qwen2_5_omni":
|
||||||
|
return "<|vision_start|><|VIDEO|><|vision_end|>"
|
||||||
if model_type in ("minicpmo", "minicpmv"):
|
if model_type in ("minicpmo", "minicpmv"):
|
||||||
return "(<video>./</video>)"
|
return "(<video>./</video>)"
|
||||||
if model_type.startswith("llava"):
|
if model_type.startswith("llava"):
|
||||||
|
|||||||
@ -747,7 +747,7 @@ def compute_hash() -> str:
|
|||||||
variables, ensure that it is included in the factors list if
|
variables, ensure that it is included in the factors list if
|
||||||
it affects the computation graph. For example, different values
|
it affects the computation graph. For example, different values
|
||||||
of VLLM_PP_LAYER_PARTITION will generate different computation
|
of VLLM_PP_LAYER_PARTITION will generate different computation
|
||||||
graphs, so it is included in the factors list. The env vars that
|
graphs, so it is included in the factors list. The env vars that
|
||||||
affect the choice of different kernels or attention backends should
|
affect the choice of different kernels or attention backends should
|
||||||
also be included in the factors list.
|
also be included in the factors list.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -988,8 +988,9 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
def get_input_positions(
|
def get_input_positions(
|
||||||
|
cls,
|
||||||
input_tokens: List[int],
|
input_tokens: List[int],
|
||||||
hf_config: PretrainedConfig,
|
hf_config: PretrainedConfig,
|
||||||
image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]],
|
image_grid_thw: Optional[Union[List[List[int]], torch.Tensor]],
|
||||||
@ -997,6 +998,8 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
second_per_grid_ts: Optional[List[float]],
|
second_per_grid_ts: Optional[List[float]],
|
||||||
context_len: int = 0,
|
context_len: int = 0,
|
||||||
seq_len: Optional[int] = None,
|
seq_len: Optional[int] = None,
|
||||||
|
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
) -> Tuple[List[List[int]], int]:
|
) -> Tuple[List[List[int]], int]:
|
||||||
"""Get mrope input positions and delta value."""
|
"""Get mrope input positions and delta value."""
|
||||||
|
|
||||||
@ -1006,7 +1009,48 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
second_per_grid_ts
|
second_per_grid_ts
|
||||||
|
|
||||||
llm_positions, mrope_position_delta = \
|
llm_positions, mrope_position_delta = \
|
||||||
MRotaryEmbedding.get_input_positions_tensor(
|
cls.get_input_positions_tensor(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
hf_config=hf_config,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
|
context_len=context_len,
|
||||||
|
seq_len=seq_len,
|
||||||
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
|
use_audio_in_video=use_audio_in_video,
|
||||||
|
)
|
||||||
|
|
||||||
|
return llm_positions.tolist(), mrope_position_delta
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_input_positions_tensor(
|
||||||
|
cls,
|
||||||
|
input_tokens: List[int],
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||||
|
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||||
|
second_per_grid_ts: List[float],
|
||||||
|
context_len: int = 0,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, int]:
|
||||||
|
from vllm.transformers_utils.config import thinker_uses_mrope
|
||||||
|
if thinker_uses_mrope(hf_config):
|
||||||
|
return cls._omni_get_input_positions_tensor(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
hf_config=hf_config,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
|
context_len=context_len,
|
||||||
|
seq_len=seq_len,
|
||||||
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
|
use_audio_in_video=use_audio_in_video,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return cls._vl_get_input_positions_tensor(
|
||||||
input_tokens=input_tokens,
|
input_tokens=input_tokens,
|
||||||
hf_config=hf_config,
|
hf_config=hf_config,
|
||||||
image_grid_thw=image_grid_thw,
|
image_grid_thw=image_grid_thw,
|
||||||
@ -1016,10 +1060,9 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
seq_len=seq_len,
|
seq_len=seq_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
return llm_positions.tolist(), mrope_position_delta
|
@classmethod
|
||||||
|
def _vl_get_input_positions_tensor(
|
||||||
@staticmethod
|
cls,
|
||||||
def get_input_positions_tensor(
|
|
||||||
input_tokens: List[int],
|
input_tokens: List[int],
|
||||||
hf_config: PretrainedConfig,
|
hf_config: PretrainedConfig,
|
||||||
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||||
@ -1037,11 +1080,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
tokens_per_second = getattr(hf_config.vision_config,
|
tokens_per_second = getattr(hf_config.vision_config,
|
||||||
"tokens_per_second", 1.0)
|
"tokens_per_second", 1.0)
|
||||||
|
|
||||||
if isinstance(image_grid_thw, torch.Tensor):
|
|
||||||
image_grid_thw = image_grid_thw.tolist()
|
|
||||||
if isinstance(video_grid_thw, torch.Tensor):
|
|
||||||
video_grid_thw = video_grid_thw.tolist()
|
|
||||||
|
|
||||||
input_tokens_tensor = torch.tensor(input_tokens)
|
input_tokens_tensor = torch.tensor(input_tokens)
|
||||||
vision_start_indices = torch.argwhere(
|
vision_start_indices = torch.argwhere(
|
||||||
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
input_tokens_tensor == vision_start_token_id).squeeze(1)
|
||||||
@ -1121,6 +1159,226 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
|
|
||||||
return llm_positions, mrope_position_delta
|
return llm_positions, mrope_position_delta
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _omni_get_input_positions_tensor(
|
||||||
|
cls,
|
||||||
|
input_tokens: List[int],
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
image_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||||
|
video_grid_thw: Union[List[List[int]], torch.Tensor],
|
||||||
|
second_per_grid_ts: Optional[List[float]] = None,
|
||||||
|
context_len: int = 0,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
audio_feature_lengths: Optional[torch.Tensor] = None,
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, int]:
|
||||||
|
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
|
||||||
|
|
||||||
|
Differences from MRotaryEmbedding:
|
||||||
|
1. Add audio support (and related `audio_feature_lengths`).
|
||||||
|
2. Add `use_audio_in_video` option to read audio from video inputs.
|
||||||
|
In this case, audio and vision position ids will be split into
|
||||||
|
chunks and interleaved.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
(V_i are vision position ids, A_i are audio position ids)
|
||||||
|
|
||||||
|
|V_1 ... V_n|A_1 ... A_n|V_n+1 ... V_2n|A_n+1 ... A_2n|...
|
||||||
|
|vision chunk 1|audio chunk 1|vision chunk 2|audio chunk 2 |...
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO(fyabc): refactor and share more code with
|
||||||
|
# _vl_get_input_positions_tensor.
|
||||||
|
|
||||||
|
thinker_config = hf_config.thinker_config
|
||||||
|
audio_token_id = thinker_config.audio_token_index
|
||||||
|
image_token_id = thinker_config.image_token_index
|
||||||
|
video_token_id = thinker_config.video_token_index
|
||||||
|
audio_start_token_id = thinker_config.audio_start_token_id
|
||||||
|
audio_end_token_id = thinker_config.audio_end_token_id
|
||||||
|
vision_end_token_id = thinker_config.vision_end_token_id
|
||||||
|
seconds_per_chunk = thinker_config.seconds_per_chunk
|
||||||
|
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
|
||||||
|
tokens_per_second = getattr(thinker_config.vision_config,
|
||||||
|
"tokens_per_second", 25)
|
||||||
|
|
||||||
|
if isinstance(image_grid_thw, list):
|
||||||
|
image_grid_thw = torch.tensor(image_grid_thw)
|
||||||
|
if isinstance(video_grid_thw, list):
|
||||||
|
video_grid_thw = torch.tensor(video_grid_thw)
|
||||||
|
|
||||||
|
src_item = input_tokens
|
||||||
|
audio_seqlens = audio_feature_lengths
|
||||||
|
if not second_per_grid_ts:
|
||||||
|
second_per_grid_ts = [1] * video_grid_thw.shape[0]
|
||||||
|
audio_idx = 0
|
||||||
|
video_idx = 0
|
||||||
|
image_idx = 0
|
||||||
|
new_src_item: list[int] = []
|
||||||
|
llm_pos_ids_list: list[torch.Tensor] = []
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
while idx < len(src_item):
|
||||||
|
new_src_item_len = len(new_src_item)
|
||||||
|
start_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||||
|
llm_pos_ids_list) > 0 else 0
|
||||||
|
if src_item[idx] not in [
|
||||||
|
audio_token_id, video_token_id, image_token_id
|
||||||
|
]:
|
||||||
|
if src_item[idx] == vision_end_token_id and use_audio_in_video:
|
||||||
|
start_idx -= 1
|
||||||
|
new_src_item.append(src_item[idx])
|
||||||
|
llm_pos_ids = torch.tensor([start_idx],
|
||||||
|
dtype=torch.long).expand(3, -1)
|
||||||
|
llm_pos_ids_list.append(llm_pos_ids)
|
||||||
|
elif src_item[idx] == audio_token_id:
|
||||||
|
assert audio_seqlens is not None
|
||||||
|
audio_seqlen = audio_seqlens[audio_idx]
|
||||||
|
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1)
|
||||||
|
new_src_item.extend([audio_token_id] * place_num)
|
||||||
|
llm_pos_ids = torch.arange(place_num).expand(3, -1) + start_idx
|
||||||
|
llm_pos_ids_list.append(llm_pos_ids)
|
||||||
|
audio_idx += 1
|
||||||
|
elif src_item[idx] == image_token_id:
|
||||||
|
grid_t = image_grid_thw[image_idx][0]
|
||||||
|
grid_hs = image_grid_thw[:, 1]
|
||||||
|
grid_ws = image_grid_thw[:, 2]
|
||||||
|
t_index = (torch.arange(grid_t) * 1 * tokens_per_second).long()
|
||||||
|
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||||
|
start_idx, image_idx, spatial_merge_size, t_index, grid_hs,
|
||||||
|
grid_ws)
|
||||||
|
llm_pos_ids_list.append(llm_pos_ids)
|
||||||
|
vision_seqlen = image_grid_thw[image_idx].prod() // (
|
||||||
|
spatial_merge_size**2)
|
||||||
|
new_src_item.extend([image_token_id] * vision_seqlen)
|
||||||
|
image_idx += 1
|
||||||
|
elif src_item[idx] == video_token_id and not use_audio_in_video:
|
||||||
|
grid_t = video_grid_thw[video_idx][0]
|
||||||
|
grid_hs = video_grid_thw[:, 1]
|
||||||
|
grid_ws = video_grid_thw[:, 2]
|
||||||
|
t_index = (torch.arange(grid_t) *
|
||||||
|
second_per_grid_ts[video_idx] *
|
||||||
|
tokens_per_second).long()
|
||||||
|
llm_pos_ids = cls._get_llm_pos_ids_for_vision(
|
||||||
|
start_idx, video_idx, spatial_merge_size, t_index, grid_hs,
|
||||||
|
grid_ws)
|
||||||
|
llm_pos_ids_list.append(llm_pos_ids)
|
||||||
|
vision_seqlen = video_grid_thw[video_idx].prod() // (
|
||||||
|
spatial_merge_size**2)
|
||||||
|
new_src_item.extend([video_token_id] * vision_seqlen)
|
||||||
|
video_idx += 1
|
||||||
|
else:
|
||||||
|
# read audio from video
|
||||||
|
assert audio_seqlens is not None
|
||||||
|
audio_seqlen = audio_seqlens[audio_idx]
|
||||||
|
vision_seqlen = video_grid_thw[video_idx].prod() // (
|
||||||
|
spatial_merge_size**2)
|
||||||
|
grid_t = video_grid_thw[video_idx][0]
|
||||||
|
grid_h = video_grid_thw[video_idx][1]
|
||||||
|
grid_w = video_grid_thw[video_idx][2]
|
||||||
|
grid_hs = video_grid_thw[:, 1]
|
||||||
|
grid_ws = video_grid_thw[:, 2]
|
||||||
|
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
|
||||||
|
t_index = (torch.arange(grid_t) *
|
||||||
|
second_per_grid_ts[video_idx] *
|
||||||
|
tokens_per_second).long()
|
||||||
|
t_index_split_chunk = cls._split_list_into_ranges(
|
||||||
|
t_index, t_ntoken_per_chunk)
|
||||||
|
new_src_item.extend([audio_start_token_id])
|
||||||
|
start_idx -= 1
|
||||||
|
llm_pos_ids_list.extend([
|
||||||
|
torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
|
||||||
|
] * 1)
|
||||||
|
place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
|
||||||
|
pure_audio_len = place_num - 2
|
||||||
|
added_audio_len = 0
|
||||||
|
audio_llm_pos_ids_list: List[torch.Tensor] = []
|
||||||
|
for t_chunk in t_index_split_chunk:
|
||||||
|
vision_ntoken_per_chunk = len(
|
||||||
|
t_chunk) * grid_h * grid_w // (spatial_merge_size**2)
|
||||||
|
new_src_item.extend([video_token_id] *
|
||||||
|
vision_ntoken_per_chunk)
|
||||||
|
vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision(
|
||||||
|
start_idx + 1, video_idx, spatial_merge_size, t_chunk,
|
||||||
|
grid_hs, grid_ws).split(1, dim=1)
|
||||||
|
llm_pos_ids_list.extend(vision_llm_pos_ids_list)
|
||||||
|
new_src_item.extend(
|
||||||
|
min(t_ntoken_per_chunk, pure_audio_len -
|
||||||
|
added_audio_len) * [audio_token_id])
|
||||||
|
audio_start_idx = start_idx if len(
|
||||||
|
audio_llm_pos_ids_list
|
||||||
|
) == 0 else audio_llm_pos_ids_list[-1][0].item()
|
||||||
|
if min(t_ntoken_per_chunk,
|
||||||
|
pure_audio_len - added_audio_len) > 0:
|
||||||
|
audio_llm_pos_ids_list = (torch.arange(
|
||||||
|
min(t_ntoken_per_chunk, pure_audio_len -
|
||||||
|
added_audio_len)).expand(3, -1) +
|
||||||
|
audio_start_idx + 1).split(
|
||||||
|
1, dim=1)
|
||||||
|
else:
|
||||||
|
audio_llm_pos_ids_list = []
|
||||||
|
added_audio_len += min(t_ntoken_per_chunk,
|
||||||
|
pure_audio_len - added_audio_len)
|
||||||
|
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
|
||||||
|
if added_audio_len < pure_audio_len:
|
||||||
|
new_src_item.extend(
|
||||||
|
(pure_audio_len - added_audio_len) * [audio_token_id])
|
||||||
|
audio_llm_pos_ids_list = (
|
||||||
|
torch.arange(pure_audio_len - added_audio_len).expand(
|
||||||
|
3, -1) + llm_pos_ids_list[-1].max() + 1).split(
|
||||||
|
1, dim=1)
|
||||||
|
llm_pos_ids_list.extend(audio_llm_pos_ids_list)
|
||||||
|
llm_pos_ids_list.extend([
|
||||||
|
torch.tensor(
|
||||||
|
[llm_pos_ids_list[-1].max() + 1] * 3).unsqueeze(1)
|
||||||
|
] * 1)
|
||||||
|
new_src_item.extend([audio_end_token_id])
|
||||||
|
audio_idx += 1
|
||||||
|
video_idx += 1
|
||||||
|
# move to the next token
|
||||||
|
idx += len(new_src_item) - new_src_item_len
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1)
|
||||||
|
mrope_position_delta = torch.cat(llm_pos_ids_list,
|
||||||
|
dim=1).max() + 1 - len(src_item)
|
||||||
|
llm_positions = llm_positions[:, context_len:seq_len]
|
||||||
|
|
||||||
|
return llm_positions, mrope_position_delta
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_llm_pos_ids_for_vision(
|
||||||
|
start_idx: int,
|
||||||
|
vision_idx: int,
|
||||||
|
spatial_merge_size: int,
|
||||||
|
t_index: List[int],
|
||||||
|
grid_hs: torch.Tensor,
|
||||||
|
grid_ws: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
llm_pos_ids_list = []
|
||||||
|
llm_grid_h = grid_hs[vision_idx] // spatial_merge_size
|
||||||
|
llm_grid_w = grid_ws[vision_idx] // spatial_merge_size
|
||||||
|
h_index = (torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||||
|
len(t_index), -1, llm_grid_w).flatten())
|
||||||
|
w_index = (torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||||
|
len(t_index), llm_grid_h, -1).flatten())
|
||||||
|
t_index_tensor = torch.Tensor(t_index).to(llm_grid_h.device).view(
|
||||||
|
-1, 1).expand(-1, llm_grid_h * llm_grid_w).long().flatten()
|
||||||
|
_llm_pos_ids = torch.stack([t_index_tensor, h_index, w_index])
|
||||||
|
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
|
||||||
|
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
|
||||||
|
return llm_pos_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_list_into_ranges(lst: torch.Tensor,
|
||||||
|
interval: int) -> List[List[int]]:
|
||||||
|
ranges: List[List[int]] = [[]
|
||||||
|
for _ in range((max(lst) // interval) + 1)]
|
||||||
|
for num in lst:
|
||||||
|
index = num // interval
|
||||||
|
ranges[index].append(num)
|
||||||
|
return ranges
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_next_input_positions(
|
def get_next_input_positions(
|
||||||
mrope_position_delta: int,
|
mrope_position_delta: int,
|
||||||
@ -1144,6 +1402,58 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
mrope_position_delta + seq_len,
|
mrope_position_delta + seq_len,
|
||||||
).expand(3, -1)
|
).expand(3, -1)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def omni_get_updates_use_audio_in_video(
|
||||||
|
cls,
|
||||||
|
thinker_config: PretrainedConfig,
|
||||||
|
audio_len: int,
|
||||||
|
video_grid_thw: Union[List[int], torch.Tensor],
|
||||||
|
video_second_per_grid_t: float,
|
||||||
|
) -> List[int]:
|
||||||
|
"""Get video prompt updates when `use_audio_in_video` is True.
|
||||||
|
|
||||||
|
In this case, audio and vision update ids will be split into
|
||||||
|
chunks and interleaved (details in `_omni_get_input_positions_tensor`).
|
||||||
|
|
||||||
|
<|video_bos|><|VIDEO|><|video_eos|> =>
|
||||||
|
<|video_bos|><|audio_bos|>(... chunks ...)<|audio_eos|><|video_eos|>
|
||||||
|
"""
|
||||||
|
|
||||||
|
audio_token_id = thinker_config.audio_token_index
|
||||||
|
video_token_id = thinker_config.video_token_index
|
||||||
|
audio_start_token_id = thinker_config.audio_start_token_id
|
||||||
|
audio_end_token_id = thinker_config.audio_end_token_id
|
||||||
|
seconds_per_chunk = thinker_config.seconds_per_chunk
|
||||||
|
spatial_merge_size = thinker_config.vision_config.spatial_merge_size
|
||||||
|
tokens_per_second = getattr(thinker_config.vision_config,
|
||||||
|
"tokens_per_second", 25)
|
||||||
|
|
||||||
|
grid_t = video_grid_thw[0]
|
||||||
|
grid_h = video_grid_thw[1]
|
||||||
|
grid_w = video_grid_thw[2]
|
||||||
|
t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
|
||||||
|
t_index = (torch.arange(grid_t) * video_second_per_grid_t *
|
||||||
|
tokens_per_second).long()
|
||||||
|
t_index_split_chunk = cls._split_list_into_ranges(
|
||||||
|
t_index, t_ntoken_per_chunk)
|
||||||
|
|
||||||
|
updates = [audio_start_token_id]
|
||||||
|
added_audio_len = 0
|
||||||
|
for t_chunk in t_index_split_chunk:
|
||||||
|
vision_ntoken_per_chunk = len(t_chunk) * grid_h * grid_w // (
|
||||||
|
spatial_merge_size**2)
|
||||||
|
updates.extend([video_token_id] * vision_ntoken_per_chunk)
|
||||||
|
|
||||||
|
audio_chunk_size = min(t_ntoken_per_chunk,
|
||||||
|
audio_len - added_audio_len)
|
||||||
|
updates.extend(audio_chunk_size * [audio_token_id])
|
||||||
|
added_audio_len += audio_chunk_size
|
||||||
|
if added_audio_len < audio_len:
|
||||||
|
updates.extend((audio_len - added_audio_len) * [audio_token_id])
|
||||||
|
updates.extend([audio_end_token_id])
|
||||||
|
|
||||||
|
return updates
|
||||||
|
|
||||||
|
|
||||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||||
|
|
||||||
|
|||||||
@ -583,21 +583,21 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
modalities = {}
|
mm_input_by_modality = {}
|
||||||
|
|
||||||
# Preserve the order of modalities if there are multiple of them
|
# Preserve the order of modalities if there are multiple of them
|
||||||
# from the order of kwargs.
|
# from the order of kwargs.
|
||||||
for input_key in kwargs:
|
for input_key in kwargs:
|
||||||
if input_key in ("pixel_values",
|
if input_key in ("pixel_values", "image_embeds"
|
||||||
"image_embeds") and "images" not in modalities:
|
) and "image" not in mm_input_by_modality:
|
||||||
modalities["images"] = self._parse_and_validate_image_input(
|
mm_input_by_modality[
|
||||||
**kwargs)
|
"image"] = self._parse_and_validate_image_input(**kwargs)
|
||||||
if input_key in ("pixel_values_videos",
|
if input_key in ("pixel_values_videos", "video_embeds"
|
||||||
"video_embeds") and "videos" not in modalities:
|
) and "video" not in mm_input_by_modality:
|
||||||
modalities["videos"] = self._parse_and_validate_video_input(
|
mm_input_by_modality[
|
||||||
**kwargs)
|
"video"] = self._parse_and_validate_video_input(**kwargs)
|
||||||
|
|
||||||
return modalities
|
return mm_input_by_modality
|
||||||
|
|
||||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||||
strategy: str) -> torch.Tensor:
|
strategy: str) -> torch.Tensor:
|
||||||
@ -848,8 +848,9 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
|
||||||
if not modalities:
|
**kwargs)
|
||||||
|
if not mm_input_by_modality:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# The result multimodal_embeddings is tuple of tensors, with each
|
# The result multimodal_embeddings is tuple of tensors, with each
|
||||||
@ -858,14 +859,13 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
# NOTE: It is important to iterate over the keys in this dictionary
|
# NOTE: It is important to iterate over the keys in this dictionary
|
||||||
# to preserve the order of the modalities.
|
# to preserve the order of the modalities.
|
||||||
for modality in modalities:
|
for modality in mm_input_by_modality:
|
||||||
if modality == "images":
|
multimodal_input = mm_input_by_modality[modality]
|
||||||
image_input = modalities["images"]
|
if modality == "image":
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(multimodal_input)
|
||||||
multimodal_embeddings += tuple(vision_embeddings)
|
multimodal_embeddings += tuple(vision_embeddings)
|
||||||
if modality == "videos":
|
if modality == "video":
|
||||||
video_input = modalities["videos"]
|
video_embeddings = self._process_video_pixels(multimodal_input)
|
||||||
video_embeddings = self._process_video_pixels(video_input)
|
|
||||||
multimodal_embeddings += tuple(video_embeddings)
|
multimodal_embeddings += tuple(video_embeddings)
|
||||||
|
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|||||||
977
vllm/model_executor/models/qwen2_5_omni_thinker.py
Normal file
977
vllm/model_executor/models/qwen2_5_omni_thinker.py
Normal file
@ -0,0 +1,977 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright 2024 The Qwen team.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only Qwen2.5-Omni model (thinker part)."""
|
||||||
|
|
||||||
|
from copy import copy
|
||||||
|
from functools import cached_property, partial
|
||||||
|
from typing import (Any, Dict, Iterable, List, Mapping, Optional, Sequence,
|
||||||
|
Set, Tuple, Union)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.feature_extraction_utils import BatchFeature
|
||||||
|
from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import (
|
||||||
|
Qwen2_5OmniConfig, Qwen2_5OmniThinkerConfig)
|
||||||
|
from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import (
|
||||||
|
Qwen2_5OmniAudioEncoder)
|
||||||
|
from transformers.models.qwen2_5_omni.processing_qwen2_5_omni import (
|
||||||
|
Qwen2_5OmniProcessor)
|
||||||
|
from transformers.models.whisper import WhisperFeatureExtractor
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
|
from vllm.model_executor.models.qwen2_5_vl import (
|
||||||
|
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
|
||||||
|
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
|
||||||
|
Qwen2_5_VLProcessingInfo, Qwen2_5_VLVideoEmbeddingInputs,
|
||||||
|
Qwen2_5_VLVideoInputs, Qwen2_5_VLVideoPixelInputs)
|
||||||
|
from vllm.model_executor.models.qwen2_audio import (
|
||||||
|
Qwen2AudioInputs, Qwen2AudioProcessingInfo,
|
||||||
|
_get_feat_extract_output_lengths)
|
||||||
|
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.hasher import MultiModalHasher
|
||||||
|
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||||
|
MultiModalDataDict, MultiModalFieldConfig,
|
||||||
|
MultiModalInputs, MultiModalKwargs,
|
||||||
|
NestedTensors)
|
||||||
|
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
|
||||||
|
ModalityDataItems, MultiModalDataItems,
|
||||||
|
MultiModalDataParser)
|
||||||
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
PlaceholderFeaturesInfo,
|
||||||
|
PromptReplacement, PromptUpdate)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
|
||||||
|
|
||||||
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||||
|
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
|
init_vllm_registered_model, maybe_prefix,
|
||||||
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
flash_attn = None
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||||
|
audio_feature_lengths = hf_inputs.get("audio_feature_lengths",
|
||||||
|
torch.empty((0, )))
|
||||||
|
|
||||||
|
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||||
|
image_grid_sizes = image_grid_thw.prod(-1)
|
||||||
|
|
||||||
|
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||||
|
video_grid_sizes = video_grid_thw.prod(-1)
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"audio", audio_feature_lengths, dim=1),
|
||||||
|
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||||
|
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
|
||||||
|
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"image", image_grid_sizes),
|
||||||
|
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"image", image_grid_sizes),
|
||||||
|
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||||
|
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"video", video_grid_sizes),
|
||||||
|
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"video", video_grid_sizes),
|
||||||
|
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||||
|
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
|
||||||
|
|
||||||
|
def _parse_audio_data(
|
||||||
|
self,
|
||||||
|
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||||
|
) -> ModalityDataItems[Any, Any]:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return DictEmbeddingItems(
|
||||||
|
data,
|
||||||
|
modality="audio",
|
||||||
|
required_fields={
|
||||||
|
"input_audio_features", "audio_feature_lengths"
|
||||||
|
},
|
||||||
|
fields_factory=_qwen2_5_omni_thinker_field_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return super()._parse_audio_data(data)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5OmniThinkerProcessingInfo(Qwen2AudioProcessingInfo,
|
||||||
|
Qwen2_5_VLProcessingInfo):
|
||||||
|
|
||||||
|
def get_hf_config(self):
|
||||||
|
return self.ctx.get_hf_config(Qwen2_5OmniConfig).thinker_config
|
||||||
|
|
||||||
|
def get_hf_processor(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sampling_rate: Optional[int] = None,
|
||||||
|
min_pixels: Optional[int] = None,
|
||||||
|
max_pixels: Optional[int] = None,
|
||||||
|
size: Optional[dict[str, int]] = None,
|
||||||
|
fps: Optional[Union[float, List[float]]] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> Qwen2_5OmniProcessor:
|
||||||
|
if fps is not None:
|
||||||
|
kwargs["fps"] = fps
|
||||||
|
processor = self.ctx.get_hf_processor(
|
||||||
|
Qwen2_5OmniProcessor,
|
||||||
|
image_processor=self.get_image_processor(min_pixels=min_pixels,
|
||||||
|
max_pixels=max_pixels,
|
||||||
|
size=size),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
if not hasattr(processor, "audio_token"):
|
||||||
|
processor.audio_token = "<|AUDIO|>"
|
||||||
|
if not hasattr(processor, "image_token"):
|
||||||
|
processor.image_token = "<|IMAGE|>"
|
||||||
|
if not hasattr(processor, "video_token"):
|
||||||
|
processor.video_token = "<|VIDEO|>"
|
||||||
|
return processor
|
||||||
|
|
||||||
|
def get_feature_extractor(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sampling_rate: Optional[int] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
):
|
||||||
|
hf_processor = self.get_hf_processor(sampling_rate=sampling_rate)
|
||||||
|
feature_extractor = hf_processor.feature_extractor # type: ignore
|
||||||
|
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||||
|
return feature_extractor
|
||||||
|
|
||||||
|
def get_max_audio_tokens(self) -> int:
|
||||||
|
hf_config = self.get_hf_config()
|
||||||
|
max_source_position = hf_config.audio_config.max_source_positions
|
||||||
|
output_lengths = (max_source_position - 2) // 2 + 1
|
||||||
|
return output_lengths
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"audio": None, "image": None, "video": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> Mapping[str, int]:
|
||||||
|
return {
|
||||||
|
"audio": self.get_max_audio_tokens(),
|
||||||
|
"image": self.get_max_image_tokens(),
|
||||||
|
"video": self.get_max_video_tokens(seq_len, mm_counts),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5OmniThinkerDummyInputsBuilder(
|
||||||
|
BaseDummyInputsBuilder[Qwen2_5OmniThinkerProcessingInfo]):
|
||||||
|
|
||||||
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
|
num_audios = mm_counts.get("audio", 0)
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
num_videos = mm_counts.get("video", 0)
|
||||||
|
|
||||||
|
hf_processor = self.info.get_hf_processor()
|
||||||
|
|
||||||
|
audio_token: str = hf_processor.audio_token
|
||||||
|
image_token: str = hf_processor.image_token
|
||||||
|
video_token: str = hf_processor.video_token
|
||||||
|
|
||||||
|
return (audio_token * num_audios + image_token * num_images +
|
||||||
|
video_token * num_videos)
|
||||||
|
|
||||||
|
# TODO: @abstractmethod after transition
|
||||||
|
def get_dummy_mm_data(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> MultiModalDataDict:
|
||||||
|
num_audios = mm_counts.get("audio", 0)
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
num_videos = mm_counts.get("video", 0)
|
||||||
|
|
||||||
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
|
|
||||||
|
target_audio_length = min(
|
||||||
|
feature_extractor.chunk_length,
|
||||||
|
30,
|
||||||
|
) * feature_extractor.sampling_rate
|
||||||
|
target_width, target_height = \
|
||||||
|
self.info.get_image_size_with_most_features()
|
||||||
|
target_num_frames = \
|
||||||
|
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
||||||
|
|
||||||
|
mm_data = {
|
||||||
|
"audio":
|
||||||
|
self._get_dummy_audios(length=target_audio_length,
|
||||||
|
num_audios=num_audios),
|
||||||
|
"image":
|
||||||
|
self._get_dummy_images(width=target_width,
|
||||||
|
height=target_height,
|
||||||
|
num_images=num_images),
|
||||||
|
"video":
|
||||||
|
self._get_dummy_videos(width=target_width,
|
||||||
|
height=target_height,
|
||||||
|
num_frames=target_num_frames,
|
||||||
|
num_videos=num_videos),
|
||||||
|
}
|
||||||
|
|
||||||
|
return mm_data
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||||
|
BaseMultiModalProcessor[Qwen2_5OmniThinkerProcessingInfo]):
|
||||||
|
|
||||||
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
|
feature_extractor = self.info.get_feature_extractor()
|
||||||
|
return Qwen2_5OmniThinkerMultiModalDataParser(
|
||||||
|
target_sr=feature_extractor.sampling_rate)
|
||||||
|
|
||||||
|
def _call_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: Mapping[str, object],
|
||||||
|
mm_kwargs: Mapping[str, object],
|
||||||
|
) -> BatchFeature:
|
||||||
|
mm_data = dict(mm_data)
|
||||||
|
audios = mm_data.pop("audios", [])
|
||||||
|
|
||||||
|
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
|
||||||
|
if audios:
|
||||||
|
# NOTE: Qwen2.5-Omni processor accept "audio"
|
||||||
|
mm_data["audio"] = audios
|
||||||
|
mm_kwargs = dict(**mm_kwargs, )
|
||||||
|
|
||||||
|
hf_inputs = super()._call_hf_processor(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
mm_kwargs=mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_features = hf_inputs.pop('input_features', None)
|
||||||
|
feature_attention_mask = hf_inputs.get('feature_attention_mask', None)
|
||||||
|
if ('input_audio_features' not in hf_inputs
|
||||||
|
and input_features is not None):
|
||||||
|
if feature_attention_mask is not None:
|
||||||
|
input_features = input_features.permute(
|
||||||
|
0, 2, 1)[feature_attention_mask.bool()].permute(1, 0)
|
||||||
|
hf_inputs['input_audio_features'] = input_features
|
||||||
|
if ('audio_feature_lengths' not in hf_inputs
|
||||||
|
and feature_attention_mask is not None):
|
||||||
|
hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1)
|
||||||
|
return hf_inputs
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return _qwen2_5_omni_thinker_field_config(hf_inputs)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, list[int]],
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
return_mm_hashes: bool = False,
|
||||||
|
) -> MultiModalInputs:
|
||||||
|
"""
|
||||||
|
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
|
||||||
|
"""
|
||||||
|
mm_items = self._to_mm_items(mm_data)
|
||||||
|
|
||||||
|
# Create MM hashes to be returned (only used in V1)
|
||||||
|
# TODO: Use these hash keys for caching operations in apply_hf_processor
|
||||||
|
# instead of rehashing.
|
||||||
|
|
||||||
|
if return_mm_hashes:
|
||||||
|
model_id = self.info.model_id
|
||||||
|
mm_hashes = {
|
||||||
|
modality: [
|
||||||
|
MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||||
|
**{modality: item},
|
||||||
|
**hf_processor_mm_kwargs)
|
||||||
|
for item in items
|
||||||
|
]
|
||||||
|
for modality, items in mm_items.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
mm_hashes = None
|
||||||
|
|
||||||
|
(
|
||||||
|
prompt_ids,
|
||||||
|
mm_kwargs,
|
||||||
|
is_update_applied,
|
||||||
|
) = self._cached_apply_hf_processor(
|
||||||
|
prompt,
|
||||||
|
mm_items,
|
||||||
|
hf_processor_mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
unbound_prompt_updates = self._get_prompt_updates(
|
||||||
|
mm_items,
|
||||||
|
hf_processor_mm_kwargs,
|
||||||
|
mm_kwargs,
|
||||||
|
)
|
||||||
|
mm_prompt_updates = self._bind_and_group_updates(
|
||||||
|
unbound_prompt_updates)
|
||||||
|
|
||||||
|
mm_item_counts = mm_items.get_all_counts()
|
||||||
|
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
|
||||||
|
|
||||||
|
use_audio_in_video = hf_processor_mm_kwargs.get(
|
||||||
|
"use_audio_in_video", False)
|
||||||
|
|
||||||
|
if is_update_applied:
|
||||||
|
mm_placeholders = self._find_mm_placeholders(
|
||||||
|
mm_prompt_updates,
|
||||||
|
prompt_ids,
|
||||||
|
mm_item_counts,
|
||||||
|
)
|
||||||
|
self._validate_mm_placeholders(
|
||||||
|
mm_placeholders,
|
||||||
|
mm_item_counts,
|
||||||
|
use_audio_in_video=use_audio_in_video)
|
||||||
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
prompt = decode_tokens(tokenizer, prompt_ids)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
prompt_ids,
|
||||||
|
prompt,
|
||||||
|
mm_placeholders,
|
||||||
|
) = self._apply_prompt_updates(
|
||||||
|
prompt_ids,
|
||||||
|
mm_prompt_updates,
|
||||||
|
mm_item_counts,
|
||||||
|
)
|
||||||
|
self._validate_mm_placeholders(
|
||||||
|
mm_placeholders,
|
||||||
|
mm_item_counts,
|
||||||
|
use_audio_in_video=use_audio_in_video)
|
||||||
|
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
prompt = decode_tokens(tokenizer, prompt_ids)
|
||||||
|
|
||||||
|
mm_placeholder_ranges = {
|
||||||
|
modality: [item.to_range() for item in placeholders]
|
||||||
|
for modality, placeholders in mm_placeholders.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
if use_audio_in_video:
|
||||||
|
mm_kwargs["use_audio_in_video"] = True
|
||||||
|
|
||||||
|
return MultiModalInputs(
|
||||||
|
type="multimodal",
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_token_ids=prompt_ids,
|
||||||
|
mm_kwargs=mm_kwargs,
|
||||||
|
mm_hashes=mm_hashes,
|
||||||
|
mm_placeholders=mm_placeholder_ranges,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> Sequence[PromptUpdate]:
|
||||||
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
image_processor = self.info.get_image_processor(
|
||||||
|
**hf_processor_mm_kwargs)
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
|
||||||
|
audio_token = processor.audio_token
|
||||||
|
image_token = processor.image_token
|
||||||
|
video_token = processor.video_token
|
||||||
|
audio_token_id = vocab[audio_token]
|
||||||
|
image_token_id = vocab[image_token]
|
||||||
|
video_token_id = vocab[video_token]
|
||||||
|
|
||||||
|
audio_feature_lengths = out_mm_kwargs.get("audio_feature_lengths")
|
||||||
|
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
|
||||||
|
if audio_feature_lengths is None and feature_attention_mask is None:
|
||||||
|
audio_output_lengths = []
|
||||||
|
elif audio_feature_lengths is not None:
|
||||||
|
_, audio_output_lens = _get_feat_extract_output_lengths(
|
||||||
|
audio_feature_lengths)
|
||||||
|
audio_output_lengths = audio_output_lens.tolist()
|
||||||
|
elif feature_attention_mask is not None:
|
||||||
|
assert isinstance(feature_attention_mask, torch.Tensor)
|
||||||
|
_, audio_output_lens = _get_feat_extract_output_lengths(
|
||||||
|
feature_attention_mask.sum(-1))
|
||||||
|
audio_output_lengths = audio_output_lens.tolist()
|
||||||
|
|
||||||
|
# number of audios read from video.
|
||||||
|
audio_in_video_item_idx = 0
|
||||||
|
|
||||||
|
def get_replacement_qwen2_audio(item_idx: int):
|
||||||
|
item_idx += audio_in_video_item_idx
|
||||||
|
|
||||||
|
num_features = audio_output_lengths[item_idx]
|
||||||
|
if num_features == 0:
|
||||||
|
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||||
|
audio = audios.get(item_idx)
|
||||||
|
raise ValueError(
|
||||||
|
f"The audio {audio} (len={len(audio)}) is too short "
|
||||||
|
"to be represented inside the model")
|
||||||
|
|
||||||
|
return [audio_token_id] * num_features
|
||||||
|
|
||||||
|
def get_replacement_qwen2_vision(item_idx: int, modality: str):
|
||||||
|
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
|
||||||
|
assert isinstance(grid_thw, torch.Tensor)
|
||||||
|
merge_length = image_processor.merge_size**2
|
||||||
|
|
||||||
|
token_id = image_token_id if modality == "image" else video_token_id
|
||||||
|
return [token_id] * (int(grid_thw.prod()) // merge_length)
|
||||||
|
|
||||||
|
use_audio_in_video = hf_processor_mm_kwargs.get(
|
||||||
|
"use_audio_in_video", False)
|
||||||
|
thinker_config = self.info.get_hf_config()
|
||||||
|
|
||||||
|
def get_replacement_qwen2_use_audio_in_video(item_idx: int):
|
||||||
|
nonlocal audio_in_video_item_idx
|
||||||
|
|
||||||
|
audio_num_features = audio_output_lengths[audio_in_video_item_idx +
|
||||||
|
item_idx]
|
||||||
|
video_grid_thw = out_mm_kwargs["video_grid_thw"][item_idx]
|
||||||
|
|
||||||
|
audio_in_video_item_idx += 1
|
||||||
|
|
||||||
|
second_per_grid_ts = hf_processor_mm_kwargs.get(
|
||||||
|
"second_per_grid_ts", None)
|
||||||
|
if second_per_grid_ts:
|
||||||
|
video_second_per_grid_t = second_per_grid_ts[item_idx]
|
||||||
|
else:
|
||||||
|
video_second_per_grid_t = 1.0
|
||||||
|
|
||||||
|
return MRotaryEmbedding.omni_get_updates_use_audio_in_video(
|
||||||
|
thinker_config=thinker_config,
|
||||||
|
audio_len=audio_num_features,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
video_second_per_grid_t=video_second_per_grid_t,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_replacement_fn = (
|
||||||
|
get_replacement_qwen2_use_audio_in_video if use_audio_in_video else
|
||||||
|
partial(get_replacement_qwen2_vision, modality="video"))
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="audio",
|
||||||
|
target=audio_token,
|
||||||
|
replacement=get_replacement_qwen2_audio,
|
||||||
|
),
|
||||||
|
PromptReplacement(
|
||||||
|
modality="image",
|
||||||
|
target=image_token,
|
||||||
|
replacement=partial(get_replacement_qwen2_vision,
|
||||||
|
modality="image"),
|
||||||
|
),
|
||||||
|
PromptReplacement(
|
||||||
|
modality="video",
|
||||||
|
target=video_token,
|
||||||
|
replacement=video_replacement_fn,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _apply_hf_processor_main(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, list[int]],
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
*,
|
||||||
|
enable_hf_prompt_update: bool,
|
||||||
|
) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||||
|
"""
|
||||||
|
Qwen2.5-Omni reimplements this function to handle text only.
|
||||||
|
"""
|
||||||
|
print(prompt)
|
||||||
|
print(hf_processor_mm_kwargs)
|
||||||
|
print(mm_items)
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
if enable_hf_prompt_update:
|
||||||
|
return self._apply_hf_processor_text_mm(
|
||||||
|
prompt_text=prompt,
|
||||||
|
mm_items=mm_items,
|
||||||
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
|
)
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
prompt_ids = encode_tokens(tokenizer, prompt)
|
||||||
|
else:
|
||||||
|
prompt_ids = self._apply_hf_processor_tokens_only(prompt)
|
||||||
|
|
||||||
|
mm_kwargs = self._apply_hf_processor_mm_only(
|
||||||
|
mm_items=mm_items,
|
||||||
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt_ids, mm_kwargs, False
|
||||||
|
|
||||||
|
def _apply_hf_processor_mm_only(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> MultiModalKwargs:
|
||||||
|
"""
|
||||||
|
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
|
||||||
|
"""
|
||||||
|
mm_counts = mm_items.get_all_counts()
|
||||||
|
|
||||||
|
use_audio_in_video = hf_processor_mm_kwargs.get(
|
||||||
|
"use_audio_in_video", False)
|
||||||
|
if use_audio_in_video and "video" in mm_counts:
|
||||||
|
assert "audio" in mm_counts
|
||||||
|
mm_counts["audio"] -= mm_counts["video"]
|
||||||
|
|
||||||
|
_, mm_kwargs, _ = self._apply_hf_processor_text_mm(
|
||||||
|
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
|
||||||
|
mm_items=mm_items,
|
||||||
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return mm_kwargs
|
||||||
|
|
||||||
|
def _validate_mm_placeholders(
|
||||||
|
self,
|
||||||
|
mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
|
||||||
|
mm_item_counts: Mapping[str, int],
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if use_audio_in_video:
|
||||||
|
mm_item_counts = copy(mm_item_counts)
|
||||||
|
if "video" in mm_item_counts:
|
||||||
|
assert "audio" in mm_item_counts
|
||||||
|
mm_item_counts["audio"] -= mm_item_counts["video"]
|
||||||
|
super()._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5OmniConditionalGenerationMixin:
|
||||||
|
|
||||||
|
def _validate_and_reshape_mm_tensor(self,
|
||||||
|
mm_input: object,
|
||||||
|
name: str,
|
||||||
|
dim: int = 0) -> torch.Tensor:
|
||||||
|
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||||
|
raise ValueError(f"Incorrect type of {name}. "
|
||||||
|
f"Got type: {type(mm_input)}")
|
||||||
|
if isinstance(mm_input, torch.Tensor):
|
||||||
|
return torch.concat(list(mm_input), dim=dim)
|
||||||
|
else:
|
||||||
|
return torch.concat(mm_input, dim=dim)
|
||||||
|
|
||||||
|
def _parse_and_validate_audio_input(
|
||||||
|
self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
|
||||||
|
input_audio_features = kwargs.pop('input_audio_features', None)
|
||||||
|
audio_feature_lengths = kwargs.pop('audio_feature_lengths', None)
|
||||||
|
feature_attention_mask = kwargs.pop('feature_attention_mask', None)
|
||||||
|
if input_audio_features is None:
|
||||||
|
return None
|
||||||
|
input_audio_features = self._validate_and_reshape_mm_tensor(
|
||||||
|
input_audio_features, 'input_audio_features', dim=1)
|
||||||
|
if feature_attention_mask is not None:
|
||||||
|
feature_attention_mask = self._validate_and_reshape_mm_tensor(
|
||||||
|
feature_attention_mask, 'feature_attention_mask')
|
||||||
|
if not isinstance(input_audio_features, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of audio input features. "
|
||||||
|
f"Got type: {type(input_audio_features)}")
|
||||||
|
return Qwen2AudioInputs(input_features=input_audio_features,
|
||||||
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
|
feature_attention_mask=feature_attention_mask)
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> Optional[Qwen2_5_VLImageInputs]:
|
||||||
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
image_grid_thw = kwargs.pop("image_grid_thw", None)
|
||||||
|
|
||||||
|
if pixel_values is None and image_embeds is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||||
|
pixel_values, "image pixel values")
|
||||||
|
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
|
image_grid_thw, "image grid_thw")
|
||||||
|
|
||||||
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of image pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
return Qwen2_5_VLImagePixelInputs(type="pixel_values",
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_grid_thw=image_grid_thw)
|
||||||
|
|
||||||
|
if image_embeds is not None:
|
||||||
|
image_embeds = self._validate_and_reshape_mm_tensor(
|
||||||
|
image_embeds, "image embeds")
|
||||||
|
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
|
image_grid_thw, "image grid_thw")
|
||||||
|
|
||||||
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
|
f"Got type: {type(image_embeds)}")
|
||||||
|
return Qwen2_5_VLImageEmbeddingInputs(
|
||||||
|
type="image_embeds",
|
||||||
|
image_embeds=image_embeds,
|
||||||
|
image_grid_thw=image_grid_thw)
|
||||||
|
|
||||||
|
def _parse_and_validate_video_input(
|
||||||
|
self,
|
||||||
|
**kwargs: Dict[str, Any],
|
||||||
|
) -> Optional[Qwen2_5_VLVideoInputs]:
|
||||||
|
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
|
||||||
|
video_embeds = kwargs.pop("video_embeds", None)
|
||||||
|
video_grid_thw = kwargs.pop("video_grid_thw", None)
|
||||||
|
|
||||||
|
if pixel_values_videos is None and video_embeds is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if pixel_values_videos is not None:
|
||||||
|
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
||||||
|
pixel_values_videos, "video pixel values")
|
||||||
|
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
|
video_grid_thw, "video grid_thw")
|
||||||
|
|
||||||
|
return Qwen2_5_VLVideoPixelInputs(
|
||||||
|
type="pixel_values_videos",
|
||||||
|
pixel_values_videos=pixel_values_videos,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
)
|
||||||
|
|
||||||
|
if video_embeds is not None:
|
||||||
|
video_embeds = self._validate_and_reshape_mm_tensor(
|
||||||
|
video_embeds, "video embeds")
|
||||||
|
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
|
video_grid_thw, "video grid_thw")
|
||||||
|
|
||||||
|
if not isinstance(video_embeds, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of video embeddings. "
|
||||||
|
f"Got type: {type(video_embeds)}")
|
||||||
|
return Qwen2_5_VLVideoEmbeddingInputs(
|
||||||
|
type="video_embeds",
|
||||||
|
video_embeds=video_embeds,
|
||||||
|
video_grid_thw=video_grid_thw)
|
||||||
|
|
||||||
|
def _process_audio_input(
|
||||||
|
self,
|
||||||
|
audio_input: Qwen2AudioInputs,
|
||||||
|
audio_hashes: List[str] = None,
|
||||||
|
cached_audio_features: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
input_features = audio_input["input_features"]
|
||||||
|
audio_feature_lengths = audio_input["audio_feature_lengths"]
|
||||||
|
if input_features.ndim == 3:
|
||||||
|
assert input_features.shape[0] == 1
|
||||||
|
input_features = input_features.squeeze(0)
|
||||||
|
if audio_feature_lengths.ndim == 2:
|
||||||
|
assert audio_feature_lengths.shape[
|
||||||
|
0] == 1 or audio_feature_lengths.shape[1] == 1
|
||||||
|
if audio_feature_lengths.shape[0] == 1:
|
||||||
|
audio_feature_lengths = audio_feature_lengths.squeeze(0)
|
||||||
|
else:
|
||||||
|
audio_feature_lengths = audio_feature_lengths.squeeze(1)
|
||||||
|
|
||||||
|
audio_feat_lengths, audio_output_lengths = (
|
||||||
|
self.audio_tower._get_feat_extract_output_lengths(
|
||||||
|
audio_feature_lengths))
|
||||||
|
|
||||||
|
audio_outputs = self.audio_tower(
|
||||||
|
input_features.to(self.audio_tower.dtype),
|
||||||
|
feature_lens=audio_feature_lengths,
|
||||||
|
aftercnn_lens=audio_feat_lengths,
|
||||||
|
)
|
||||||
|
audio_features = audio_outputs.last_hidden_state
|
||||||
|
return audio_features.split(audio_output_lengths.tolist())
|
||||||
|
|
||||||
|
def _process_image_input(
|
||||||
|
self,
|
||||||
|
image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
|
||||||
|
if image_input["type"] == "image_embeds":
|
||||||
|
return image_input["image_embeds"].type(self.visual.dtype)
|
||||||
|
|
||||||
|
grid_thw = image_input["image_grid_thw"]
|
||||||
|
assert grid_thw.ndim == 2
|
||||||
|
|
||||||
|
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
|
||||||
|
# Split concatenated embeddings for each image item.
|
||||||
|
merge_size = self.visual.spatial_merge_size
|
||||||
|
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||||
|
|
||||||
|
return image_embeds.split(sizes.tolist())
|
||||||
|
|
||||||
|
def _process_video_input(
|
||||||
|
self,
|
||||||
|
video_input: Qwen2_5_VLVideoInputs,
|
||||||
|
video_hashes: List[str] = None,
|
||||||
|
cached_video_embeds: torch.Tensor = None) -> torch.Tensor:
|
||||||
|
if video_input["type"] == "video_embeds":
|
||||||
|
return video_input["video_embeds"].type(self.visual.dtype)
|
||||||
|
|
||||||
|
grid_thw = video_input["video_grid_thw"]
|
||||||
|
assert grid_thw.ndim == 2
|
||||||
|
|
||||||
|
pixel_values_videos = video_input["pixel_values_videos"].type(
|
||||||
|
self.visual.dtype)
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
|
||||||
|
# Split concatenated embeddings for each video item.
|
||||||
|
merge_size = self.visual.spatial_merge_size
|
||||||
|
sizes = grid_thw.prod(-1) // merge_size // merge_size
|
||||||
|
|
||||||
|
return video_embeds.split(sizes.tolist())
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
Qwen2_5OmniThinkerMultiModalProcessor,
|
||||||
|
info=Qwen2_5OmniThinkerProcessingInfo,
|
||||||
|
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
|
||||||
|
)
|
||||||
|
class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||||
|
nn.Module, SupportsMultiModal, SupportsPP,
|
||||||
|
Qwen2_5OmniConditionalGenerationMixin):
|
||||||
|
hf_to_vllm_mapper = WeightsMapper(
|
||||||
|
orig_to_new_prefix={
|
||||||
|
"thinker.lm_head.": "language_model.lm_head.",
|
||||||
|
"thinker.model.": "language_model.model.",
|
||||||
|
"thinker.": "",
|
||||||
|
})
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
thinker_config: Qwen2_5OmniThinkerConfig = (
|
||||||
|
vllm_config.model_config.hf_config.thinker_config)
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
multimodal_config = vllm_config.model_config.multimodal_config
|
||||||
|
self.config = thinker_config
|
||||||
|
self.multimodal_config = multimodal_config
|
||||||
|
|
||||||
|
# force "use_flash_attention_2=True" to audio tower to align
|
||||||
|
# the results.
|
||||||
|
if flash_attn is not None:
|
||||||
|
audio_config = thinker_config.audio_config
|
||||||
|
audio_config._attn_implementation_autoset = True
|
||||||
|
audio_config._attn_implementation = "flash_attention_2"
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"flash_attn is not available, the model may not yield the "
|
||||||
|
"exactly same result as the transformers implementation "
|
||||||
|
"in the audio tower part.")
|
||||||
|
|
||||||
|
self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
|
||||||
|
self.visual = Qwen2_5_VisionTransformer(
|
||||||
|
vision_config=thinker_config.vision_config,
|
||||||
|
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
|
)
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.language_model = init_vllm_registered_model(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=maybe_prefix(prefix, "language_model"),
|
||||||
|
hf_config=thinker_config.text_config,
|
||||||
|
architectures=["Qwen2ForCausalLM"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def sampler(self):
|
||||||
|
if hasattr(self.language_model, "sampler"):
|
||||||
|
return self.language_model.sampler
|
||||||
|
|
||||||
|
return get_sampler()
|
||||||
|
|
||||||
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
|
mm_input_by_modality = {}
|
||||||
|
|
||||||
|
# Preserve the order of modalities if there are multiple of them
|
||||||
|
# from the order of kwargs.
|
||||||
|
for input_key in kwargs:
|
||||||
|
if input_key in ("pixel_values", "image_embeds"
|
||||||
|
) and "image" not in mm_input_by_modality:
|
||||||
|
mm_input_by_modality[
|
||||||
|
"image"] = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
if input_key in ("pixel_values_videos", "video_embeds"
|
||||||
|
) and "video" not in mm_input_by_modality:
|
||||||
|
mm_input_by_modality[
|
||||||
|
"video"] = self._parse_and_validate_video_input(**kwargs)
|
||||||
|
if input_key in ("input_audio_features"
|
||||||
|
) and "audio" not in mm_input_by_modality:
|
||||||
|
mm_input_by_modality[
|
||||||
|
"audio"] = self._parse_and_validate_audio_input(**kwargs)
|
||||||
|
return mm_input_by_modality
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
|
|
||||||
|
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
|
||||||
|
**kwargs)
|
||||||
|
if not mm_input_by_modality:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# The result multimodal_embeddings is tuple of tensors, with each
|
||||||
|
# tensor correspoending to a multimodal data item (image or video).
|
||||||
|
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||||
|
|
||||||
|
# NOTE: It is important to iterate over the keys in this dictionary
|
||||||
|
# to preserve the order of the modalities.
|
||||||
|
for modality in mm_input_by_modality:
|
||||||
|
multimodal_input = mm_input_by_modality[modality]
|
||||||
|
if modality == "image":
|
||||||
|
vision_embeddings = self._process_image_input(multimodal_input)
|
||||||
|
multimodal_embeddings += vision_embeddings
|
||||||
|
if modality == "video":
|
||||||
|
video_embeddings = self._process_video_input(multimodal_input)
|
||||||
|
multimodal_embeddings += video_embeddings
|
||||||
|
if modality == "audio":
|
||||||
|
audio_embeddings = self._process_audio_input(multimodal_input)
|
||||||
|
multimodal_embeddings += audio_embeddings
|
||||||
|
return multimodal_embeddings
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
if multimodal_embeddings is not None:
|
||||||
|
|
||||||
|
# TODO (ywang96): support overlapping modalitiy embeddings so that
|
||||||
|
# `use_audio_in_video` will work on V1.
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids, inputs_embeds, multimodal_embeddings, [
|
||||||
|
self.config.image_token_index,
|
||||||
|
self.config.video_token_index,
|
||||||
|
self.config.audio_token_index
|
||||||
|
])
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def get_multimodal_embeddings_v0(
|
||||||
|
self, **kwargs: object) -> Optional[NestedTensors]:
|
||||||
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||||
|
|
||||||
|
if audio_input is None and image_input is None and video_input is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
multimodal_embeddings: List[Tuple[NestedTensors, str]] = []
|
||||||
|
|
||||||
|
if audio_input is not None:
|
||||||
|
audio_embeds = self._process_audio_input(audio_input)
|
||||||
|
multimodal_embeddings.append((audio_embeds, "audio"))
|
||||||
|
if image_input is not None:
|
||||||
|
image_embeds = self._process_image_input(image_input)
|
||||||
|
multimodal_embeddings.append((image_embeds, "image"))
|
||||||
|
if video_input is not None:
|
||||||
|
video_embeds = self._process_video_input(video_input)
|
||||||
|
multimodal_embeddings.append((video_embeds, "video"))
|
||||||
|
return multimodal_embeddings
|
||||||
|
|
||||||
|
def get_input_embeddings_v0(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
if multimodal_embeddings is None:
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
for embeddings, modality in multimodal_embeddings:
|
||||||
|
if modality == "audio":
|
||||||
|
placeholder_token_id = self.config.audio_token_index
|
||||||
|
if modality == "image":
|
||||||
|
placeholder_token_id = self.config.image_token_index
|
||||||
|
if modality == "video":
|
||||||
|
placeholder_token_id = self.config.video_token_index
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids, inputs_embeds, embeddings, placeholder_token_id)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
|
# condition is for v0 compatibility.
|
||||||
|
elif inputs_embeds is None:
|
||||||
|
multimodal_embeddings = self.get_multimodal_embeddings_v0(**kwargs)
|
||||||
|
inputs_embeds = self.get_input_embeddings_v0(
|
||||||
|
input_ids, multimodal_embeddings)
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
hidden_states = self.language_model.model(input_ids,
|
||||||
|
positions,
|
||||||
|
intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.language_model.compute_logits(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=["talker.", "token2wav."],
|
||||||
|
)
|
||||||
|
loaded_weights = loader.load_weights(weights,
|
||||||
|
mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
return loaded_weights
|
||||||
@ -38,13 +38,14 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
|||||||
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
from vllm.distributed import parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||||
@ -195,6 +196,23 @@ class Qwen2_5_VisionMLP(nn.Module):
|
|||||||
return x_down
|
return x_down
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
||||||
|
"""All-gather the input tensor interleavely across model parallel group."""
|
||||||
|
import torch.distributed as dist
|
||||||
|
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
||||||
|
dist.all_gather(gathered_tensors, local_tensor)
|
||||||
|
|
||||||
|
gathered_tensors_split = [
|
||||||
|
torch.split(tensor, hidden_size // tp_size, -1)
|
||||||
|
for tensor in gathered_tensors
|
||||||
|
]
|
||||||
|
ordered_tensors = [
|
||||||
|
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
|
||||||
|
]
|
||||||
|
result_tensor = torch.cat(ordered_tensors, dim=-1)
|
||||||
|
return result_tensor
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VisionAttention(nn.Module):
|
class Qwen2_5_VisionAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -214,10 +232,14 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
num_heads, self.tp_size)
|
num_heads, self.tp_size)
|
||||||
|
|
||||||
self.qkv = ColumnParallelLinear(input_size=embed_dim,
|
self.qkv = QKVParallelLinear(
|
||||||
output_size=3 * projection_size,
|
hidden_size=embed_dim,
|
||||||
quant_config=quant_config,
|
head_size=self.hidden_size_per_attention_head,
|
||||||
prefix=f"{prefix}.qkv")
|
total_num_heads=num_heads,
|
||||||
|
total_num_kv_heads=num_heads,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv")
|
||||||
self.proj = RowParallelLinear(input_size=projection_size,
|
self.proj = RowParallelLinear(input_size=projection_size,
|
||||||
output_size=embed_dim,
|
output_size=embed_dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@ -236,7 +258,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
# [s, b, 3 * head * head_dim]
|
# [s, b, 3 * head * head_dim]
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
qkv = tensor_model_parallel_all_gather(qkv)
|
qkv = all_gather_interleave(qkv, self.qkv.hidden_size,
|
||||||
|
self.tp_size)
|
||||||
|
|
||||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||||
q, k, v = qkv.chunk(3, dim=2)
|
q, k, v = qkv.chunk(3, dim=2)
|
||||||
@ -694,9 +717,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
("qkv_proj", "q_proj", "q"),
|
("attn.qkv.", "attn.q.", "q"),
|
||||||
("qkv_proj", "k_proj", "k"),
|
("attn.qkv.", "attn.k.", "k"),
|
||||||
("qkv_proj", "v_proj", "v"),
|
("attn.qkv.", "attn.v.", "v"),
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
loaded_params: Set[str] = set()
|
loaded_params: Set[str] = set()
|
||||||
@ -952,20 +975,20 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return video_embeds.split(sizes.tolist())
|
return video_embeds.split(sizes.tolist())
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
modalities = {}
|
mm_input_by_modality = {}
|
||||||
|
|
||||||
# Preserve the order of modalities if there are multiple of them
|
# Preserve the order of modalities if there are multiple of them
|
||||||
# from the order of kwargs.
|
# from the order of kwargs.
|
||||||
for input_key in kwargs:
|
for input_key in kwargs:
|
||||||
if input_key in ("pixel_values",
|
if input_key in ("pixel_values", "image_embeds"
|
||||||
"image_embeds") and "images" not in modalities:
|
) and "image" not in mm_input_by_modality:
|
||||||
modalities["images"] = self._parse_and_validate_image_input(
|
mm_input_by_modality[
|
||||||
**kwargs)
|
"image"] = self._parse_and_validate_image_input(**kwargs)
|
||||||
if input_key in ("pixel_values_videos",
|
if input_key in ("pixel_values_videos", "video_embeds"
|
||||||
"video_embeds") and "videos" not in modalities:
|
) and "video" not in mm_input_by_modality:
|
||||||
modalities["videos"] = self._parse_and_validate_video_input(
|
mm_input_by_modality[
|
||||||
**kwargs)
|
"video"] = self._parse_and_validate_video_input(**kwargs)
|
||||||
return modalities
|
return mm_input_by_modality
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
return self.language_model
|
return self.language_model
|
||||||
@ -973,8 +996,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
def get_multimodal_embeddings(
|
def get_multimodal_embeddings(
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
|
|
||||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
|
||||||
if not modalities:
|
**kwargs)
|
||||||
|
if not mm_input_by_modality:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# The result multimodal_embeddings is tuple of tensors, with each
|
# The result multimodal_embeddings is tuple of tensors, with each
|
||||||
@ -983,14 +1007,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
|
|
||||||
# NOTE: It is important to iterate over the keys in this dictionary
|
# NOTE: It is important to iterate over the keys in this dictionary
|
||||||
# to preserve the order of the modalities.
|
# to preserve the order of the modalities.
|
||||||
for modality in modalities:
|
for modality in mm_input_by_modality:
|
||||||
if modality == "images":
|
multimodal_input = mm_input_by_modality[modality]
|
||||||
image_input = modalities["images"]
|
if modality == "image":
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
vision_embeddings = self._process_image_input(multimodal_input)
|
||||||
multimodal_embeddings += vision_embeddings
|
multimodal_embeddings += vision_embeddings
|
||||||
if modality == "videos":
|
if modality == "video":
|
||||||
video_input = modalities["videos"]
|
video_embeddings = self._process_video_input(multimodal_input)
|
||||||
video_embeddings = self._process_video_input(video_input)
|
|
||||||
multimodal_embeddings += video_embeddings
|
multimodal_embeddings += video_embeddings
|
||||||
return multimodal_embeddings
|
return multimodal_embeddings
|
||||||
|
|
||||||
|
|||||||
@ -200,6 +200,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||||
"Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
|
"Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
|
||||||
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
|
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
|
||||||
|
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
|
||||||
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
"UltravoxModel": ("ultravox", "UltravoxModel"),
|
||||||
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
|
"Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"),
|
||||||
# [Encoder-decoder]
|
# [Encoder-decoder]
|
||||||
|
|||||||
@ -84,7 +84,7 @@ def replace_linear_class(
|
|||||||
) -> Union[ColumnParallelLinear, RowParallelLinear]:
|
) -> Union[ColumnParallelLinear, RowParallelLinear]:
|
||||||
"""
|
"""
|
||||||
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
linear (nn.Linear): `nn.Linear` to be replaced.
|
linear (nn.Linear): `nn.Linear` to be replaced.
|
||||||
style (str): Tensor parallel style of the new linear, e.g. "colwise".
|
style (str): Tensor parallel style of the new linear, e.g. "colwise".
|
||||||
|
|||||||
@ -320,7 +320,8 @@ class MultiModalFlatField(BaseMultiModalField):
|
|||||||
:func:`MultiModalFieldConfig.flat`
|
:func:`MultiModalFieldConfig.flat`
|
||||||
:func:`MultiModalFieldConfig.flat_from_sizes`
|
:func:`MultiModalFieldConfig.flat_from_sizes`
|
||||||
"""
|
"""
|
||||||
slices: Sequence[slice]
|
slices: Union[Sequence[slice], Sequence[Sequence[slice]]]
|
||||||
|
dim: int = 0
|
||||||
|
|
||||||
def build_elems(
|
def build_elems(
|
||||||
self,
|
self,
|
||||||
@ -329,7 +330,10 @@ class MultiModalFlatField(BaseMultiModalField):
|
|||||||
data: NestedTensors,
|
data: NestedTensors,
|
||||||
) -> Sequence[MultiModalFieldElem]:
|
) -> Sequence[MultiModalFieldElem]:
|
||||||
field_factory = self._field_factory(modality=modality, key=key)
|
field_factory = self._field_factory(modality=modality, key=key)
|
||||||
return [field_factory(data[s]) for s in self.slices]
|
if not is_list_of(self.slices, slice, check="all"):
|
||||||
|
assert isinstance(data, torch.Tensor), \
|
||||||
|
"torch.Tensor is required for multiple slices"
|
||||||
|
return [field_factory(data[cast(slice, s)]) for s in self.slices]
|
||||||
|
|
||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||||
@ -338,10 +342,16 @@ class MultiModalFlatField(BaseMultiModalField):
|
|||||||
# - produce exactly same result as `torch.concat(batch)`
|
# - produce exactly same result as `torch.concat(batch)`
|
||||||
# - will achieve zero-copy if the tensor is contiguous
|
# - will achieve zero-copy if the tensor is contiguous
|
||||||
return batch[0].contiguous()
|
return batch[0].contiguous()
|
||||||
first_shape = batch[0].shape
|
|
||||||
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
|
|
||||||
return torch.concat(batch)
|
|
||||||
|
|
||||||
|
def _expect_same_shape(tensor: torch.Tensor):
|
||||||
|
return tensor.shape[:self.dim] + tensor.shape[self.dim + 1:]
|
||||||
|
|
||||||
|
first_shape = _expect_same_shape(batch[0])
|
||||||
|
|
||||||
|
if all(_expect_same_shape(elem) == first_shape for elem in batch):
|
||||||
|
return torch.concat(batch, dim=self.dim)
|
||||||
|
|
||||||
|
assert self.dim == 0, "dim == 0 is required for nested list"
|
||||||
return [e for elem in batch for e in elem]
|
return [e for elem in batch for e in elem]
|
||||||
|
|
||||||
|
|
||||||
@ -398,7 +408,9 @@ class MultiModalFieldConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def flat(modality: str, slices: Sequence[slice]):
|
def flat(modality: str,
|
||||||
|
slices: Union[Sequence[slice], Sequence[Sequence[slice]]],
|
||||||
|
dim: int = 0):
|
||||||
"""
|
"""
|
||||||
Defines a field where an element in the batch is obtained by
|
Defines a field where an element in the batch is obtained by
|
||||||
slicing along the first dimension of the underlying data.
|
slicing along the first dimension of the underlying data.
|
||||||
@ -406,8 +418,10 @@ class MultiModalFieldConfig:
|
|||||||
Args:
|
Args:
|
||||||
modality: The modality of the multi-modal item that uses this
|
modality: The modality of the multi-modal item that uses this
|
||||||
keyword argument.
|
keyword argument.
|
||||||
slices: For each multi-modal item, a slice that is used to extract
|
slices: For each multi-modal item, a slice (dim=0) or a tuple of
|
||||||
the data corresponding to it.
|
slices (dim>0) that is used to extract the data corresponding
|
||||||
|
to it.
|
||||||
|
dim: The dimension to extract data, default to 0.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -423,14 +437,33 @@ class MultiModalFieldConfig:
|
|||||||
Element 1: [AAA]
|
Element 1: [AAA]
|
||||||
Element 2: [BBBB]
|
Element 2: [BBBB]
|
||||||
Element 3: [CC]
|
Element 3: [CC]
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
Given:
|
||||||
|
slices: [
|
||||||
|
(slice(None), slice(0, 3)),
|
||||||
|
(slice(None), slice(3, 7)),
|
||||||
|
(slice(None), slice(7, 9))]
|
||||||
|
dim: 1
|
||||||
|
|
||||||
|
Input:
|
||||||
|
Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]
|
||||||
|
|
||||||
|
Output:
|
||||||
|
Element 1: [[A],[A],[A]]
|
||||||
|
Element 2: [[B],[B],[B],[B]]
|
||||||
|
Element 3: [[C],[C]]
|
||||||
"""
|
"""
|
||||||
return MultiModalFieldConfig(
|
return MultiModalFieldConfig(
|
||||||
field=MultiModalFlatField(slices=slices),
|
field=MultiModalFlatField(slices=slices, dim=dim),
|
||||||
modality=modality,
|
modality=modality,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
|
def flat_from_sizes(modality: str,
|
||||||
|
size_per_item: torch.Tensor,
|
||||||
|
dim: int = 0):
|
||||||
"""
|
"""
|
||||||
Defines a field where an element in the batch is obtained by
|
Defines a field where an element in the batch is obtained by
|
||||||
slicing along the first dimension of the underlying data.
|
slicing along the first dimension of the underlying data.
|
||||||
@ -440,6 +473,7 @@ class MultiModalFieldConfig:
|
|||||||
keyword argument.
|
keyword argument.
|
||||||
slices: For each multi-modal item, the size of the slice that
|
slices: For each multi-modal item, the size of the slice that
|
||||||
is used to extract the data corresponding to it.
|
is used to extract the data corresponding to it.
|
||||||
|
dim: The dimension to slice, default to 0.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -455,6 +489,21 @@ class MultiModalFieldConfig:
|
|||||||
Element 1: [AAA]
|
Element 1: [AAA]
|
||||||
Element 2: [BBBB]
|
Element 2: [BBBB]
|
||||||
Element 3: [CC]
|
Element 3: [CC]
|
||||||
|
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
Given:
|
||||||
|
slices: [3, 4, 2]
|
||||||
|
dim: 1
|
||||||
|
|
||||||
|
Input:
|
||||||
|
Data: [[A],[A],[A],[B],[B],[B],[B],[C],[C]]
|
||||||
|
|
||||||
|
Output:
|
||||||
|
Element 1: [[A],[A],[A]]
|
||||||
|
Element 2: [[B],[B],[B],[B]]
|
||||||
|
Element 3: [[C],[C]]
|
||||||
|
|
||||||
See also:
|
See also:
|
||||||
:func:`MultiModalFieldConfig.flat`
|
:func:`MultiModalFieldConfig.flat`
|
||||||
@ -465,12 +514,11 @@ class MultiModalFieldConfig:
|
|||||||
f"but found shape: {size_per_item.shape}")
|
f"but found shape: {size_per_item.shape}")
|
||||||
|
|
||||||
slice_idxs = [0, *accumulate(size_per_item)]
|
slice_idxs = [0, *accumulate(size_per_item)]
|
||||||
slices = [
|
slices = [(slice(None, None, None), ) * dim +
|
||||||
slice(slice_idxs[i], slice_idxs[i + 1])
|
(slice(slice_idxs[i], slice_idxs[i + 1]), )
|
||||||
for i in range(len(size_per_item))
|
for i in range(len(size_per_item))]
|
||||||
]
|
|
||||||
|
|
||||||
return MultiModalFieldConfig.flat(modality, slices)
|
return MultiModalFieldConfig.flat(modality, slices, dim=dim)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shared(modality: str, batch_size: int):
|
def shared(modality: str, batch_size: int):
|
||||||
|
|||||||
@ -222,8 +222,7 @@ def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
|
|||||||
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
|
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
|
||||||
|
|
||||||
|
|
||||||
def uses_mrope(config: PretrainedConfig) -> bool:
|
def _uses_mrope(config: PretrainedConfig) -> bool:
|
||||||
"""Detect if the model with this config uses M-ROPE."""
|
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
return False
|
return False
|
||||||
@ -231,6 +230,24 @@ def uses_mrope(config: PretrainedConfig) -> bool:
|
|||||||
return "mrope_section" in rope_scaling
|
return "mrope_section" in rope_scaling
|
||||||
|
|
||||||
|
|
||||||
|
def uses_mrope(config: PretrainedConfig) -> bool:
|
||||||
|
"""Detect if the model with this config uses M-ROPE."""
|
||||||
|
return _uses_mrope(config) or thinker_uses_mrope(config)
|
||||||
|
|
||||||
|
|
||||||
|
def thinker_uses_mrope(config: PretrainedConfig) -> bool:
|
||||||
|
"""Detect if the model contains a thinker config and it uses M-ROPE."""
|
||||||
|
thinker_config = getattr(config, "thinker_config", None)
|
||||||
|
if thinker_config is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
thinker_text_config = getattr(thinker_config, "text_config", None)
|
||||||
|
if thinker_text_config is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return uses_mrope(thinker_text_config)
|
||||||
|
|
||||||
|
|
||||||
def is_encoder_decoder(config: PretrainedConfig) -> bool:
|
def is_encoder_decoder(config: PretrainedConfig) -> bool:
|
||||||
"""Detect if the model with this config is used as an encoder/decoder."""
|
"""Detect if the model with this config is used as an encoder/decoder."""
|
||||||
text_config = getattr(config, "text_config", None)
|
text_config = getattr(config, "text_config", None)
|
||||||
@ -740,6 +757,11 @@ def get_hf_text_config(config: PretrainedConfig):
|
|||||||
# if transformers config doesn't align with this assumption.
|
# if transformers config doesn't align with this assumption.
|
||||||
assert hasattr(config.text_config, "num_attention_heads")
|
assert hasattr(config.text_config, "num_attention_heads")
|
||||||
return config.text_config
|
return config.text_config
|
||||||
|
elif hasattr(config, "thinker_config"):
|
||||||
|
# TODO(suyang.fy): Refactor code.
|
||||||
|
# For Qwen2.5-Omni, change hf_text_config to
|
||||||
|
# thinker_config.text_config.
|
||||||
|
return config.thinker_config.text_config
|
||||||
else:
|
else:
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|||||||
@ -111,6 +111,55 @@ def cached_processor_from_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_feature_extractor(
|
||||||
|
processor_name: str,
|
||||||
|
*args: Any,
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
"""Load an audio feature extractor for the given model name
|
||||||
|
via HuggingFace."""
|
||||||
|
# don't put this import at the top level
|
||||||
|
# it will call torch.cuda.device_count()
|
||||||
|
from transformers import AutoFeatureExtractor
|
||||||
|
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
||||||
|
try:
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
|
processor_name,
|
||||||
|
*args,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**kwargs)
|
||||||
|
except ValueError as e:
|
||||||
|
# If the error pertains to the processor class not existing or not
|
||||||
|
# currently being imported, suggest using the --trust-remote-code flag.
|
||||||
|
# Unlike AutoTokenizer, AutoImageProcessor does not separate such errors
|
||||||
|
if not trust_remote_code:
|
||||||
|
err_msg = (
|
||||||
|
"Failed to load the feature extractor. If the feature "
|
||||||
|
"extractor is a custom extractor not yet available in the "
|
||||||
|
"HuggingFace transformers library, consider setting "
|
||||||
|
"`trust_remote_code=True` in LLM or using the "
|
||||||
|
"`--trust-remote-code` flag in the CLI.")
|
||||||
|
raise RuntimeError(err_msg) from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
return cast(FeatureExtractionMixin, feature_extractor)
|
||||||
|
|
||||||
|
|
||||||
|
cached_get_feature_extractor = lru_cache(get_feature_extractor)
|
||||||
|
|
||||||
|
|
||||||
|
def cached_feature_extractor_from_config(
|
||||||
|
model_config: "ModelConfig",
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
return cached_get_feature_extractor(
|
||||||
|
model_config.model,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
**_merge_mm_kwargs(model_config, **kwargs),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_image_processor(
|
def get_image_processor(
|
||||||
processor_name: str,
|
processor_name: str,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
|
|||||||
@ -355,6 +355,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
image_grid_thw = []
|
image_grid_thw = []
|
||||||
video_grid_thw = []
|
video_grid_thw = []
|
||||||
second_per_grid_ts = []
|
second_per_grid_ts = []
|
||||||
|
audio_feature_lengths = []
|
||||||
|
use_audio_in_video = False
|
||||||
for mm_input in self.requests[req_id].mm_inputs:
|
for mm_input in self.requests[req_id].mm_inputs:
|
||||||
if mm_input.get("image_grid_thw") is not None:
|
if mm_input.get("image_grid_thw") is not None:
|
||||||
image_grid_thw.extend(
|
image_grid_thw.extend(
|
||||||
@ -365,6 +367,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if mm_input.get("second_per_grid_ts") is not None:
|
if mm_input.get("second_per_grid_ts") is not None:
|
||||||
second_per_grid_ts.extend(
|
second_per_grid_ts.extend(
|
||||||
mm_input["second_per_grid_ts"])
|
mm_input["second_per_grid_ts"])
|
||||||
|
if mm_input.get("audio_feature_lengths") is not None:
|
||||||
|
audio_feature_lengths.extend(
|
||||||
|
mm_input["audio_feature_lengths"])
|
||||||
|
if mm_input.get("use_audio_in_video") is True:
|
||||||
|
use_audio_in_video = True
|
||||||
|
|
||||||
hf_config = self.model_config.hf_config
|
hf_config = self.model_config.hf_config
|
||||||
|
|
||||||
@ -376,6 +383,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
image_grid_thw=image_grid_thw,
|
image_grid_thw=image_grid_thw,
|
||||||
video_grid_thw=video_grid_thw,
|
video_grid_thw=video_grid_thw,
|
||||||
second_per_grid_ts=second_per_grid_ts,
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
|
use_audio_in_video=use_audio_in_video,
|
||||||
)
|
)
|
||||||
|
|
||||||
req_ids_to_add.append(req_id)
|
req_ids_to_add.append(req_id)
|
||||||
|
|||||||
@ -382,11 +382,17 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
|||||||
|
|
||||||
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
||||||
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
||||||
assert image_grid_thw is not None or video_grid_thw is not None, (
|
audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
|
||||||
"mrope embedding type requires multi-modal input mapper "
|
None)
|
||||||
"returns 'image_grid_thw' or 'video_grid_thw'.")
|
assert (
|
||||||
|
image_grid_thw is not None or video_grid_thw is not None
|
||||||
|
or audio_feature_lengths is not None), (
|
||||||
|
"mrope embedding type requires multi-modal input mapper "
|
||||||
|
"returns 'image_grid_thw' or 'video_grid_thw' or "
|
||||||
|
"'audio_feature_lengths'.")
|
||||||
|
|
||||||
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
|
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
|
||||||
|
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
|
||||||
hf_config = self.runner.model_config.hf_config
|
hf_config = self.runner.model_config.hf_config
|
||||||
token_ids = seq_data.get_token_ids()
|
token_ids = seq_data.get_token_ids()
|
||||||
|
|
||||||
@ -398,6 +404,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
|||||||
video_grid_thw=video_grid_thw,
|
video_grid_thw=video_grid_thw,
|
||||||
second_per_grid_ts=second_per_grid_ts,
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
context_len=computed_len,
|
context_len=computed_len,
|
||||||
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
|
use_audio_in_video=use_audio_in_video,
|
||||||
)
|
)
|
||||||
seq_data.mrope_position_delta = mrope_position_delta
|
seq_data.mrope_position_delta = mrope_position_delta
|
||||||
|
|
||||||
|
|||||||
@ -699,11 +699,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
if self.runner.model_config.uses_mrope:
|
if self.runner.model_config.uses_mrope:
|
||||||
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
||||||
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
||||||
assert image_grid_thw is not None or video_grid_thw is not None, (
|
audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
|
||||||
"mrope embedding type requires multi-modal input mapper "
|
None)
|
||||||
"returns 'image_grid_thw' or 'video_grid_thw'.")
|
assert (
|
||||||
|
image_grid_thw is not None or video_grid_thw is not None
|
||||||
|
or audio_feature_lengths is not None), (
|
||||||
|
"mrope embedding type requires multi-modal input mapper "
|
||||||
|
"returns 'image_grid_thw' or 'video_grid_thw' or "
|
||||||
|
"'audio_feature_lengths'.")
|
||||||
|
|
||||||
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
|
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
|
||||||
|
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
|
||||||
hf_config = self.runner.model_config.hf_config
|
hf_config = self.runner.model_config.hf_config
|
||||||
|
|
||||||
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
|
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
|
||||||
@ -721,6 +727,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
second_per_grid_ts=second_per_grid_ts,
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
context_len=inter_data.context_lens[seq_idx],
|
context_len=inter_data.context_lens[seq_idx],
|
||||||
seq_len=inter_data.seq_lens[seq_idx],
|
seq_len=inter_data.seq_lens[seq_idx],
|
||||||
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
|
use_audio_in_video=use_audio_in_video,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_data.mrope_position_delta = mrope_position_delta
|
seq_data.mrope_position_delta = mrope_position_delta
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user