[Core] Add Support for Default Modality Specific LoRAs [generate / chat completions] (#19126)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks 2025-07-10 14:09:37 -06:00 committed by GitHub
parent 3de2ed767f
commit 41060c6e08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 482 additions and 5 deletions

View File

@ -272,3 +272,80 @@ The new format of `--lora-modules` is mainly to support the display of parent mo
]
}
```
## Default LoRA Models For Multimodal Models
Some models, e.g., [Granite Speech](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) and [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) multimodal, contain LoRA adapter(s) that are expected to always be applied when a given modality is present. This can be a bit tedious to manage with the above approaches, as it requires the user to send the `LoRARequest` (offline) or to filter requests between the base model and LoRA model (server) depending on the content of the request's multimodal data.
To this end, we allow registration of default multimodal LoRAs to handle this automatically, where users can map each modality to a LoRA adapter to automatically apply it when the corresponding inputs are present. Note that currently, we only allow one LoRA per prompt; if several modalities are provided, each of which are registered to a given modality, none of them will be applied.
Example usage for offline inference:
```python
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
model_id = "ibm-granite/granite-speech-3.3-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
def get_prompt(question: str, has_audio: bool):
"""Build the input prompt to send to vLLM."""
if has_audio:
question = f"<|audio|>{question}"
chat = [
{
"role": "user",
"content": question
}
]
return tokenizer.apply_chat_template(chat, tokenize=False)
model = LLM(
model=model_id,
enable_lora=True,
max_lora_rank=64,
max_model_len=2048,
limit_mm_per_prompt={"audio": 1},
# Will always pass a `LoRARequest` with the `model_id`
# whenever audio is contained in the request data.
default_mm_loras = {"audio": model_id},
enforce_eager=True,
)
question = "can you transcribe the speech into a written format?"
prompt_with_audio = get_prompt(
question=question,
has_audio=True,
)
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate
inputs = {
"prompt": prompt_with_audio,
"multi_modal_data": {
"audio": audio,
}
}
outputs = model.generate(
inputs,
sampling_params=SamplingParams(
temperature=0.2,
max_tokens=64,
),
)
```
You can also pass a json dictionary of `--default-mm-loras` mapping modalities to LoRA model IDs. For example, when starting the server:
```bash
vllm serve ibm-granite/granite-speech-3.3-2b \
--max-model-len 2048 \
--enable-lora \
--default-mm-loras '{"audio":"ibm-granite/granite-speech-3.3-2b"}' \
--max-lora-rank 64
```
Note: Default multimodal LoRAs are currently only available for `.generate` and chat completions.

View File

@ -0,0 +1,107 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from huggingface_hub import snapshot_download
from ...conftest import AudioTestAssets
from ...utils import RemoteOpenAIServer
# NOTE - the tests in this module are currently analogous to test_chat, but are
# separated to avoid OOM killing due to module-scoped servers, since we
# need a multimodal model for these tests.
# Contains a modality specific lora alongside the base model
MULTIMODAL_MODEL_NAME = snapshot_download(
"microsoft/Phi-4-multimodal-instruct")
AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora")
ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
@pytest.fixture(scope="module")
def monkeypatch_module():
from _pytest.monkeypatch import MonkeyPatch
mpatch = MonkeyPatch()
yield mpatch
mpatch.undo()
@pytest.fixture(scope="module", params=[False, True])
def multimodal_server(request, monkeypatch_module): # noqa: F811
use_v1 = request.param
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--max-model-len",
"12800",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
f"speech={AUDIO_LORA_PATH}",
"--max-lora-rank",
"320",
"--max-num-seqs",
"2",
"--trust-remote-code",
"--gpu-memory-utilization",
"0.8",
"--default-mm-loras",
f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}",
]
with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def multi_modal_client(multimodal_server):
async with multimodal_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize(
# base model with default lora should give the same response as lora model
"model_name",
[MULTIMODAL_MODEL_NAME, "speech"],
)
async def test_default_mm_lora_chat_completions(
model_name: str,
multi_modal_client: openai.AsyncOpenAI,
audio_assets: AudioTestAssets,
):
messages = [{
"role":
"user",
"content": [{
"type": "text",
"text": "Can you transcribe this audio?",
}, {
"type": "audio_url",
"audio_url": {
"url": audio_assets[0].url
},
}]
}]
chat_completion = await multi_modal_client.chat.completions.create(
model=model_name,
messages=messages,
max_completion_tokens=128,
temperature=0.0)
assert len(chat_completion.choices) > 0
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
assert message.content == ACTIVE_MM_LORA_RESPONSE

View File

@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for applying default registered multimodal loras.
"""
import os
from huggingface_hub import snapshot_download
from vllm.lora.request import LoRARequest
from ..conftest import AudioTestAssets, VllmRunner
MODEL_PATH = snapshot_download("microsoft/Phi-4-multimodal-instruct")
AUDIO_LORA_PATH = os.path.join(MODEL_PATH, "speech-lora")
IMAGE_LORA_PATH = os.path.join(MODEL_PATH, "vision-lora")
AUDIO_PROMPT = "<|user|><|audio_1|>Can you transcribe this audio?<|end|><|assistant|>" # noqa: E501
# Responses are greedy decoded; we just check the end of
# the generated text. If the lora is inactive, this model
# generates commentary on the transcription.
RESPONSE_SUFFIX_WITH_LORA = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
RESPONSE_SUFFIX_WITHOUT_LORA = "Certainly! Here is the transcription of the audio you provided:\n\nThe first words I spoke in the original phonograph record: A little piece of practical poetry. Mary had a little lamb; its fleece was white as snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
VLLM_RUNNER_BASE_KWARGS = {
"model_name": MODEL_PATH,
"dtype": "half",
"enable_lora": "True",
"max_num_seqs": 2,
"max_lora_rank": 320,
"max_model_len": 12800,
"gpu_memory_utilization": 0.8,
"limit_mm_per_prompt": {
"audio": 1
},
"enforce_eager": True,
}
def run_test(vllm_runner, audio_assets, lora_request, expected_suffix,
**kwargs):
inputs = [([AUDIO_PROMPT], [audio_assets[0].audio_and_sample_rate[0]])]
# Apply any additional kwargs as overrides to the base kwargs
vllm_runner_kwargs = {**VLLM_RUNNER_BASE_KWARGS, **kwargs}
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
vllm_outputs_with_default_lora = [
vllm_model.generate_greedy(
prompts,
max_tokens=128,
audios=audios,
lora_request=lora_request,
) for prompts, audios in inputs
]
assert vllm_outputs_with_default_lora[-1][-1][-1].endswith(
expected_suffix)
def test_active_default_mm_lora(
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets,
):
"""Ensure that we can use the default audio lora."""
run_test(
vllm_runner,
audio_assets,
lora_request=None,
default_mm_loras={"audio": AUDIO_LORA_PATH},
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
)
def test_inactive_default_mm_lora(
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets,
):
"""Ensure that modalities are filtered properly."""
# Default image lora won't be active since we only pass audio
run_test(
vllm_runner,
audio_assets,
lora_request=None,
default_mm_loras={"image": IMAGE_LORA_PATH},
expected_suffix=RESPONSE_SUFFIX_WITHOUT_LORA,
)
def test_default_mm_lora_succeeds_with_redundant_lora_request(
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets,
):
"""Ensure that redundantly providing the lora works."""
run_test(
vllm_runner,
audio_assets,
lora_request=LoRARequest("audio", 1, AUDIO_LORA_PATH),
default_mm_loras={"audio": AUDIO_LORA_PATH},
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
)
def test_default_mm_lora_fails_with_overridden_lora_request(
vllm_runner: type[VllmRunner],
audio_assets: AudioTestAssets,
):
"""Ensure that if the lora_request conflicts with default_mm_loras,
we use the lora_request."""
run_test(
vllm_runner,
audio_assets,
lora_request=LoRARequest("speech", 2, AUDIO_LORA_PATH),
default_mm_loras={"audio": IMAGE_LORA_PATH},
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
)

View File

@ -33,6 +33,7 @@ import vllm.envs as envs
from vllm import version
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
@ -2989,6 +2990,16 @@ class LoRAConfig:
trained with those scaling factors to be used at the same time. If not
specified, only adapters trained with the base model scaling factor are
allowed."""
default_mm_loras: Optional[dict[str, str]] = None
"""Dictionary mapping specific modalities to LoRA model paths; this field
is only applicable to multimodal models and should be leveraged when a
model always expects a LoRA to be active when a given modality is present.
Note that currently, if a request provides multiple additional
modalities, each of which have their own LoRA, we do NOT apply
default_mm_loras because we currently only support one lora adapter
per prompt. When run in offline mode, the lora IDs for n modalities
will be automatically assigned to 1-n with the names of the modalities
in alphabetic order."""
bias_enabled: bool = False
"""Enable bias for LoRA adapters."""

View File

@ -395,6 +395,8 @@ class EngineArgs:
enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank
default_mm_loras: Optional[Dict[str, str]] = \
LoRAConfig.default_mm_loras
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
@ -807,6 +809,8 @@ class EngineArgs:
**lora_kwargs["max_cpu_loras"])
lora_group.add_argument("--fully-sharded-loras",
**lora_kwargs["fully_sharded_loras"])
lora_group.add_argument("--default-mm-loras",
**lora_kwargs["default_mm_loras"])
# PromptAdapter related configs
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
@ -1284,10 +1288,16 @@ class EngineArgs:
disable_hybrid_kv_cache_manager,
)
if not model_config.is_multimodal_model and self.default_mm_loras:
raise ValueError(
"Default modality-specific LoRA(s) were provided for a "
"non multimodal model")
lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
default_mm_loras=self.default_mm_loras,
fully_sharded_loras=self.fully_sharded_loras,
lora_extra_vocab_size=self.lora_extra_vocab_size,
long_lora_scaling_factors=self.long_lora_scaling_factors,

View File

@ -499,6 +499,10 @@ class LLM:
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)
# Add any modality specific loras to the corresponding prompts
lora_request = self._get_modality_specific_lora_reqs(
parsed_prompts, lora_request)
self._validate_and_add_requests(
prompts=parsed_prompts,
params=sampling_params,
@ -513,6 +517,83 @@ class LLM:
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs, RequestOutput)
def _get_modality_specific_lora_reqs(
self, parsed_prompts: Union[PromptType, Sequence[PromptType]],
lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
# Grab the lora config off the vllm config on the engine,
# since this is the same for both v0 & v1.
lora_config = self.llm_engine.vllm_config.lora_config
# If there's no lora config / default_mm_loras, or the model
# isn't multimodal, leave the lora as is.
if (lora_config is None
or not self.llm_engine.model_config.is_multimodal_model
or (lora_config and lora_config.default_mm_loras is None)):
return lora_request
if not isinstance(parsed_prompts, Sequence):
parsed_prompts = [parsed_prompts]
optional_loras = ([lora_request] * len(parsed_prompts)
if not isinstance(lora_request, Sequence) else
lora_request)
return [
self._resolve_single_prompt_mm_lora(
parsed_prompt,
opt_lora_req,
lora_config.default_mm_loras,
) for parsed_prompt, opt_lora_req in zip(parsed_prompts,
optional_loras)
]
def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType,
lora_request: Optional[LoRARequest],
default_mm_loras: Optional[dict[str,
str]]):
if (not default_mm_loras or not isinstance(parsed_prompt, dict)
or "multi_modal_data" not in parsed_prompt):
return lora_request
parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt)
intersection = set(
parsed_prompt["multi_modal_data"].keys()).intersection(
default_mm_loras.keys())
if not intersection:
return lora_request
if len(intersection) > 1:
# TODO: Would be nice to be able to have multiple loras per prompt
logger.warning(
"Multiple modality specific loras were registered and would be"
" used by a single prompt consuming several modalities; "
" currently we only support one lora per request; as such,"
" lora(s) registered with modalities: %s"
" will be skipped", intersection)
return lora_request
# Build the LoRA request; the ID of the default mm lora is the
# index of the modality name sorted alphabetically + 1.
modality_name = intersection.pop()
modality_lora_path = default_mm_loras[modality_name]
modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1
# If we have a collision, warn if there is a collision,
# but always send the explicitly provided request.
if lora_request:
if lora_request.lora_int_id != modality_lora_id:
logger.warning(
"A modality with a registered lora and a lora_request "
"with a different ID were provided; falling back to the "
"lora_request as we only apply one LoRARequest per prompt")
return lora_request
return LoRARequest(
modality_name,
modality_lora_id,
modality_lora_path,
)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,

View File

@ -87,6 +87,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
LoRAModulePath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
@ -1481,11 +1482,28 @@ async def init_app_state(
"This discrepancy may lead to performance degradation.",
resolved_chat_template, args.model)
# Merge default_mm_loras into the static lora_modules
default_mm_loras = (vllm_config.lora_config.default_mm_loras
if vllm_config.lora_config is not None else {})
lora_modules = args.lora_modules
if default_mm_loras:
default_mm_lora_paths = [
LoRAModulePath(
name=modality,
path=lora_path,
) for modality, lora_path in default_mm_loras.items()
]
if args.lora_modules is None:
lora_modules = default_mm_lora_paths
else:
lora_modules += default_mm_lora_paths
state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=args.lora_modules,
lora_modules=lora_modules,
prompt_adapters=args.prompt_adapters,
)
await state.openai_serving_models.init_static_loras()

View File

@ -153,7 +153,8 @@ class OpenAIServingChat(OpenAIServing):
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
) = self._maybe_get_adapters(request,
supports_default_mm_loras=True)
model_name = self._get_model_name(request.model, lora_request)

View File

@ -458,20 +458,74 @@ class OpenAIServing:
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _get_active_default_mm_loras(
self, request: AnyRequest) -> Optional[LoRARequest]:
"""Determine if there are any active default multimodal loras."""
# TODO: Currently this is only enabled for chat completions
# to be better aligned with only being enabled for .generate
# when run offline. It would be nice to support additional
# tasks types in the future.
message_types = self._get_message_types(request)
default_mm_loras = set()
for lora in self.models.lora_requests.values():
# Best effort match for default multimodal lora adapters;
# There is probably a better way to do this, but currently
# this matches against the set of 'types' in any content lists
# up until '_', e.g., to match audio_url -> audio
if lora.lora_name in message_types:
default_mm_loras.add(lora)
# Currently only support default modality specific loras if
# we have exactly one lora matched on the request.
if len(default_mm_loras) == 1:
return default_mm_loras.pop()
return None
def _maybe_get_adapters(
self, request: AnyRequest
self,
request: AnyRequest,
supports_default_mm_loras: bool = False,
) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
None, PromptAdapterRequest]]:
if self._is_model_supported(request.model):
return None, None
if request.model in self.models.lora_requests:
return self.models.lora_requests[request.model], None
# Currently only support default modality specific loras
# if we have exactly one lora matched on the request.
if supports_default_mm_loras:
default_mm_lora = self._get_active_default_mm_loras(request)
if default_mm_lora is not None:
return default_mm_lora, None
if self._is_model_supported(request.model):
return None, None
for prompt_adapter in self.models.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _get_message_types(self, request: AnyRequest) -> set[str]:
"""Retrieve the set of types from message content dicts up
until `_`; we use this to match potential multimodal data
with default per modality loras.
"""
message_types: set[str] = set()
if not hasattr(request, "messages"):
return message_types
for message in request.messages:
if (isinstance(message, dict) and "content" in message
and isinstance(message["content"], list)):
for content_dict in message["content"]:
if "type" in content_dict:
message_types.add(content_dict["type"].split("_")[0])
return message_types
async def _normalize_prompt_text_to_input(
self,
request: AnyRequest,