diff --git a/docs/features/lora.md b/docs/features/lora.md index 3e17c659655e..d72c0bb4160c 100644 --- a/docs/features/lora.md +++ b/docs/features/lora.md @@ -272,3 +272,80 @@ The new format of `--lora-modules` is mainly to support the display of parent mo ] } ``` + +## Default LoRA Models For Multimodal Models + +Some models, e.g., [Granite Speech](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) and [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) multimodal, contain LoRA adapter(s) that are expected to always be applied when a given modality is present. This can be a bit tedious to manage with the above approaches, as it requires the user to send the `LoRARequest` (offline) or to filter requests between the base model and LoRA model (server) depending on the content of the request's multimodal data. + +To this end, we allow registration of default multimodal LoRAs to handle this automatically, where users can map each modality to a LoRA adapter to automatically apply it when the corresponding inputs are present. Note that currently, we only allow one LoRA per prompt; if several modalities are provided, each of which are registered to a given modality, none of them will be applied. + +Example usage for offline inference: + +```python +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset + +model_id = "ibm-granite/granite-speech-3.3-2b" +tokenizer = AutoTokenizer.from_pretrained(model_id) + +def get_prompt(question: str, has_audio: bool): + """Build the input prompt to send to vLLM.""" + if has_audio: + question = f"<|audio|>{question}" + chat = [ + { + "role": "user", + "content": question + } + ] + return tokenizer.apply_chat_template(chat, tokenize=False) + + +model = LLM( + model=model_id, + enable_lora=True, + max_lora_rank=64, + max_model_len=2048, + limit_mm_per_prompt={"audio": 1}, + # Will always pass a `LoRARequest` with the `model_id` + # whenever audio is contained in the request data. + default_mm_loras = {"audio": model_id}, + enforce_eager=True, +) + +question = "can you transcribe the speech into a written format?" +prompt_with_audio = get_prompt( + question=question, + has_audio=True, +) +audio = AudioAsset("mary_had_lamb").audio_and_sample_rate + +inputs = { + "prompt": prompt_with_audio, + "multi_modal_data": { + "audio": audio, + } +} + + +outputs = model.generate( + inputs, + sampling_params=SamplingParams( + temperature=0.2, + max_tokens=64, + ), +) +``` + +You can also pass a json dictionary of `--default-mm-loras` mapping modalities to LoRA model IDs. For example, when starting the server: + +```bash +vllm serve ibm-granite/granite-speech-3.3-2b \ + --max-model-len 2048 \ + --enable-lora \ + --default-mm-loras '{"audio":"ibm-granite/granite-speech-3.3-2b"}' \ + --max-lora-rank 64 +``` + +Note: Default multimodal LoRAs are currently only available for `.generate` and chat completions. diff --git a/tests/entrypoints/openai/test_default_mm_loras.py b/tests/entrypoints/openai/test_default_mm_loras.py new file mode 100644 index 000000000000..1fc87c8b42a7 --- /dev/null +++ b/tests/entrypoints/openai/test_default_mm_loras.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +from huggingface_hub import snapshot_download + +from ...conftest import AudioTestAssets +from ...utils import RemoteOpenAIServer + +# NOTE - the tests in this module are currently analogous to test_chat, but are +# separated to avoid OOM killing due to module-scoped servers, since we +# need a multimodal model for these tests. + +# Contains a modality specific lora alongside the base model +MULTIMODAL_MODEL_NAME = snapshot_download( + "microsoft/Phi-4-multimodal-instruct") +AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora") + +ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501 + + +@pytest.fixture(scope="module") +def monkeypatch_module(): + from _pytest.monkeypatch import MonkeyPatch + mpatch = MonkeyPatch() + yield mpatch + mpatch.undo() + + +@pytest.fixture(scope="module", params=[False, True]) +def multimodal_server(request, monkeypatch_module): # noqa: F811 + + use_v1 = request.param + monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0') + + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "half", + "--max-model-len", + "12800", + "--enforce-eager", + # lora config below + "--enable-lora", + "--lora-modules", + f"speech={AUDIO_LORA_PATH}", + "--max-lora-rank", + "320", + "--max-num-seqs", + "2", + "--trust-remote-code", + "--gpu-memory-utilization", + "0.8", + "--default-mm-loras", + f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}", + ] + + with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def multi_modal_client(multimodal_server): + async with multimodal_server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # base model with default lora should give the same response as lora model + "model_name", + [MULTIMODAL_MODEL_NAME, "speech"], +) +async def test_default_mm_lora_chat_completions( + model_name: str, + multi_modal_client: openai.AsyncOpenAI, + audio_assets: AudioTestAssets, +): + messages = [{ + "role": + "user", + "content": [{ + "type": "text", + "text": "Can you transcribe this audio?", + }, { + "type": "audio_url", + "audio_url": { + "url": audio_assets[0].url + }, + }] + }] + + chat_completion = await multi_modal_client.chat.completions.create( + model=model_name, + messages=messages, + max_completion_tokens=128, + temperature=0.0) + + assert len(chat_completion.choices) > 0 + + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 + assert message.content == ACTIVE_MM_LORA_RESPONSE diff --git a/tests/lora/test_default_mm_loras.py b/tests/lora/test_default_mm_loras.py new file mode 100644 index 000000000000..f615ceda76b5 --- /dev/null +++ b/tests/lora/test_default_mm_loras.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for applying default registered multimodal loras. +""" + +import os + +from huggingface_hub import snapshot_download + +from vllm.lora.request import LoRARequest + +from ..conftest import AudioTestAssets, VllmRunner + +MODEL_PATH = snapshot_download("microsoft/Phi-4-multimodal-instruct") +AUDIO_LORA_PATH = os.path.join(MODEL_PATH, "speech-lora") +IMAGE_LORA_PATH = os.path.join(MODEL_PATH, "vision-lora") + +AUDIO_PROMPT = "<|user|><|audio_1|>Can you transcribe this audio?<|end|><|assistant|>" # noqa: E501 + +# Responses are greedy decoded; we just check the end of +# the generated text. If the lora is inactive, this model +# generates commentary on the transcription. +RESPONSE_SUFFIX_WITH_LORA = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501 +RESPONSE_SUFFIX_WITHOUT_LORA = "Certainly! Here is the transcription of the audio you provided:\n\nThe first words I spoke in the original phonograph record: A little piece of practical poetry. Mary had a little lamb; its fleece was white as snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501 + +VLLM_RUNNER_BASE_KWARGS = { + "model_name": MODEL_PATH, + "dtype": "half", + "enable_lora": "True", + "max_num_seqs": 2, + "max_lora_rank": 320, + "max_model_len": 12800, + "gpu_memory_utilization": 0.8, + "limit_mm_per_prompt": { + "audio": 1 + }, + "enforce_eager": True, +} + + +def run_test(vllm_runner, audio_assets, lora_request, expected_suffix, + **kwargs): + inputs = [([AUDIO_PROMPT], [audio_assets[0].audio_and_sample_rate[0]])] + + # Apply any additional kwargs as overrides to the base kwargs + vllm_runner_kwargs = {**VLLM_RUNNER_BASE_KWARGS, **kwargs} + + with vllm_runner(**vllm_runner_kwargs) as vllm_model: + vllm_outputs_with_default_lora = [ + vllm_model.generate_greedy( + prompts, + max_tokens=128, + audios=audios, + lora_request=lora_request, + ) for prompts, audios in inputs + ] + + assert vllm_outputs_with_default_lora[-1][-1][-1].endswith( + expected_suffix) + + +def test_active_default_mm_lora( + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, +): + """Ensure that we can use the default audio lora.""" + run_test( + vllm_runner, + audio_assets, + lora_request=None, + default_mm_loras={"audio": AUDIO_LORA_PATH}, + expected_suffix=RESPONSE_SUFFIX_WITH_LORA, + ) + + +def test_inactive_default_mm_lora( + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, +): + """Ensure that modalities are filtered properly.""" + # Default image lora won't be active since we only pass audio + run_test( + vllm_runner, + audio_assets, + lora_request=None, + default_mm_loras={"image": IMAGE_LORA_PATH}, + expected_suffix=RESPONSE_SUFFIX_WITHOUT_LORA, + ) + + +def test_default_mm_lora_succeeds_with_redundant_lora_request( + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, +): + """Ensure that redundantly providing the lora works.""" + run_test( + vllm_runner, + audio_assets, + lora_request=LoRARequest("audio", 1, AUDIO_LORA_PATH), + default_mm_loras={"audio": AUDIO_LORA_PATH}, + expected_suffix=RESPONSE_SUFFIX_WITH_LORA, + ) + + +def test_default_mm_lora_fails_with_overridden_lora_request( + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, +): + """Ensure that if the lora_request conflicts with default_mm_loras, + we use the lora_request.""" + run_test( + vllm_runner, + audio_assets, + lora_request=LoRARequest("speech", 2, AUDIO_LORA_PATH), + default_mm_loras={"audio": IMAGE_LORA_PATH}, + expected_suffix=RESPONSE_SUFFIX_WITH_LORA, + ) diff --git a/vllm/config.py b/vllm/config.py index b973bf20809b..1a3ff9d42ff6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -33,6 +33,7 @@ import vllm.envs as envs from vllm import version from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.platforms import current_platform from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -2989,6 +2990,16 @@ class LoRAConfig: trained with those scaling factors to be used at the same time. If not specified, only adapters trained with the base model scaling factor are allowed.""" + default_mm_loras: Optional[dict[str, str]] = None + """Dictionary mapping specific modalities to LoRA model paths; this field + is only applicable to multimodal models and should be leveraged when a + model always expects a LoRA to be active when a given modality is present. + Note that currently, if a request provides multiple additional + modalities, each of which have their own LoRA, we do NOT apply + default_mm_loras because we currently only support one lora adapter + per prompt. When run in offline mode, the lora IDs for n modalities + will be automatically assigned to 1-n with the names of the modalities + in alphabetic order.""" bias_enabled: bool = False """Enable bias for LoRA adapters.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index eb870d8e16c3..1b8dc640e056 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -395,6 +395,8 @@ class EngineArgs: enable_lora_bias: bool = LoRAConfig.bias_enabled max_loras: int = LoRAConfig.max_loras max_lora_rank: int = LoRAConfig.max_lora_rank + default_mm_loras: Optional[Dict[str, str]] = \ + LoRAConfig.default_mm_loras fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype @@ -807,6 +809,8 @@ class EngineArgs: **lora_kwargs["max_cpu_loras"]) lora_group.add_argument("--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]) + lora_group.add_argument("--default-mm-loras", + **lora_kwargs["default_mm_loras"]) # PromptAdapter related configs prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig) @@ -1284,10 +1288,16 @@ class EngineArgs: disable_hybrid_kv_cache_manager, ) + if not model_config.is_multimodal_model and self.default_mm_loras: + raise ValueError( + "Default modality-specific LoRA(s) were provided for a " + "non multimodal model") + lora_config = LoRAConfig( bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, + default_mm_loras=self.default_mm_loras, fully_sharded_loras=self.fully_sharded_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, long_lora_scaling_factors=self.long_lora_scaling_factors, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d5ecd7a864d6..c60a566f585d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -499,6 +499,10 @@ class LLM: _validate_truncation_size(self.llm_engine.model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) + # Add any modality specific loras to the corresponding prompts + lora_request = self._get_modality_specific_lora_reqs( + parsed_prompts, lora_request) + self._validate_and_add_requests( prompts=parsed_prompts, params=sampling_params, @@ -513,6 +517,83 @@ class LLM: outputs = self._run_engine(use_tqdm=use_tqdm) return self.engine_class.validate_outputs(outputs, RequestOutput) + def _get_modality_specific_lora_reqs( + self, parsed_prompts: Union[PromptType, Sequence[PromptType]], + lora_request: Optional[Union[list[LoRARequest], LoRARequest]]): + # Grab the lora config off the vllm config on the engine, + # since this is the same for both v0 & v1. + lora_config = self.llm_engine.vllm_config.lora_config + + # If there's no lora config / default_mm_loras, or the model + # isn't multimodal, leave the lora as is. + if (lora_config is None + or not self.llm_engine.model_config.is_multimodal_model + or (lora_config and lora_config.default_mm_loras is None)): + return lora_request + + if not isinstance(parsed_prompts, Sequence): + parsed_prompts = [parsed_prompts] + + optional_loras = ([lora_request] * len(parsed_prompts) + if not isinstance(lora_request, Sequence) else + lora_request) + + return [ + self._resolve_single_prompt_mm_lora( + parsed_prompt, + opt_lora_req, + lora_config.default_mm_loras, + ) for parsed_prompt, opt_lora_req in zip(parsed_prompts, + optional_loras) + ] + + def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType, + lora_request: Optional[LoRARequest], + default_mm_loras: Optional[dict[str, + str]]): + if (not default_mm_loras or not isinstance(parsed_prompt, dict) + or "multi_modal_data" not in parsed_prompt): + return lora_request + + parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt) + + intersection = set( + parsed_prompt["multi_modal_data"].keys()).intersection( + default_mm_loras.keys()) + if not intersection: + return lora_request + if len(intersection) > 1: + # TODO: Would be nice to be able to have multiple loras per prompt + logger.warning( + "Multiple modality specific loras were registered and would be" + " used by a single prompt consuming several modalities; " + " currently we only support one lora per request; as such," + " lora(s) registered with modalities: %s" + " will be skipped", intersection) + return lora_request + + # Build the LoRA request; the ID of the default mm lora is the + # index of the modality name sorted alphabetically + 1. + modality_name = intersection.pop() + modality_lora_path = default_mm_loras[modality_name] + modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1 + + # If we have a collision, warn if there is a collision, + # but always send the explicitly provided request. + if lora_request: + if lora_request.lora_int_id != modality_lora_id: + logger.warning( + "A modality with a registered lora and a lora_request " + "with a different ID were provided; falling back to the " + "lora_request as we only apply one LoRARequest per prompt") + return lora_request + + return LoRARequest( + modality_name, + modality_lora_id, + modality_lora_path, + ) + def collective_rpc(self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2f8b31c8a7ba..f0c486317c23 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -87,6 +87,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import (BaseModelPath, + LoRAModulePath, OpenAIServingModels) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses @@ -1481,11 +1482,28 @@ async def init_app_state( "This discrepancy may lead to performance degradation.", resolved_chat_template, args.model) + # Merge default_mm_loras into the static lora_modules + default_mm_loras = (vllm_config.lora_config.default_mm_loras + if vllm_config.lora_config is not None else {}) + + lora_modules = args.lora_modules + if default_mm_loras: + default_mm_lora_paths = [ + LoRAModulePath( + name=modality, + path=lora_path, + ) for modality, lora_path in default_mm_loras.items() + ] + if args.lora_modules is None: + lora_modules = default_mm_lora_paths + else: + lora_modules += default_mm_lora_paths + state.openai_serving_models = OpenAIServingModels( engine_client=engine_client, model_config=model_config, base_model_paths=base_model_paths, - lora_modules=args.lora_modules, + lora_modules=lora_modules, prompt_adapters=args.prompt_adapters, ) await state.openai_serving_models.init_static_loras() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 451241d3f9f7..53509e8f65a7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -153,7 +153,8 @@ class OpenAIServingChat(OpenAIServing): ( lora_request, prompt_adapter_request, - ) = self._maybe_get_adapters(request) + ) = self._maybe_get_adapters(request, + supports_default_mm_loras=True) model_name = self._get_model_name(request.model, lora_request) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ccd98ea75f54..7581ab6e63bb 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -458,20 +458,74 @@ class OpenAIServing: err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) + def _get_active_default_mm_loras( + self, request: AnyRequest) -> Optional[LoRARequest]: + """Determine if there are any active default multimodal loras.""" + # TODO: Currently this is only enabled for chat completions + # to be better aligned with only being enabled for .generate + # when run offline. It would be nice to support additional + # tasks types in the future. + message_types = self._get_message_types(request) + default_mm_loras = set() + + for lora in self.models.lora_requests.values(): + # Best effort match for default multimodal lora adapters; + # There is probably a better way to do this, but currently + # this matches against the set of 'types' in any content lists + # up until '_', e.g., to match audio_url -> audio + if lora.lora_name in message_types: + default_mm_loras.add(lora) + + # Currently only support default modality specific loras if + # we have exactly one lora matched on the request. + if len(default_mm_loras) == 1: + return default_mm_loras.pop() + return None + def _maybe_get_adapters( - self, request: AnyRequest + self, + request: AnyRequest, + supports_default_mm_loras: bool = False, ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[ None, PromptAdapterRequest]]: - if self._is_model_supported(request.model): - return None, None + if request.model in self.models.lora_requests: return self.models.lora_requests[request.model], None + + # Currently only support default modality specific loras + # if we have exactly one lora matched on the request. + if supports_default_mm_loras: + default_mm_lora = self._get_active_default_mm_loras(request) + if default_mm_lora is not None: + return default_mm_lora, None + + if self._is_model_supported(request.model): + return None, None + for prompt_adapter in self.models.prompt_adapter_requests: if request.model == prompt_adapter.prompt_adapter_name: return None, prompt_adapter # if _check_model has been called earlier, this will be unreachable raise ValueError(f"The model `{request.model}` does not exist.") + def _get_message_types(self, request: AnyRequest) -> set[str]: + """Retrieve the set of types from message content dicts up + until `_`; we use this to match potential multimodal data + with default per modality loras. + """ + message_types: set[str] = set() + + if not hasattr(request, "messages"): + return message_types + + for message in request.messages: + if (isinstance(message, dict) and "content" in message + and isinstance(message["content"], list)): + for content_dict in message["content"]: + if "type" in content_dict: + message_types.add(content_dict["type"].split("_")[0]) + return message_types + async def _normalize_prompt_text_to_input( self, request: AnyRequest,