mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:15:01 +08:00
[Core] Add Support for Default Modality Specific LoRAs [generate / chat completions] (#19126)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
3de2ed767f
commit
41060c6e08
@ -272,3 +272,80 @@ The new format of `--lora-modules` is mainly to support the display of parent mo
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Default LoRA Models For Multimodal Models
|
||||||
|
|
||||||
|
Some models, e.g., [Granite Speech](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) and [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) multimodal, contain LoRA adapter(s) that are expected to always be applied when a given modality is present. This can be a bit tedious to manage with the above approaches, as it requires the user to send the `LoRARequest` (offline) or to filter requests between the base model and LoRA model (server) depending on the content of the request's multimodal data.
|
||||||
|
|
||||||
|
To this end, we allow registration of default multimodal LoRAs to handle this automatically, where users can map each modality to a LoRA adapter to automatically apply it when the corresponding inputs are present. Note that currently, we only allow one LoRA per prompt; if several modalities are provided, each of which are registered to a given modality, none of them will be applied.
|
||||||
|
|
||||||
|
Example usage for offline inference:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.assets.audio import AudioAsset
|
||||||
|
|
||||||
|
model_id = "ibm-granite/granite-speech-3.3-2b"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
def get_prompt(question: str, has_audio: bool):
|
||||||
|
"""Build the input prompt to send to vLLM."""
|
||||||
|
if has_audio:
|
||||||
|
question = f"<|audio|>{question}"
|
||||||
|
chat = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": question
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return tokenizer.apply_chat_template(chat, tokenize=False)
|
||||||
|
|
||||||
|
|
||||||
|
model = LLM(
|
||||||
|
model=model_id,
|
||||||
|
enable_lora=True,
|
||||||
|
max_lora_rank=64,
|
||||||
|
max_model_len=2048,
|
||||||
|
limit_mm_per_prompt={"audio": 1},
|
||||||
|
# Will always pass a `LoRARequest` with the `model_id`
|
||||||
|
# whenever audio is contained in the request data.
|
||||||
|
default_mm_loras = {"audio": model_id},
|
||||||
|
enforce_eager=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
question = "can you transcribe the speech into a written format?"
|
||||||
|
prompt_with_audio = get_prompt(
|
||||||
|
question=question,
|
||||||
|
has_audio=True,
|
||||||
|
)
|
||||||
|
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||||
|
|
||||||
|
inputs = {
|
||||||
|
"prompt": prompt_with_audio,
|
||||||
|
"multi_modal_data": {
|
||||||
|
"audio": audio,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
inputs,
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=64,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also pass a json dictionary of `--default-mm-loras` mapping modalities to LoRA model IDs. For example, when starting the server:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
vllm serve ibm-granite/granite-speech-3.3-2b \
|
||||||
|
--max-model-len 2048 \
|
||||||
|
--enable-lora \
|
||||||
|
--default-mm-loras '{"audio":"ibm-granite/granite-speech-3.3-2b"}' \
|
||||||
|
--max-lora-rank 64
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: Default multimodal LoRAs are currently only available for `.generate` and chat completions.
|
||||||
|
|||||||
107
tests/entrypoints/openai/test_default_mm_loras.py
Normal file
107
tests/entrypoints/openai/test_default_mm_loras.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from ...conftest import AudioTestAssets
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
# NOTE - the tests in this module are currently analogous to test_chat, but are
|
||||||
|
# separated to avoid OOM killing due to module-scoped servers, since we
|
||||||
|
# need a multimodal model for these tests.
|
||||||
|
|
||||||
|
# Contains a modality specific lora alongside the base model
|
||||||
|
MULTIMODAL_MODEL_NAME = snapshot_download(
|
||||||
|
"microsoft/Phi-4-multimodal-instruct")
|
||||||
|
AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora")
|
||||||
|
|
||||||
|
ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def monkeypatch_module():
|
||||||
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
|
mpatch = MonkeyPatch()
|
||||||
|
yield mpatch
|
||||||
|
mpatch.undo()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module", params=[False, True])
|
||||||
|
def multimodal_server(request, monkeypatch_module): # noqa: F811
|
||||||
|
|
||||||
|
use_v1 = request.param
|
||||||
|
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
|
||||||
|
|
||||||
|
args = [
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
|
"half",
|
||||||
|
"--max-model-len",
|
||||||
|
"12800",
|
||||||
|
"--enforce-eager",
|
||||||
|
# lora config below
|
||||||
|
"--enable-lora",
|
||||||
|
"--lora-modules",
|
||||||
|
f"speech={AUDIO_LORA_PATH}",
|
||||||
|
"--max-lora-rank",
|
||||||
|
"320",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"2",
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
"0.8",
|
||||||
|
"--default-mm-loras",
|
||||||
|
f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def multi_modal_client(multimodal_server):
|
||||||
|
async with multimodal_server.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
# base model with default lora should give the same response as lora model
|
||||||
|
"model_name",
|
||||||
|
[MULTIMODAL_MODEL_NAME, "speech"],
|
||||||
|
)
|
||||||
|
async def test_default_mm_lora_chat_completions(
|
||||||
|
model_name: str,
|
||||||
|
multi_modal_client: openai.AsyncOpenAI,
|
||||||
|
audio_assets: AudioTestAssets,
|
||||||
|
):
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Can you transcribe this audio?",
|
||||||
|
}, {
|
||||||
|
"type": "audio_url",
|
||||||
|
"audio_url": {
|
||||||
|
"url": audio_assets[0].url
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
}]
|
||||||
|
|
||||||
|
chat_completion = await multi_modal_client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_completion_tokens=128,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
assert len(chat_completion.choices) > 0
|
||||||
|
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 0
|
||||||
|
assert message.content == ACTIVE_MM_LORA_RESPONSE
|
||||||
118
tests/lora/test_default_mm_loras.py
Normal file
118
tests/lora/test_default_mm_loras.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Tests for applying default registered multimodal loras.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
from ..conftest import AudioTestAssets, VllmRunner
|
||||||
|
|
||||||
|
MODEL_PATH = snapshot_download("microsoft/Phi-4-multimodal-instruct")
|
||||||
|
AUDIO_LORA_PATH = os.path.join(MODEL_PATH, "speech-lora")
|
||||||
|
IMAGE_LORA_PATH = os.path.join(MODEL_PATH, "vision-lora")
|
||||||
|
|
||||||
|
AUDIO_PROMPT = "<|user|><|audio_1|>Can you transcribe this audio?<|end|><|assistant|>" # noqa: E501
|
||||||
|
|
||||||
|
# Responses are greedy decoded; we just check the end of
|
||||||
|
# the generated text. If the lora is inactive, this model
|
||||||
|
# generates commentary on the transcription.
|
||||||
|
RESPONSE_SUFFIX_WITH_LORA = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
|
||||||
|
RESPONSE_SUFFIX_WITHOUT_LORA = "Certainly! Here is the transcription of the audio you provided:\n\nThe first words I spoke in the original phonograph record: A little piece of practical poetry. Mary had a little lamb; its fleece was white as snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
|
||||||
|
|
||||||
|
VLLM_RUNNER_BASE_KWARGS = {
|
||||||
|
"model_name": MODEL_PATH,
|
||||||
|
"dtype": "half",
|
||||||
|
"enable_lora": "True",
|
||||||
|
"max_num_seqs": 2,
|
||||||
|
"max_lora_rank": 320,
|
||||||
|
"max_model_len": 12800,
|
||||||
|
"gpu_memory_utilization": 0.8,
|
||||||
|
"limit_mm_per_prompt": {
|
||||||
|
"audio": 1
|
||||||
|
},
|
||||||
|
"enforce_eager": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_test(vllm_runner, audio_assets, lora_request, expected_suffix,
|
||||||
|
**kwargs):
|
||||||
|
inputs = [([AUDIO_PROMPT], [audio_assets[0].audio_and_sample_rate[0]])]
|
||||||
|
|
||||||
|
# Apply any additional kwargs as overrides to the base kwargs
|
||||||
|
vllm_runner_kwargs = {**VLLM_RUNNER_BASE_KWARGS, **kwargs}
|
||||||
|
|
||||||
|
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
|
||||||
|
vllm_outputs_with_default_lora = [
|
||||||
|
vllm_model.generate_greedy(
|
||||||
|
prompts,
|
||||||
|
max_tokens=128,
|
||||||
|
audios=audios,
|
||||||
|
lora_request=lora_request,
|
||||||
|
) for prompts, audios in inputs
|
||||||
|
]
|
||||||
|
|
||||||
|
assert vllm_outputs_with_default_lora[-1][-1][-1].endswith(
|
||||||
|
expected_suffix)
|
||||||
|
|
||||||
|
|
||||||
|
def test_active_default_mm_lora(
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
audio_assets: AudioTestAssets,
|
||||||
|
):
|
||||||
|
"""Ensure that we can use the default audio lora."""
|
||||||
|
run_test(
|
||||||
|
vllm_runner,
|
||||||
|
audio_assets,
|
||||||
|
lora_request=None,
|
||||||
|
default_mm_loras={"audio": AUDIO_LORA_PATH},
|
||||||
|
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inactive_default_mm_lora(
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
audio_assets: AudioTestAssets,
|
||||||
|
):
|
||||||
|
"""Ensure that modalities are filtered properly."""
|
||||||
|
# Default image lora won't be active since we only pass audio
|
||||||
|
run_test(
|
||||||
|
vllm_runner,
|
||||||
|
audio_assets,
|
||||||
|
lora_request=None,
|
||||||
|
default_mm_loras={"image": IMAGE_LORA_PATH},
|
||||||
|
expected_suffix=RESPONSE_SUFFIX_WITHOUT_LORA,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_mm_lora_succeeds_with_redundant_lora_request(
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
audio_assets: AudioTestAssets,
|
||||||
|
):
|
||||||
|
"""Ensure that redundantly providing the lora works."""
|
||||||
|
run_test(
|
||||||
|
vllm_runner,
|
||||||
|
audio_assets,
|
||||||
|
lora_request=LoRARequest("audio", 1, AUDIO_LORA_PATH),
|
||||||
|
default_mm_loras={"audio": AUDIO_LORA_PATH},
|
||||||
|
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_mm_lora_fails_with_overridden_lora_request(
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
audio_assets: AudioTestAssets,
|
||||||
|
):
|
||||||
|
"""Ensure that if the lora_request conflicts with default_mm_loras,
|
||||||
|
we use the lora_request."""
|
||||||
|
run_test(
|
||||||
|
vllm_runner,
|
||||||
|
audio_assets,
|
||||||
|
lora_request=LoRARequest("speech", 2, AUDIO_LORA_PATH),
|
||||||
|
default_mm_loras={"audio": IMAGE_LORA_PATH},
|
||||||
|
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
|
||||||
|
)
|
||||||
@ -33,6 +33,7 @@ import vllm.envs as envs
|
|||||||
from vllm import version
|
from vllm import version
|
||||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||||
@ -2989,6 +2990,16 @@ class LoRAConfig:
|
|||||||
trained with those scaling factors to be used at the same time. If not
|
trained with those scaling factors to be used at the same time. If not
|
||||||
specified, only adapters trained with the base model scaling factor are
|
specified, only adapters trained with the base model scaling factor are
|
||||||
allowed."""
|
allowed."""
|
||||||
|
default_mm_loras: Optional[dict[str, str]] = None
|
||||||
|
"""Dictionary mapping specific modalities to LoRA model paths; this field
|
||||||
|
is only applicable to multimodal models and should be leveraged when a
|
||||||
|
model always expects a LoRA to be active when a given modality is present.
|
||||||
|
Note that currently, if a request provides multiple additional
|
||||||
|
modalities, each of which have their own LoRA, we do NOT apply
|
||||||
|
default_mm_loras because we currently only support one lora adapter
|
||||||
|
per prompt. When run in offline mode, the lora IDs for n modalities
|
||||||
|
will be automatically assigned to 1-n with the names of the modalities
|
||||||
|
in alphabetic order."""
|
||||||
bias_enabled: bool = False
|
bias_enabled: bool = False
|
||||||
"""Enable bias for LoRA adapters."""
|
"""Enable bias for LoRA adapters."""
|
||||||
|
|
||||||
|
|||||||
@ -395,6 +395,8 @@ class EngineArgs:
|
|||||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||||
max_loras: int = LoRAConfig.max_loras
|
max_loras: int = LoRAConfig.max_loras
|
||||||
max_lora_rank: int = LoRAConfig.max_lora_rank
|
max_lora_rank: int = LoRAConfig.max_lora_rank
|
||||||
|
default_mm_loras: Optional[Dict[str, str]] = \
|
||||||
|
LoRAConfig.default_mm_loras
|
||||||
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
|
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
|
||||||
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
||||||
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
||||||
@ -807,6 +809,8 @@ class EngineArgs:
|
|||||||
**lora_kwargs["max_cpu_loras"])
|
**lora_kwargs["max_cpu_loras"])
|
||||||
lora_group.add_argument("--fully-sharded-loras",
|
lora_group.add_argument("--fully-sharded-loras",
|
||||||
**lora_kwargs["fully_sharded_loras"])
|
**lora_kwargs["fully_sharded_loras"])
|
||||||
|
lora_group.add_argument("--default-mm-loras",
|
||||||
|
**lora_kwargs["default_mm_loras"])
|
||||||
|
|
||||||
# PromptAdapter related configs
|
# PromptAdapter related configs
|
||||||
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
|
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
|
||||||
@ -1284,10 +1288,16 @@ class EngineArgs:
|
|||||||
disable_hybrid_kv_cache_manager,
|
disable_hybrid_kv_cache_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not model_config.is_multimodal_model and self.default_mm_loras:
|
||||||
|
raise ValueError(
|
||||||
|
"Default modality-specific LoRA(s) were provided for a "
|
||||||
|
"non multimodal model")
|
||||||
|
|
||||||
lora_config = LoRAConfig(
|
lora_config = LoRAConfig(
|
||||||
bias_enabled=self.enable_lora_bias,
|
bias_enabled=self.enable_lora_bias,
|
||||||
max_lora_rank=self.max_lora_rank,
|
max_lora_rank=self.max_lora_rank,
|
||||||
max_loras=self.max_loras,
|
max_loras=self.max_loras,
|
||||||
|
default_mm_loras=self.default_mm_loras,
|
||||||
fully_sharded_loras=self.fully_sharded_loras,
|
fully_sharded_loras=self.fully_sharded_loras,
|
||||||
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||||
long_lora_scaling_factors=self.long_lora_scaling_factors,
|
long_lora_scaling_factors=self.long_lora_scaling_factors,
|
||||||
|
|||||||
@ -499,6 +499,10 @@ class LLM:
|
|||||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||||
truncate_prompt_tokens, tokenization_kwargs)
|
truncate_prompt_tokens, tokenization_kwargs)
|
||||||
|
|
||||||
|
# Add any modality specific loras to the corresponding prompts
|
||||||
|
lora_request = self._get_modality_specific_lora_reqs(
|
||||||
|
parsed_prompts, lora_request)
|
||||||
|
|
||||||
self._validate_and_add_requests(
|
self._validate_and_add_requests(
|
||||||
prompts=parsed_prompts,
|
prompts=parsed_prompts,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
@ -513,6 +517,83 @@ class LLM:
|
|||||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||||
|
|
||||||
|
def _get_modality_specific_lora_reqs(
|
||||||
|
self, parsed_prompts: Union[PromptType, Sequence[PromptType]],
|
||||||
|
lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
|
||||||
|
# Grab the lora config off the vllm config on the engine,
|
||||||
|
# since this is the same for both v0 & v1.
|
||||||
|
lora_config = self.llm_engine.vllm_config.lora_config
|
||||||
|
|
||||||
|
# If there's no lora config / default_mm_loras, or the model
|
||||||
|
# isn't multimodal, leave the lora as is.
|
||||||
|
if (lora_config is None
|
||||||
|
or not self.llm_engine.model_config.is_multimodal_model
|
||||||
|
or (lora_config and lora_config.default_mm_loras is None)):
|
||||||
|
return lora_request
|
||||||
|
|
||||||
|
if not isinstance(parsed_prompts, Sequence):
|
||||||
|
parsed_prompts = [parsed_prompts]
|
||||||
|
|
||||||
|
optional_loras = ([lora_request] * len(parsed_prompts)
|
||||||
|
if not isinstance(lora_request, Sequence) else
|
||||||
|
lora_request)
|
||||||
|
|
||||||
|
return [
|
||||||
|
self._resolve_single_prompt_mm_lora(
|
||||||
|
parsed_prompt,
|
||||||
|
opt_lora_req,
|
||||||
|
lora_config.default_mm_loras,
|
||||||
|
) for parsed_prompt, opt_lora_req in zip(parsed_prompts,
|
||||||
|
optional_loras)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType,
|
||||||
|
lora_request: Optional[LoRARequest],
|
||||||
|
default_mm_loras: Optional[dict[str,
|
||||||
|
str]]):
|
||||||
|
if (not default_mm_loras or not isinstance(parsed_prompt, dict)
|
||||||
|
or "multi_modal_data" not in parsed_prompt):
|
||||||
|
return lora_request
|
||||||
|
|
||||||
|
parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt)
|
||||||
|
|
||||||
|
intersection = set(
|
||||||
|
parsed_prompt["multi_modal_data"].keys()).intersection(
|
||||||
|
default_mm_loras.keys())
|
||||||
|
if not intersection:
|
||||||
|
return lora_request
|
||||||
|
if len(intersection) > 1:
|
||||||
|
# TODO: Would be nice to be able to have multiple loras per prompt
|
||||||
|
logger.warning(
|
||||||
|
"Multiple modality specific loras were registered and would be"
|
||||||
|
" used by a single prompt consuming several modalities; "
|
||||||
|
" currently we only support one lora per request; as such,"
|
||||||
|
" lora(s) registered with modalities: %s"
|
||||||
|
" will be skipped", intersection)
|
||||||
|
return lora_request
|
||||||
|
|
||||||
|
# Build the LoRA request; the ID of the default mm lora is the
|
||||||
|
# index of the modality name sorted alphabetically + 1.
|
||||||
|
modality_name = intersection.pop()
|
||||||
|
modality_lora_path = default_mm_loras[modality_name]
|
||||||
|
modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1
|
||||||
|
|
||||||
|
# If we have a collision, warn if there is a collision,
|
||||||
|
# but always send the explicitly provided request.
|
||||||
|
if lora_request:
|
||||||
|
if lora_request.lora_int_id != modality_lora_id:
|
||||||
|
logger.warning(
|
||||||
|
"A modality with a registered lora and a lora_request "
|
||||||
|
"with a different ID were provided; falling back to the "
|
||||||
|
"lora_request as we only apply one LoRARequest per prompt")
|
||||||
|
return lora_request
|
||||||
|
|
||||||
|
return LoRARequest(
|
||||||
|
modality_name,
|
||||||
|
modality_lora_id,
|
||||||
|
modality_lora_path,
|
||||||
|
)
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable[..., _R]],
|
method: Union[str, Callable[..., _R]],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
|
|||||||
@ -87,6 +87,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
|||||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||||
|
LoRAModulePath,
|
||||||
OpenAIServingModels)
|
OpenAIServingModels)
|
||||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||||
@ -1481,11 +1482,28 @@ async def init_app_state(
|
|||||||
"This discrepancy may lead to performance degradation.",
|
"This discrepancy may lead to performance degradation.",
|
||||||
resolved_chat_template, args.model)
|
resolved_chat_template, args.model)
|
||||||
|
|
||||||
|
# Merge default_mm_loras into the static lora_modules
|
||||||
|
default_mm_loras = (vllm_config.lora_config.default_mm_loras
|
||||||
|
if vllm_config.lora_config is not None else {})
|
||||||
|
|
||||||
|
lora_modules = args.lora_modules
|
||||||
|
if default_mm_loras:
|
||||||
|
default_mm_lora_paths = [
|
||||||
|
LoRAModulePath(
|
||||||
|
name=modality,
|
||||||
|
path=lora_path,
|
||||||
|
) for modality, lora_path in default_mm_loras.items()
|
||||||
|
]
|
||||||
|
if args.lora_modules is None:
|
||||||
|
lora_modules = default_mm_lora_paths
|
||||||
|
else:
|
||||||
|
lora_modules += default_mm_lora_paths
|
||||||
|
|
||||||
state.openai_serving_models = OpenAIServingModels(
|
state.openai_serving_models = OpenAIServingModels(
|
||||||
engine_client=engine_client,
|
engine_client=engine_client,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
base_model_paths=base_model_paths,
|
base_model_paths=base_model_paths,
|
||||||
lora_modules=args.lora_modules,
|
lora_modules=lora_modules,
|
||||||
prompt_adapters=args.prompt_adapters,
|
prompt_adapters=args.prompt_adapters,
|
||||||
)
|
)
|
||||||
await state.openai_serving_models.init_static_loras()
|
await state.openai_serving_models.init_static_loras()
|
||||||
|
|||||||
@ -153,7 +153,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
(
|
(
|
||||||
lora_request,
|
lora_request,
|
||||||
prompt_adapter_request,
|
prompt_adapter_request,
|
||||||
) = self._maybe_get_adapters(request)
|
) = self._maybe_get_adapters(request,
|
||||||
|
supports_default_mm_loras=True)
|
||||||
|
|
||||||
model_name = self._get_model_name(request.model, lora_request)
|
model_name = self._get_model_name(request.model, lora_request)
|
||||||
|
|
||||||
|
|||||||
@ -458,20 +458,74 @@ class OpenAIServing:
|
|||||||
err_type="NotFoundError",
|
err_type="NotFoundError",
|
||||||
status_code=HTTPStatus.NOT_FOUND)
|
status_code=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
|
def _get_active_default_mm_loras(
|
||||||
|
self, request: AnyRequest) -> Optional[LoRARequest]:
|
||||||
|
"""Determine if there are any active default multimodal loras."""
|
||||||
|
# TODO: Currently this is only enabled for chat completions
|
||||||
|
# to be better aligned with only being enabled for .generate
|
||||||
|
# when run offline. It would be nice to support additional
|
||||||
|
# tasks types in the future.
|
||||||
|
message_types = self._get_message_types(request)
|
||||||
|
default_mm_loras = set()
|
||||||
|
|
||||||
|
for lora in self.models.lora_requests.values():
|
||||||
|
# Best effort match for default multimodal lora adapters;
|
||||||
|
# There is probably a better way to do this, but currently
|
||||||
|
# this matches against the set of 'types' in any content lists
|
||||||
|
# up until '_', e.g., to match audio_url -> audio
|
||||||
|
if lora.lora_name in message_types:
|
||||||
|
default_mm_loras.add(lora)
|
||||||
|
|
||||||
|
# Currently only support default modality specific loras if
|
||||||
|
# we have exactly one lora matched on the request.
|
||||||
|
if len(default_mm_loras) == 1:
|
||||||
|
return default_mm_loras.pop()
|
||||||
|
return None
|
||||||
|
|
||||||
def _maybe_get_adapters(
|
def _maybe_get_adapters(
|
||||||
self, request: AnyRequest
|
self,
|
||||||
|
request: AnyRequest,
|
||||||
|
supports_default_mm_loras: bool = False,
|
||||||
) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
|
) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
|
||||||
None, PromptAdapterRequest]]:
|
None, PromptAdapterRequest]]:
|
||||||
if self._is_model_supported(request.model):
|
|
||||||
return None, None
|
|
||||||
if request.model in self.models.lora_requests:
|
if request.model in self.models.lora_requests:
|
||||||
return self.models.lora_requests[request.model], None
|
return self.models.lora_requests[request.model], None
|
||||||
|
|
||||||
|
# Currently only support default modality specific loras
|
||||||
|
# if we have exactly one lora matched on the request.
|
||||||
|
if supports_default_mm_loras:
|
||||||
|
default_mm_lora = self._get_active_default_mm_loras(request)
|
||||||
|
if default_mm_lora is not None:
|
||||||
|
return default_mm_lora, None
|
||||||
|
|
||||||
|
if self._is_model_supported(request.model):
|
||||||
|
return None, None
|
||||||
|
|
||||||
for prompt_adapter in self.models.prompt_adapter_requests:
|
for prompt_adapter in self.models.prompt_adapter_requests:
|
||||||
if request.model == prompt_adapter.prompt_adapter_name:
|
if request.model == prompt_adapter.prompt_adapter_name:
|
||||||
return None, prompt_adapter
|
return None, prompt_adapter
|
||||||
# if _check_model has been called earlier, this will be unreachable
|
# if _check_model has been called earlier, this will be unreachable
|
||||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||||
|
|
||||||
|
def _get_message_types(self, request: AnyRequest) -> set[str]:
|
||||||
|
"""Retrieve the set of types from message content dicts up
|
||||||
|
until `_`; we use this to match potential multimodal data
|
||||||
|
with default per modality loras.
|
||||||
|
"""
|
||||||
|
message_types: set[str] = set()
|
||||||
|
|
||||||
|
if not hasattr(request, "messages"):
|
||||||
|
return message_types
|
||||||
|
|
||||||
|
for message in request.messages:
|
||||||
|
if (isinstance(message, dict) and "content" in message
|
||||||
|
and isinstance(message["content"], list)):
|
||||||
|
for content_dict in message["content"]:
|
||||||
|
if "type" in content_dict:
|
||||||
|
message_types.add(content_dict["type"].split("_")[0])
|
||||||
|
return message_types
|
||||||
|
|
||||||
async def _normalize_prompt_text_to_input(
|
async def _normalize_prompt_text_to_input(
|
||||||
self,
|
self,
|
||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user