mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 06:35:00 +08:00
[Model] Add Granite Speech Support (#16246)
Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
aec9674dbe
commit
fa93cd9f60
@ -895,6 +895,13 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
* ✅︎
|
* ✅︎
|
||||||
|
- * `GraniteSpeechForConditionalGeneration`
|
||||||
|
* Granite Speech
|
||||||
|
* T + A
|
||||||
|
* `ibm-granite/granite-speech-3.3-8b`
|
||||||
|
* ✅︎
|
||||||
|
* ✅︎
|
||||||
|
* ✅︎
|
||||||
- * `H2OVLChatModel`
|
- * `H2OVLChatModel`
|
||||||
* H2OVL
|
* H2OVL
|
||||||
* T + I<sup>E+</sup>
|
* T + I<sup>E+</sup>
|
||||||
|
|||||||
@ -38,6 +38,37 @@ class ModelRequestData(NamedTuple):
|
|||||||
# Unless specified, these settings have been tested to work on a single L4.
|
# Unless specified, these settings have been tested to work on a single L4.
|
||||||
|
|
||||||
|
|
||||||
|
# Granite Speech
|
||||||
|
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
||||||
|
# NOTE - the setting in this example are somehat different than what is
|
||||||
|
# optimal for granite speech, and it is generally recommended to use beam
|
||||||
|
# search. Check the model README for suggested settings.
|
||||||
|
# https://huggingface.co/ibm-granite/granite-speech-3.3-8b
|
||||||
|
model_name = "ibm-granite/granite-speech-3.3-8b"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=2048,
|
||||||
|
max_num_seqs=2,
|
||||||
|
enable_lora=True,
|
||||||
|
max_lora_rank=64,
|
||||||
|
limit_mm_per_prompt={"audio": audio_count},
|
||||||
|
)
|
||||||
|
|
||||||
|
# The model has an audio-specific lora directly in its model dir;
|
||||||
|
# it should be enabled whenever you pass audio inputs to the model.
|
||||||
|
speech_lora_path = model_name
|
||||||
|
audio_placeholder = "<|audio|>" * audio_count
|
||||||
|
prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompt=prompts,
|
||||||
|
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# MiniCPM-O
|
# MiniCPM-O
|
||||||
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
|
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
|
||||||
model_name = "openbmb/MiniCPM-o-2_6"
|
model_name = "openbmb/MiniCPM-o-2_6"
|
||||||
@ -209,6 +240,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
|||||||
|
|
||||||
|
|
||||||
model_example_map = {
|
model_example_map = {
|
||||||
|
"granite_speech": run_granite_speech,
|
||||||
"minicpmo": run_minicpmo,
|
"minicpmo": run_minicpmo,
|
||||||
"phi4_mm": run_phi4mm,
|
"phi4_mm": run_phi4mm,
|
||||||
"qwen2_audio": run_qwen2_audio,
|
"qwen2_audio": run_qwen2_audio,
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
|||||||
from tests.models.utils import (TokensTextLogprobs,
|
from tests.models.utils import (TokensTextLogprobs,
|
||||||
TokensTextLogprobsPromptLogprobs)
|
TokensTextLogprobsPromptLogprobs)
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.assets.audio import AudioAsset
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.assets.video import VideoAsset
|
from vllm.assets.video import VideoAsset
|
||||||
from vllm.config import TaskOption, _get_and_verify_dtype
|
from vllm.config import TaskOption, _get_and_verify_dtype
|
||||||
@ -103,10 +104,25 @@ class _VideoAssets(_VideoAssetsBase):
|
|||||||
return [prompts["sample_demo_1"]]
|
return [prompts["sample_demo_1"]]
|
||||||
|
|
||||||
|
|
||||||
|
class _AudioAssetsBase(UserList[AudioAsset]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _AudioAssets(_AudioAssetsBase):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__([
|
||||||
|
AudioAsset("mary_had_lamb"),
|
||||||
|
AudioAsset("winning_call"),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
IMAGE_ASSETS = _ImageAssets()
|
IMAGE_ASSETS = _ImageAssets()
|
||||||
"""Singleton instance of :class:`_ImageAssets`."""
|
"""Singleton instance of :class:`_ImageAssets`."""
|
||||||
VIDEO_ASSETS = _VideoAssets()
|
VIDEO_ASSETS = _VideoAssets()
|
||||||
"""Singleton instance of :class:`_VideoAssets`."""
|
"""Singleton instance of :class:`_VideoAssets`."""
|
||||||
|
AUDIO_ASSETS = _AudioAssets()
|
||||||
|
"""Singleton instance of :class:`_AudioAssets`."""
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
@ -263,6 +279,11 @@ def video_assets() -> _VideoAssets:
|
|||||||
return VIDEO_ASSETS
|
return VIDEO_ASSETS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def audio_assets() -> _AudioAssets:
|
||||||
|
return AUDIO_ASSETS
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
||||||
_R = TypeVar("_R")
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
@ -390,10 +411,15 @@ class HfRunner:
|
|||||||
processor_kwargs["images"] = image
|
processor_kwargs["images"] = image
|
||||||
if videos is not None and (video := videos[i]) is not None:
|
if videos is not None and (video := videos[i]) is not None:
|
||||||
processor_kwargs["videos"] = video
|
processor_kwargs["videos"] = video
|
||||||
if audios is not None and (audio_tuple := audios[i]) is not None:
|
if audios is not None and (audio_inputs := audios[i]) is not None:
|
||||||
audio, sr = audio_tuple
|
# HACK - not all processors take sampling_rate; we should
|
||||||
|
# clean this up in the future.
|
||||||
|
if len(audio_inputs) == 2:
|
||||||
|
audio, sr = audio_inputs
|
||||||
processor_kwargs["audio"] = audio
|
processor_kwargs["audio"] = audio
|
||||||
processor_kwargs["sampling_rate"] = sr
|
processor_kwargs["sampling_rate"] = sr
|
||||||
|
else:
|
||||||
|
processor_kwargs["audio"] = audio_inputs
|
||||||
|
|
||||||
inputs = self.processor(**processor_kwargs)
|
inputs = self.processor(**processor_kwargs)
|
||||||
if isinstance(inputs, BatchFeature):
|
if isinstance(inputs, BatchFeature):
|
||||||
|
|||||||
143
tests/models/decoder_only/audio_language/test_granite_speech.py
Normal file
143
tests/models/decoder_only/audio_language/test_granite_speech.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from transformers import AutoModelForSpeechSeq2Seq
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
|
from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets
|
||||||
|
from ...registry import HF_EXAMPLE_MODELS
|
||||||
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
|
HF_AUDIO_PROMPT = "<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|><|audio|>can you transcribe the speech into a written format?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
def vllm_to_hf_output(
|
||||||
|
vllm_output: tuple[list[int], str, Optional[SampleLogprobs]],
|
||||||
|
) -> tuple[list[int], str, Optional[SampleLogprobs]]:
|
||||||
|
"""Sanitize hf output to be comparable with vllm output."""
|
||||||
|
output_ids, output_str, out_logprobs = vllm_output
|
||||||
|
|
||||||
|
hf_output_str = output_str + "<|end_of_text|>"
|
||||||
|
|
||||||
|
return output_ids, hf_output_str, out_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_NAME = "ibm-granite/granite-speech-3.3-8b"
|
||||||
|
# Audio lora co-exists directly in the model directory, but
|
||||||
|
# currently still needs to be passed directly to vLLM.
|
||||||
|
audio_lora_path = MODEL_NAME
|
||||||
|
models = [MODEL_NAME]
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
inputs: Sequence[tuple[list[str], PromptAudioInput]],
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
max_model_len: int,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
distributed_executor_backend: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""Inference result should be the same between hf and vllm.
|
||||||
|
|
||||||
|
All the audio fixtures for the test are from AUDIO_ASSETS.
|
||||||
|
For huggingface runner, we provide the audio as input.
|
||||||
|
For vllm runner, we provide MultiModalDataDict objects
|
||||||
|
and corresponding MultiModalConfig as input.
|
||||||
|
Note, the text input is also adjusted to abide by vllm contract.
|
||||||
|
The text output is sanitized to be able to compare with hf.
|
||||||
|
"""
|
||||||
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
# max_model_len should be greater than image_feature_size
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
task="generate",
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
max_num_seqs=1,
|
||||||
|
dtype=dtype,
|
||||||
|
limit_mm_per_prompt={"audio": 1},
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
enable_lora=True,
|
||||||
|
max_lora_rank=64,
|
||||||
|
enforce_eager=True,
|
||||||
|
) as vllm_model:
|
||||||
|
lora_request = LoRARequest("audio", 1, audio_lora_path)
|
||||||
|
vllm_outputs_per_case = [
|
||||||
|
vllm_model.generate_greedy_logprobs(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
audios=audios,
|
||||||
|
lora_request=lora_request)
|
||||||
|
for prompts, audios in inputs
|
||||||
|
]
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype,
|
||||||
|
auto_cls=AutoModelForSpeechSeq2Seq) as hf_model:
|
||||||
|
|
||||||
|
hf_processor = hf_model.processor
|
||||||
|
eos_token_id = hf_processor.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
hf_outputs_per_case = [
|
||||||
|
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
audios=[audios],
|
||||||
|
eos_token_id=eos_token_id)
|
||||||
|
for prompts, audios in inputs
|
||||||
|
]
|
||||||
|
|
||||||
|
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
|
||||||
|
vllm_outputs_per_case):
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=[
|
||||||
|
vllm_to_hf_output(output) for output in vllm_outputs
|
||||||
|
],
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", models)
|
||||||
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
|
@pytest.mark.parametrize("max_model_len", [2048])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [128])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
|
def test_models(hf_runner, vllm_runner, model: str, audio_assets: _AudioAssets,
|
||||||
|
dtype: str, max_model_len: int, max_tokens: int,
|
||||||
|
num_logprobs: int) -> None:
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
|
||||||
|
audio, sr = audio_assets[0].audio_and_sample_rate
|
||||||
|
# This model expects 16k sample rate, which our test audio
|
||||||
|
# already is; if this changes, it may break this test,
|
||||||
|
# so we check it directly
|
||||||
|
assert sr == 16000
|
||||||
|
run_test(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
[
|
||||||
|
([HF_AUDIO_PROMPT], [audio]),
|
||||||
|
],
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
num_logprobs=num_logprobs,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
@ -11,7 +11,7 @@ from transformers import AutoModel, AutoTokenizer
|
|||||||
from vllm.multimodal.audio import resample_audio_librosa
|
from vllm.multimodal.audio import resample_audio_librosa
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
|
|
||||||
from ....conftest import HfRunner, VllmRunner
|
from ....conftest import HfRunner, VllmRunner, _AudioAssets
|
||||||
from ....utils import RemoteOpenAIServer
|
from ....utils import RemoteOpenAIServer
|
||||||
from ...registry import HF_EXAMPLE_MODELS
|
from ...registry import HF_EXAMPLE_MODELS
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
@ -31,12 +31,6 @@ CHUNKED_PREFILL_KWARGS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def audio_assets():
|
|
||||||
from vllm.assets.audio import AudioAsset
|
|
||||||
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
|
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
|
||||||
def audio(request):
|
def audio(request):
|
||||||
from vllm.assets.audio import AudioAsset
|
from vllm.assets.audio import AudioAsset
|
||||||
@ -59,7 +53,7 @@ def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
|
|||||||
pytest.param({}, marks=pytest.mark.cpu_model),
|
pytest.param({}, marks=pytest.mark.cpu_model),
|
||||||
pytest.param(CHUNKED_PREFILL_KWARGS),
|
pytest.param(CHUNKED_PREFILL_KWARGS),
|
||||||
])
|
])
|
||||||
def server(request, audio_assets):
|
def server(request, audio_assets: _AudioAssets):
|
||||||
args = [
|
args = [
|
||||||
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
|
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
|
||||||
"--limit-mm-per-prompt",
|
"--limit-mm-per-prompt",
|
||||||
@ -230,8 +224,9 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
|
|||||||
pytest.param({}, marks=pytest.mark.cpu_model),
|
pytest.param({}, marks=pytest.mark.cpu_model),
|
||||||
pytest.param(CHUNKED_PREFILL_KWARGS),
|
pytest.param(CHUNKED_PREFILL_KWARGS),
|
||||||
])
|
])
|
||||||
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
|
def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets,
|
||||||
max_tokens: int, num_logprobs: int,
|
dtype: str, max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
vllm_kwargs: dict) -> None:
|
vllm_kwargs: dict) -> None:
|
||||||
|
|
||||||
vllm_prompt = _get_prompt(len(audio_assets),
|
vllm_prompt = _get_prompt(len(audio_assets),
|
||||||
@ -250,7 +245,7 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_online_serving(client, audio_assets):
|
async def test_online_serving(client, audio_assets: _AudioAssets):
|
||||||
"""Exercises online serving with/without chunked prefill enabled."""
|
"""Exercises online serving with/without chunked prefill enabled."""
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
|
|||||||
@ -254,6 +254,7 @@ def _test_processing_correctness_mistral(
|
|||||||
"adept/fuyu-8b",
|
"adept/fuyu-8b",
|
||||||
"google/gemma-3-4b-it",
|
"google/gemma-3-4b-it",
|
||||||
"THUDM/glm-4v-9b",
|
"THUDM/glm-4v-9b",
|
||||||
|
"ibm-granite/granite-speech-3.3-8b",
|
||||||
"h2oai/h2ovl-mississippi-800m",
|
"h2oai/h2ovl-mississippi-800m",
|
||||||
"OpenGVLab/InternVL2-1B",
|
"OpenGVLab/InternVL2-1B",
|
||||||
"HuggingFaceM4/Idefics3-8B-Llama3",
|
"HuggingFaceM4/Idefics3-8B-Llama3",
|
||||||
|
|||||||
@ -301,6 +301,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
|
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
|
||||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||||
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
|
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
|
||||||
|
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-8b", # noqa: E501
|
||||||
|
min_transformers_version="4.52.0"), # noqa: E501
|
||||||
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
|
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
||||||
|
|||||||
@ -517,7 +517,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
|
|||||||
|
|
||||||
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
raise TypeError(f"Unknown {modality} model type: {model_type}")
|
||||||
elif modality == "audio":
|
elif modality == "audio":
|
||||||
if model_type == "ultravox":
|
if model_type in ("ultravox", "granite_speech"):
|
||||||
return "<|audio|>"
|
return "<|audio|>"
|
||||||
if model_type == "phi4mm":
|
if model_type == "phi4mm":
|
||||||
return f"<|audio_{current_count}|>"
|
return f"<|audio_{current_count}|>"
|
||||||
|
|||||||
@ -60,6 +60,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig],
|
quant_config: Optional[QuantizationConfig],
|
||||||
cache_config: Optional[CacheConfig],
|
cache_config: Optional[CacheConfig],
|
||||||
is_cross_attention: bool = False,
|
is_cross_attention: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -139,7 +140,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
class Blip2QFormerSelfOutput(nn.Module):
|
class Blip2QFormerSelfOutput(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
@ -167,6 +168,7 @@ class Blip2QFormerAttention(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig],
|
quant_config: Optional[QuantizationConfig],
|
||||||
cache_config: Optional[CacheConfig],
|
cache_config: Optional[CacheConfig],
|
||||||
is_cross_attention: bool = False,
|
is_cross_attention: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -175,9 +177,10 @@ class Blip2QFormerAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
is_cross_attention=is_cross_attention,
|
is_cross_attention=is_cross_attention,
|
||||||
|
prefix=f"{prefix}.attention",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.output = Blip2QFormerSelfOutput(config)
|
self.output = Blip2QFormerSelfOutput(config, prefix=f"{prefix}.output")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -195,7 +198,7 @@ class Blip2QFormerAttention(nn.Module):
|
|||||||
|
|
||||||
class Blip2QFormerIntermediate(nn.Module):
|
class Blip2QFormerIntermediate(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
@ -209,7 +212,7 @@ class Blip2QFormerIntermediate(nn.Module):
|
|||||||
|
|
||||||
class Blip2QFormerOutput(nn.Module):
|
class Blip2QFormerOutput(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||||
@ -237,6 +240,7 @@ class Blip2QFormerLayer(nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig],
|
quant_config: Optional[QuantizationConfig],
|
||||||
cache_config: Optional[CacheConfig],
|
cache_config: Optional[CacheConfig],
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -244,7 +248,8 @@ class Blip2QFormerLayer(nn.Module):
|
|||||||
self.seq_len_dim = 1
|
self.seq_len_dim = 1
|
||||||
self.attention = Blip2QFormerAttention(config,
|
self.attention = Blip2QFormerAttention(config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.attention")
|
||||||
|
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
@ -253,13 +258,16 @@ class Blip2QFormerLayer(nn.Module):
|
|||||||
config,
|
config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
is_cross_attention=True)
|
is_cross_attention=True,
|
||||||
|
prefix=f"{prefix}.crossattention")
|
||||||
self.has_cross_attention = True
|
self.has_cross_attention = True
|
||||||
else:
|
else:
|
||||||
self.has_cross_attention = False
|
self.has_cross_attention = False
|
||||||
|
|
||||||
self.intermediate_query = Blip2QFormerIntermediate(config)
|
self.intermediate_query = Blip2QFormerIntermediate(
|
||||||
self.output_query = Blip2QFormerOutput(config)
|
config, prefix=f"{prefix}.intermediate_query")
|
||||||
|
self.output_query = Blip2QFormerOutput(config,
|
||||||
|
prefix=f"{prefix}.output_query")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -325,6 +333,7 @@ class Blip2QFormerEncoder(nn.Module):
|
|||||||
*,
|
*,
|
||||||
quant_config: Optional[QuantizationConfig],
|
quant_config: Optional[QuantizationConfig],
|
||||||
cache_config: Optional[CacheConfig],
|
cache_config: Optional[CacheConfig],
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -334,7 +343,8 @@ class Blip2QFormerEncoder(nn.Module):
|
|||||||
Blip2QFormerLayer(config,
|
Blip2QFormerLayer(config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
layer_idx=layer_idx)
|
layer_idx=layer_idx,
|
||||||
|
prefix=f"{prefix}.layer.{layer_idx}")
|
||||||
for layer_idx in range(config.num_hidden_layers)
|
for layer_idx in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -365,6 +375,7 @@ class Blip2QFormerModel(nn.Module):
|
|||||||
*,
|
*,
|
||||||
quant_config: Optional[QuantizationConfig],
|
quant_config: Optional[QuantizationConfig],
|
||||||
cache_config: Optional[CacheConfig],
|
cache_config: Optional[CacheConfig],
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -376,7 +387,8 @@ class Blip2QFormerModel(nn.Module):
|
|||||||
|
|
||||||
self.encoder = Blip2QFormerEncoder(config,
|
self.encoder = Blip2QFormerEncoder(config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
cache_config=cache_config)
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.encoder")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -511,7 +523,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
|||||||
|
|
||||||
self.qformer = Blip2QFormerModel(config.qformer_config,
|
self.qformer = Blip2QFormerModel(config.qformer_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qformer")
|
||||||
|
|
||||||
self.language_projection = nn.Linear(
|
self.language_projection = nn.Linear(
|
||||||
config.qformer_config.hidden_size,
|
config.qformer_config.hidden_size,
|
||||||
|
|||||||
777
vllm/model_executor/models/granite_speech.py
Normal file
777
vllm/model_executor/models/granite_speech.py
Normal file
@ -0,0 +1,777 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# Copyright 2025 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only IBM Granite speeech model."""
|
||||||
|
import math
|
||||||
|
from typing import Iterable, Mapping, Optional, Set, Tuple, TypedDict, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.sampler import get_sampler
|
||||||
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
|
MultiModalKwargs)
|
||||||
|
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||||
|
MultiModalDataParser)
|
||||||
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
BaseProcessingInfo, PromptReplacement,
|
||||||
|
PromptUpdate)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .blip2 import Blip2QFormerModel
|
||||||
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
|
SupportsMultiModal, SupportsPP)
|
||||||
|
from .utils import (AutoWeightsLoader, embed_multimodal,
|
||||||
|
init_vllm_registered_model, maybe_prefix)
|
||||||
|
|
||||||
|
|
||||||
|
### Audio Input
|
||||||
|
class GraniteSpeechAudioInputs(TypedDict):
|
||||||
|
|
||||||
|
input_features: torch.Tensor
|
||||||
|
"""Shape: `(bsz, num_features, 160)`"""
|
||||||
|
|
||||||
|
input_features_mask: torch.Tensor
|
||||||
|
"""Shape: `(bsz, num_features)`"""
|
||||||
|
|
||||||
|
audio_embed_sizes: list[int]
|
||||||
|
"""List of length `bsz`"""
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"audio": 1}
|
||||||
|
|
||||||
|
# There is no limit to the maximum number of audio tokens that can be
|
||||||
|
# encoded as features; we pick ~5000 as a number that is probably higher
|
||||||
|
# than we would expect to encounter. The sequence of length
|
||||||
|
# get_max_audio_len() produces get_max_audio_tokens().
|
||||||
|
def get_max_audio_tokens(self):
|
||||||
|
return 5001
|
||||||
|
|
||||||
|
def get_max_audio_len(self):
|
||||||
|
return 8000000
|
||||||
|
|
||||||
|
|
||||||
|
### Input Processing & Multimodal utils
|
||||||
|
class GraniteSpeechMultiModalProcessor(
|
||||||
|
BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]):
|
||||||
|
|
||||||
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
|
feature_extractor = self.info.get_hf_processor().audio_processor
|
||||||
|
sampling_rate = feature_extractor.melspec_kwargs["sample_rate"]
|
||||||
|
return MultiModalDataParser(target_sr=sampling_rate)
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return dict(
|
||||||
|
input_features=MultiModalFieldConfig.batched("audio"),
|
||||||
|
audio_embed_sizes=MultiModalFieldConfig.batched("audio"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> list[PromptUpdate]:
|
||||||
|
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
feature_extractor = processor.audio_processor
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
|
||||||
|
# Use getattr with default to be compatible with transformers<4.48
|
||||||
|
audio_token = getattr(processor, "audio_token", "<|audio|>")
|
||||||
|
audio_token_id = vocab[audio_token]
|
||||||
|
|
||||||
|
def get_replacement(item_idx: int):
|
||||||
|
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||||
|
audio = audios.get(item_idx)
|
||||||
|
audio_length = audio.shape[-1]
|
||||||
|
num_projector_features = feature_extractor._get_num_audio_features(
|
||||||
|
[audio_length])[0]
|
||||||
|
return [audio_token_id] * num_projector_features
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="audio",
|
||||||
|
target=[audio_token_id],
|
||||||
|
replacement=get_replacement,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _call_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: Mapping[str, object],
|
||||||
|
mm_kwargs: Mapping[str, object],
|
||||||
|
) -> BatchFeature:
|
||||||
|
mm_data = dict(mm_data)
|
||||||
|
audios = mm_data.pop("audios", [])
|
||||||
|
|
||||||
|
if audios:
|
||||||
|
# GraniteSpeechFeatureExtractor accepts "audio"
|
||||||
|
mm_data["audio"] = audios
|
||||||
|
|
||||||
|
processed_outputs = super()._call_hf_processor(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
mm_kwargs=mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "audio" in mm_data:
|
||||||
|
# Calculate the number of audio tokens per entry in the batch;
|
||||||
|
# This is used to split the batch back out after padding.
|
||||||
|
audio_token_index = self.info.get_hf_config().audio_token_index
|
||||||
|
processed_outputs["audio_embed_sizes"] = [
|
||||||
|
torch.sum(indices == audio_token_index).item()
|
||||||
|
for indices in processed_outputs["input_ids"]
|
||||||
|
]
|
||||||
|
|
||||||
|
return processed_outputs
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteSpeechDummyInputsBuilder(
|
||||||
|
BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]):
|
||||||
|
|
||||||
|
def get_dummy_mm_data(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> MultiModalDataDict:
|
||||||
|
num_audios = mm_counts.get("audio", 0)
|
||||||
|
return {
|
||||||
|
"audio":
|
||||||
|
self._get_dummy_audios(
|
||||||
|
length=self.info.get_max_audio_len(),
|
||||||
|
num_audios=num_audios,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
|
num_audios = mm_counts.get("audio", 0)
|
||||||
|
hf_processor = self.info.get_hf_processor()
|
||||||
|
audio_token = getattr(hf_processor, "audio_token", "<|audio|>")
|
||||||
|
return audio_token * num_audios
|
||||||
|
|
||||||
|
|
||||||
|
### QFormer Projector
|
||||||
|
class GraniteSpeechEncoderProjector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.projector_config.hidden_size
|
||||||
|
self.downsample_rate = config.downsample_rate
|
||||||
|
self.window_size = config.window_size
|
||||||
|
self.num_queries = config.window_size // config.downsample_rate
|
||||||
|
|
||||||
|
self.query = nn.Parameter(
|
||||||
|
torch.zeros(1, self.num_queries,
|
||||||
|
config.projector_config.hidden_size))
|
||||||
|
|
||||||
|
# NOTE - this is implemented generically in transformers,
|
||||||
|
# but for now we create the QFormer model directly since
|
||||||
|
# all existing models use this for the projector.
|
||||||
|
self.qformer = Blip2QFormerModel(
|
||||||
|
config.projector_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.qformer",
|
||||||
|
)
|
||||||
|
self.linear = nn.Linear(config.projector_config.hidden_size,
|
||||||
|
config.text_config.hidden_size)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
batch_size, seq_len, dim = hidden_states.size()
|
||||||
|
nblocks = math.ceil(seq_len / self.window_size)
|
||||||
|
pad = nblocks * self.window_size - seq_len
|
||||||
|
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad),
|
||||||
|
"constant", 0)
|
||||||
|
hidden_states = hidden_states.view(batch_size * nblocks,
|
||||||
|
self.window_size, dim)
|
||||||
|
|
||||||
|
last_hidden_state = self.qformer(
|
||||||
|
query_embeds=self.query.data,
|
||||||
|
encoder_hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_proj = self.linear(
|
||||||
|
last_hidden_state.view(
|
||||||
|
batch_size,
|
||||||
|
nblocks * self.window_size // self.downsample_rate,
|
||||||
|
-1,
|
||||||
|
))
|
||||||
|
return query_proj
|
||||||
|
|
||||||
|
|
||||||
|
# Encoder - conformer is adapted from: https://github.com/lucidrains/conformer.git
|
||||||
|
# NOTE - it would be nice to see if we can align this with other models using
|
||||||
|
# conformer in vLLM, e.g., phi4mm audio.
|
||||||
|
class GraniteSpeechConformerFeedForward(nn.Module):
|
||||||
|
"""Feedforward module for conformer encoder blocks."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
||||||
|
|
||||||
|
self.up_proj = ColumnParallelLinear(
|
||||||
|
input_size=config.hidden_dim,
|
||||||
|
output_size=config.hidden_dim * config.feedforward_mult,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.up_proj",
|
||||||
|
)
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
input_size=config.hidden_dim * config.feedforward_mult,
|
||||||
|
output_size=config.hidden_dim,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.pre_norm(hidden_states)
|
||||||
|
hidden_states, _ = self.up_proj(hidden_states)
|
||||||
|
hidden_states = self.silu(hidden_states)
|
||||||
|
hidden_states, _ = self.down_proj(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteSpeechConformerAttention(nn.Module):
|
||||||
|
"""Attention for conformer blocks using Shaw's relative positional
|
||||||
|
embeddings. See the following [paper](https://arxiv.org/pdf/1803.02155)
|
||||||
|
for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
inner_dim = config.dim_head * config.num_heads
|
||||||
|
self.max_pos_emb = config.max_pos_emb
|
||||||
|
self.context_size = config.context_size
|
||||||
|
self.num_heads = config.num_heads
|
||||||
|
self.dim_head = config.dim_head
|
||||||
|
self.scale = self.dim_head**-0.5
|
||||||
|
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
||||||
|
self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
|
||||||
|
self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, config.hidden_dim)
|
||||||
|
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1,
|
||||||
|
self.dim_head)
|
||||||
|
|
||||||
|
if self.context_size <= 0 or self.context_size > self.max_pos_emb:
|
||||||
|
raise ValueError(
|
||||||
|
"Context size is either less than 0 or exceeds the max_pos_emb"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
|
attention_dists: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.pre_norm(hidden_states)
|
||||||
|
bsz, num_features, _ = hidden_states.shape
|
||||||
|
|
||||||
|
num_blocks = math.ceil(num_features / self.context_size)
|
||||||
|
remainder = num_features % self.context_size
|
||||||
|
if remainder > 0:
|
||||||
|
# right padding to reach block size
|
||||||
|
hidden_states = torch.nn.functional.pad(
|
||||||
|
hidden_states, (0, 0, 0, self.context_size - remainder))
|
||||||
|
|
||||||
|
# NOTE: would be nice to try to use qkvparallellinear
|
||||||
|
# here for this block attention implementation if possible
|
||||||
|
query_states = self.to_q(hidden_states)
|
||||||
|
key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
|
||||||
|
|
||||||
|
query_states = query_states.reshape(bsz, num_blocks, self.context_size,
|
||||||
|
self.num_heads,
|
||||||
|
-1).transpose(2, 3)
|
||||||
|
key_states = key_states.reshape(bsz, num_blocks, self.context_size,
|
||||||
|
self.num_heads, -1).transpose(2, 3)
|
||||||
|
value_states = value_states.reshape(bsz, num_blocks, self.context_size,
|
||||||
|
self.num_heads,
|
||||||
|
-1).transpose(2, 3)
|
||||||
|
|
||||||
|
# shaw's relative positional embedding
|
||||||
|
dist = attention_dists.to(hidden_states.device)
|
||||||
|
rel_pos_emb = self.rel_pos_emb(dist)
|
||||||
|
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] +
|
||||||
|
list(rel_pos_emb.shape))
|
||||||
|
pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded,
|
||||||
|
dim=-1) * self.scale
|
||||||
|
|
||||||
|
if remainder > 0:
|
||||||
|
# masked attention in the extended block
|
||||||
|
mask = torch.ones(self.context_size,
|
||||||
|
self.context_size,
|
||||||
|
dtype=bool,
|
||||||
|
device=hidden_states.device)
|
||||||
|
mask[:remainder, :remainder] = 0
|
||||||
|
mask_value = -torch.finfo(pos_attn.dtype).max
|
||||||
|
pos_attn[:, -1, :].masked_fill_(mask, mask_value)
|
||||||
|
|
||||||
|
with torch.nn.attention.sdpa_kernel(
|
||||||
|
torch.nn.attention.SDPBackend.MATH):
|
||||||
|
out = F.scaled_dot_product_attention(query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=pos_attn,
|
||||||
|
scale=self.scale)
|
||||||
|
out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
|
||||||
|
return self.to_out(out[:, :num_features, :])
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
|
||||||
|
"""Wrapper for padded 1D pointwise convolution."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
chan_in: int,
|
||||||
|
chan_out: int,
|
||||||
|
kernel_size: int,
|
||||||
|
prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
# Padding for the 1D conv is symmetric or close (i.e., offset by one).
|
||||||
|
pad = kernel_size // 2
|
||||||
|
pad_offset = (kernel_size + 1) % 2
|
||||||
|
self.padding = (pad, pad - pad_offset)
|
||||||
|
|
||||||
|
self.conv = nn.Conv1d(chan_in,
|
||||||
|
chan_out,
|
||||||
|
kernel_size,
|
||||||
|
groups=chan_in,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = F.pad(hidden_states, self.padding)
|
||||||
|
return self.conv(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteSpeechConformerConvModule(nn.Module):
|
||||||
|
"""Conformer conv module consisting of several 1D/depthwise 1D
|
||||||
|
convolutional layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
inner_dim = config.hidden_dim * config.conv_expansion_factor
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(config.hidden_dim)
|
||||||
|
self.up_conv = nn.Conv1d(config.hidden_dim, inner_dim * 2, 1)
|
||||||
|
self.glu = nn.GLU(dim=1)
|
||||||
|
self.depth_conv = GraniteSpeechConformerDepthWiseConv1d(
|
||||||
|
inner_dim,
|
||||||
|
inner_dim,
|
||||||
|
kernel_size=config.conv_kernel_size,
|
||||||
|
prefix=f"{prefix}.depth_conv",
|
||||||
|
)
|
||||||
|
self.silu = nn.SiLU()
|
||||||
|
self.batch_norm = nn.BatchNorm1d(inner_dim)
|
||||||
|
self.down_conv = nn.Conv1d(inner_dim, config.hidden_dim, 1)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
hidden_states = self.up_conv(hidden_states.permute(0, 2, 1))
|
||||||
|
hidden_states = self.glu(hidden_states)
|
||||||
|
hidden_states = self.depth_conv(hidden_states)
|
||||||
|
hidden_states = self.silu(self.batch_norm(hidden_states))
|
||||||
|
hidden_states = self.down_conv(hidden_states).permute(0, 2, 1)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteSpeechConformerBlock(nn.Module):
|
||||||
|
"""Conformer block, consisting largely of linear layers,
|
||||||
|
attention, and convolutional layers."""
|
||||||
|
|
||||||
|
def __init__(self, config: PretrainedConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
self.ff1 = GraniteSpeechConformerFeedForward(config,
|
||||||
|
prefix=f"{prefix}.ff1")
|
||||||
|
self.attn = GraniteSpeechConformerAttention(config,
|
||||||
|
prefix=f"{prefix}.attn")
|
||||||
|
self.conv = GraniteSpeechConformerConvModule(config,
|
||||||
|
prefix=f"{prefix}.conv")
|
||||||
|
self.ff2 = GraniteSpeechConformerFeedForward(config,
|
||||||
|
prefix=f"{prefix}.ff2")
|
||||||
|
self.post_norm = nn.LayerNorm(config.hidden_dim)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor,
|
||||||
|
attention_dists: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
|
||||||
|
hidden_states = self.attn(
|
||||||
|
hidden_states, attention_dists=attention_dists) + hidden_states
|
||||||
|
hidden_states = self.conv(hidden_states) + hidden_states
|
||||||
|
hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
|
||||||
|
hidden_states = self.post_norm(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class GraniteSpeechCTCEncoder(nn.Module):
|
||||||
|
"""CTC Encoder comprising conformer blocks and additional linear layers."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
prefix: str,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Precompute clamped relative positional encoding distances
|
||||||
|
seq = torch.arange(config.context_size)
|
||||||
|
relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
|
||||||
|
self.attention_dists = torch.clamp(
|
||||||
|
relpos_dist, -config.context_size,
|
||||||
|
config.context_size) + config.max_pos_emb
|
||||||
|
|
||||||
|
self.input_linear = nn.Linear(config.input_dim,
|
||||||
|
config.hidden_dim,
|
||||||
|
bias=True)
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
GraniteSpeechConformerBlock(
|
||||||
|
config,
|
||||||
|
prefix=f"{prefix}.layers.{idx}",
|
||||||
|
) for idx in range(config.num_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.out = ColumnParallelLinear(
|
||||||
|
input_size=config.hidden_dim,
|
||||||
|
output_size=config.output_dim,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.out",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_mid = RowParallelLinear(
|
||||||
|
input_size=config.output_dim,
|
||||||
|
output_size=config.hidden_dim,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.out_mid",
|
||||||
|
)
|
||||||
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
self.num_layers = config.num_layers
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor):
|
||||||
|
hidden_states = self.input_linear(hidden_states)
|
||||||
|
for idx, layer in enumerate(self.layers, start=1):
|
||||||
|
hidden_states = layer(hidden_states,
|
||||||
|
attention_dists=self.attention_dists)
|
||||||
|
|
||||||
|
if idx == self.num_layers // 2:
|
||||||
|
hidden_states_mid = hidden_states.clone()
|
||||||
|
hidden_states_mid, _ = self.out(hidden_states_mid)
|
||||||
|
hidden_states_mid = self.softmax(hidden_states_mid)
|
||||||
|
hidden_states_mid, _ = self.out_mid(hidden_states_mid)
|
||||||
|
hidden_states += hidden_states_mid
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
GraniteSpeechMultiModalProcessor,
|
||||||
|
info=GraniteSpeechMultiModalProcessingInfo,
|
||||||
|
dummy_inputs=GraniteSpeechDummyInputsBuilder)
|
||||||
|
class GraniteSpeechForConditionalGeneration(
|
||||||
|
nn.Module,
|
||||||
|
SupportsMultiModal,
|
||||||
|
SupportsPP,
|
||||||
|
SupportsLoRA,
|
||||||
|
):
|
||||||
|
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
|
||||||
|
# The language model is typically a Granite LLM
|
||||||
|
self.language_model = init_vllm_registered_model(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
hf_config=config.text_config,
|
||||||
|
prefix=maybe_prefix(prefix, "language_model"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Conformer encoder
|
||||||
|
self.encoder = GraniteSpeechCTCEncoder(
|
||||||
|
config=config.encoder_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.encoder",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Blip2 QFormer
|
||||||
|
self.projector = GraniteSpeechEncoderProjector(
|
||||||
|
config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.projector",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def _parse_and_validate_audio_input(
|
||||||
|
self,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> Optional[GraniteSpeechAudioInputs]:
|
||||||
|
input_features = kwargs.pop("input_features", None)
|
||||||
|
input_features_mask = kwargs.pop("input_features_mask", None)
|
||||||
|
audio_embed_sizes = kwargs.pop("audio_embed_sizes", None)
|
||||||
|
if input_features is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# If we have a batch of variable feature length audio clips, we need
|
||||||
|
# to mask the features; usually we would get an input_features_mask
|
||||||
|
# from the processor, but we handle rebuilding it here since
|
||||||
|
# vLLM generally processes everything independently + batches.
|
||||||
|
if input_features_mask is None:
|
||||||
|
input_features_mask = self._build_input_features_mask(
|
||||||
|
audio_embed_sizes)
|
||||||
|
|
||||||
|
if not isinstance(input_features, (torch.Tensor, list)):
|
||||||
|
raise ValueError("Incorrect type of audio input features. "
|
||||||
|
f"Got type: {type(input_features)}")
|
||||||
|
|
||||||
|
if input_features_mask is not None and not isinstance(
|
||||||
|
input_features_mask, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of audio input features mask. "
|
||||||
|
f"Got type: {type(input_features_mask)}")
|
||||||
|
|
||||||
|
if isinstance(input_features, torch.Tensor):
|
||||||
|
# Granite speech currently only allows one audio token per instance
|
||||||
|
# and features are already unsqueezed in the processor, so one
|
||||||
|
# instance will have shape [1, {num_features}, 160]. As such,
|
||||||
|
# input features will usually be of shape
|
||||||
|
# [bsz, 1, num_features, 160], which we squeeze to be 3D here.
|
||||||
|
if len(input_features.shape) == 4:
|
||||||
|
input_features = input_features.squeeze(1)
|
||||||
|
if len(input_features.shape) != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"Squeezed input features should be 3D but are of shape "
|
||||||
|
f"{input_features.shape}")
|
||||||
|
input_features = input_features.to(
|
||||||
|
self.encoder.input_linear.weight.dtype)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Otherwise we have a list of tensors, which are almost certainly
|
||||||
|
# differing in their respective numbers of audio features;
|
||||||
|
# stack them into a 3D tensor of size [bsz, most_num_features, 160].
|
||||||
|
input_features = self._pad_and_stack_input_features(
|
||||||
|
input_features, ).to(self.encoder.input_linear.weight.dtype)
|
||||||
|
|
||||||
|
return GraniteSpeechAudioInputs(
|
||||||
|
input_features=input_features,
|
||||||
|
input_features_mask=input_features_mask,
|
||||||
|
audio_embed_sizes=audio_embed_sizes.flatten().tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_input_features_mask(
|
||||||
|
self,
|
||||||
|
audio_embed_sizes: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Calculate the input features mask, which will generally be used
|
||||||
|
to mask the the padded features for all entries in the batch except
|
||||||
|
for those with the most audio features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_embed_sizes: torch.Tensor
|
||||||
|
Tensor of num features in each seq in the batch.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Mask of shape (bsz, num_features) to be applied to
|
||||||
|
the audio features prior to splitting the audio embeddings.
|
||||||
|
"""
|
||||||
|
most_audio_features = torch.max(audio_embed_sizes).item()
|
||||||
|
mask_indices = torch.arange(
|
||||||
|
most_audio_features,
|
||||||
|
device=audio_embed_sizes.device,
|
||||||
|
).view(1, -1)
|
||||||
|
input_features_mask = mask_indices < audio_embed_sizes.view(-1, 1)
|
||||||
|
return input_features_mask
|
||||||
|
|
||||||
|
def _pad_and_stack_input_features(
|
||||||
|
self,
|
||||||
|
input_features: list[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Given a list of input features of varying length, pad them to the
|
||||||
|
same length and stack them into a torch.Tensor.
|
||||||
|
|
||||||
|
NOTE: Usually, padding is done in the input processor/feature extractor
|
||||||
|
and zero padded prior to the computation of the Mel features; the
|
||||||
|
resulting values are only constant within a batch and generally nonzero
|
||||||
|
(i.e., slightly negative nums); we should validate that this is okay
|
||||||
|
since we don't use a feature attention mask, but the more important
|
||||||
|
thing is that we apply the input_features_mask with variable len
|
||||||
|
batches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_features: list[torch.Tensor]
|
||||||
|
Input features to be coerced into a tensor.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Tensor of shape [bsz, num_features, 160], where
|
||||||
|
num_features is the max number of features of any entry in the
|
||||||
|
batch.
|
||||||
|
"""
|
||||||
|
# Input features are of shape [bsz, num_features, 160]
|
||||||
|
feat_lens = [feats.shape[1] for feats in input_features]
|
||||||
|
padding = [max(feat_lens) - length for length in feat_lens]
|
||||||
|
# TODO (Alex) - Validate that it's okay to zero pad like this;
|
||||||
|
# in transformers we zero pad prior to calculating the speech features,
|
||||||
|
# so the value is not zero and is dependent on the batched features.
|
||||||
|
padded = [
|
||||||
|
torch.nn.functional.pad(feats, (0, 0, 0, pad, 0, 0))
|
||||||
|
for feats, pad in zip(input_features, padding)
|
||||||
|
]
|
||||||
|
stacked_features = torch.cat(padded, dim=0).to(input_features[0])
|
||||||
|
return stacked_features
|
||||||
|
|
||||||
|
def _process_audio_input(
|
||||||
|
self,
|
||||||
|
audio_input: GraniteSpeechAudioInputs,
|
||||||
|
) -> tuple[torch.Tensor]:
|
||||||
|
"""Compute the audio features to be merged into the LLM embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_input: GraniteSpeechAudioInputs
|
||||||
|
Audio inputs object containing Mel features, an input features
|
||||||
|
mask, and the (flattened) number of audio tokens per instance.
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor]: List of length bsz.
|
||||||
|
"""
|
||||||
|
# TODO (Alex) - support embedding inputs
|
||||||
|
encoder_embeds = self.encoder(audio_input["input_features"])
|
||||||
|
# [bsz, <max feature size>, 4096]
|
||||||
|
projected_embeds = self.projector(encoder_embeds)
|
||||||
|
# Apply mask on variable length audio features
|
||||||
|
masked_embeds = projected_embeds[audio_input["input_features_mask"]]
|
||||||
|
# Split variable length features into a tuple
|
||||||
|
return torch.split(masked_embeds, audio_input["audio_embed_sizes"])
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(
|
||||||
|
self,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> Optional[MultiModalEmbeddings]:
|
||||||
|
"""Compute the audio embeddings if audio inputs are present."""
|
||||||
|
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||||
|
if audio_input is None:
|
||||||
|
return None
|
||||||
|
audio_features = self._process_audio_input(audio_input)
|
||||||
|
return audio_features
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute the merged LLM / audio embeddings."""
|
||||||
|
if multimodal_embeddings is None:
|
||||||
|
return self.language_model.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
inputs_embeds = embed_multimodal(
|
||||||
|
input_ids,
|
||||||
|
self.config.audio_token_index,
|
||||||
|
self.language_model.model.get_input_embeddings,
|
||||||
|
multimodal_embeddings,
|
||||||
|
)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
|
# condition is for v0 compatibility.
|
||||||
|
elif inputs_embeds is None:
|
||||||
|
audio_embeds = self.get_multimodal_embeddings(**kwargs)
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids, audio_embeds)
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
model_output = self.language_model(input_ids, positions,
|
||||||
|
intermediate_tensors, inputs_embeds)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.language_model.compute_logits(
|
||||||
|
hidden_states,
|
||||||
|
sampling_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
self,
|
||||||
|
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||||
|
) -> Set[str]:
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
return loader.load_weights(weights)
|
||||||
|
|
||||||
|
def get_mm_mapping(self) -> MultiModelKeys:
|
||||||
|
"""Get the module prefix in multimodal models."""
|
||||||
|
return MultiModelKeys.from_string_field(
|
||||||
|
language_model="language_model",
|
||||||
|
connector="projector",
|
||||||
|
tower_model="encoder",
|
||||||
|
)
|
||||||
@ -178,6 +178,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
|
||||||
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
|
"Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501
|
||||||
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
||||||
|
"GraniteSpeechForConditionalGeneration": ("granite_speech", "GraniteSpeechForConditionalGeneration"), # noqa: E501
|
||||||
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
"H2OVLChatModel": ("h2ovl", "H2OVLChatModel"),
|
||||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user