mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:04:27 +08:00
[Core] Add Support for Default Modality Specific LoRAs [generate / chat completions] (#19126)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
parent
3de2ed767f
commit
41060c6e08
@ -272,3 +272,80 @@ The new format of `--lora-modules` is mainly to support the display of parent mo
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Default LoRA Models For Multimodal Models
|
||||
|
||||
Some models, e.g., [Granite Speech](https://huggingface.co/ibm-granite/granite-speech-3.3-8b) and [Phi-4-multimodal-instruct](https://huggingface.co/microsoft/Phi-4-multimodal-instruct) multimodal, contain LoRA adapter(s) that are expected to always be applied when a given modality is present. This can be a bit tedious to manage with the above approaches, as it requires the user to send the `LoRARequest` (offline) or to filter requests between the base model and LoRA model (server) depending on the content of the request's multimodal data.
|
||||
|
||||
To this end, we allow registration of default multimodal LoRAs to handle this automatically, where users can map each modality to a LoRA adapter to automatically apply it when the corresponding inputs are present. Note that currently, we only allow one LoRA per prompt; if several modalities are provided, each of which are registered to a given modality, none of them will be applied.
|
||||
|
||||
Example usage for offline inference:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
model_id = "ibm-granite/granite-speech-3.3-2b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def get_prompt(question: str, has_audio: bool):
|
||||
"""Build the input prompt to send to vLLM."""
|
||||
if has_audio:
|
||||
question = f"<|audio|>{question}"
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
}
|
||||
]
|
||||
return tokenizer.apply_chat_template(chat, tokenize=False)
|
||||
|
||||
|
||||
model = LLM(
|
||||
model=model_id,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_model_len=2048,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
# Will always pass a `LoRARequest` with the `model_id`
|
||||
# whenever audio is contained in the request data.
|
||||
default_mm_loras = {"audio": model_id},
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
question = "can you transcribe the speech into a written format?"
|
||||
prompt_with_audio = get_prompt(
|
||||
question=question,
|
||||
has_audio=True,
|
||||
)
|
||||
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||
|
||||
inputs = {
|
||||
"prompt": prompt_with_audio,
|
||||
"multi_modal_data": {
|
||||
"audio": audio,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
outputs = model.generate(
|
||||
inputs,
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0.2,
|
||||
max_tokens=64,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
You can also pass a json dictionary of `--default-mm-loras` mapping modalities to LoRA model IDs. For example, when starting the server:
|
||||
|
||||
```bash
|
||||
vllm serve ibm-granite/granite-speech-3.3-2b \
|
||||
--max-model-len 2048 \
|
||||
--enable-lora \
|
||||
--default-mm-loras '{"audio":"ibm-granite/granite-speech-3.3-2b"}' \
|
||||
--max-lora-rank 64
|
||||
```
|
||||
|
||||
Note: Default multimodal LoRAs are currently only available for `.generate` and chat completions.
|
||||
|
||||
107
tests/entrypoints/openai/test_default_mm_loras.py
Normal file
107
tests/entrypoints/openai/test_default_mm_loras.py
Normal file
@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from ...conftest import AudioTestAssets
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# NOTE - the tests in this module are currently analogous to test_chat, but are
|
||||
# separated to avoid OOM killing due to module-scoped servers, since we
|
||||
# need a multimodal model for these tests.
|
||||
|
||||
# Contains a modality specific lora alongside the base model
|
||||
MULTIMODAL_MODEL_NAME = snapshot_download(
|
||||
"microsoft/Phi-4-multimodal-instruct")
|
||||
AUDIO_LORA_PATH = os.path.join(MULTIMODAL_MODEL_NAME, "speech-lora")
|
||||
|
||||
ACTIVE_MM_LORA_RESPONSE = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def monkeypatch_module():
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
mpatch = MonkeyPatch()
|
||||
yield mpatch
|
||||
mpatch.undo()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[False, True])
|
||||
def multimodal_server(request, monkeypatch_module): # noqa: F811
|
||||
|
||||
use_v1 = request.param
|
||||
monkeypatch_module.setenv('VLLM_USE_V1', '1' if use_v1 else '0')
|
||||
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"half",
|
||||
"--max-model-len",
|
||||
"12800",
|
||||
"--enforce-eager",
|
||||
# lora config below
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"speech={AUDIO_LORA_PATH}",
|
||||
"--max-lora-rank",
|
||||
"320",
|
||||
"--max-num-seqs",
|
||||
"2",
|
||||
"--trust-remote-code",
|
||||
"--gpu-memory-utilization",
|
||||
"0.8",
|
||||
"--default-mm-loras",
|
||||
f"{{\"audio\": \"{AUDIO_LORA_PATH}\"}}",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MULTIMODAL_MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def multi_modal_client(multimodal_server):
|
||||
async with multimodal_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# base model with default lora should give the same response as lora model
|
||||
"model_name",
|
||||
[MULTIMODAL_MODEL_NAME, "speech"],
|
||||
)
|
||||
async def test_default_mm_lora_chat_completions(
|
||||
model_name: str,
|
||||
multi_modal_client: openai.AsyncOpenAI,
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
messages = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Can you transcribe this audio?",
|
||||
}, {
|
||||
"type": "audio_url",
|
||||
"audio_url": {
|
||||
"url": audio_assets[0].url
|
||||
},
|
||||
}]
|
||||
}]
|
||||
|
||||
chat_completion = await multi_modal_client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_completion_tokens=128,
|
||||
temperature=0.0)
|
||||
|
||||
assert len(chat_completion.choices) > 0
|
||||
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
assert message.content == ACTIVE_MM_LORA_RESPONSE
|
||||
118
tests/lora/test_default_mm_loras.py
Normal file
118
tests/lora/test_default_mm_loras.py
Normal file
@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for applying default registered multimodal loras.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from ..conftest import AudioTestAssets, VllmRunner
|
||||
|
||||
MODEL_PATH = snapshot_download("microsoft/Phi-4-multimodal-instruct")
|
||||
AUDIO_LORA_PATH = os.path.join(MODEL_PATH, "speech-lora")
|
||||
IMAGE_LORA_PATH = os.path.join(MODEL_PATH, "vision-lora")
|
||||
|
||||
AUDIO_PROMPT = "<|user|><|audio_1|>Can you transcribe this audio?<|end|><|assistant|>" # noqa: E501
|
||||
|
||||
# Responses are greedy decoded; we just check the end of
|
||||
# the generated text. If the lora is inactive, this model
|
||||
# generates commentary on the transcription.
|
||||
RESPONSE_SUFFIX_WITH_LORA = "Spoken text: The first words I spoke in the original chronograph, a little piece of practical poetry. Mary had a little lamb, it slept with quite a snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
|
||||
RESPONSE_SUFFIX_WITHOUT_LORA = "Certainly! Here is the transcription of the audio you provided:\n\nThe first words I spoke in the original phonograph record: A little piece of practical poetry. Mary had a little lamb; its fleece was white as snow, and everywhere that Mary went, the lamb was sure to go." # noqa: E501
|
||||
|
||||
VLLM_RUNNER_BASE_KWARGS = {
|
||||
"model_name": MODEL_PATH,
|
||||
"dtype": "half",
|
||||
"enable_lora": "True",
|
||||
"max_num_seqs": 2,
|
||||
"max_lora_rank": 320,
|
||||
"max_model_len": 12800,
|
||||
"gpu_memory_utilization": 0.8,
|
||||
"limit_mm_per_prompt": {
|
||||
"audio": 1
|
||||
},
|
||||
"enforce_eager": True,
|
||||
}
|
||||
|
||||
|
||||
def run_test(vllm_runner, audio_assets, lora_request, expected_suffix,
|
||||
**kwargs):
|
||||
inputs = [([AUDIO_PROMPT], [audio_assets[0].audio_and_sample_rate[0]])]
|
||||
|
||||
# Apply any additional kwargs as overrides to the base kwargs
|
||||
vllm_runner_kwargs = {**VLLM_RUNNER_BASE_KWARGS, **kwargs}
|
||||
|
||||
with vllm_runner(**vllm_runner_kwargs) as vllm_model:
|
||||
vllm_outputs_with_default_lora = [
|
||||
vllm_model.generate_greedy(
|
||||
prompts,
|
||||
max_tokens=128,
|
||||
audios=audios,
|
||||
lora_request=lora_request,
|
||||
) for prompts, audios in inputs
|
||||
]
|
||||
|
||||
assert vllm_outputs_with_default_lora[-1][-1][-1].endswith(
|
||||
expected_suffix)
|
||||
|
||||
|
||||
def test_active_default_mm_lora(
|
||||
vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
"""Ensure that we can use the default audio lora."""
|
||||
run_test(
|
||||
vllm_runner,
|
||||
audio_assets,
|
||||
lora_request=None,
|
||||
default_mm_loras={"audio": AUDIO_LORA_PATH},
|
||||
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
|
||||
)
|
||||
|
||||
|
||||
def test_inactive_default_mm_lora(
|
||||
vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
"""Ensure that modalities are filtered properly."""
|
||||
# Default image lora won't be active since we only pass audio
|
||||
run_test(
|
||||
vllm_runner,
|
||||
audio_assets,
|
||||
lora_request=None,
|
||||
default_mm_loras={"image": IMAGE_LORA_PATH},
|
||||
expected_suffix=RESPONSE_SUFFIX_WITHOUT_LORA,
|
||||
)
|
||||
|
||||
|
||||
def test_default_mm_lora_succeeds_with_redundant_lora_request(
|
||||
vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
"""Ensure that redundantly providing the lora works."""
|
||||
run_test(
|
||||
vllm_runner,
|
||||
audio_assets,
|
||||
lora_request=LoRARequest("audio", 1, AUDIO_LORA_PATH),
|
||||
default_mm_loras={"audio": AUDIO_LORA_PATH},
|
||||
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
|
||||
)
|
||||
|
||||
|
||||
def test_default_mm_lora_fails_with_overridden_lora_request(
|
||||
vllm_runner: type[VllmRunner],
|
||||
audio_assets: AudioTestAssets,
|
||||
):
|
||||
"""Ensure that if the lora_request conflicts with default_mm_loras,
|
||||
we use the lora_request."""
|
||||
run_test(
|
||||
vllm_runner,
|
||||
audio_assets,
|
||||
lora_request=LoRARequest("speech", 2, AUDIO_LORA_PATH),
|
||||
default_mm_loras={"audio": IMAGE_LORA_PATH},
|
||||
expected_suffix=RESPONSE_SUFFIX_WITH_LORA,
|
||||
)
|
||||
@ -33,6 +33,7 @@ import vllm.envs as envs
|
||||
from vllm import version
|
||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
@ -2989,6 +2990,16 @@ class LoRAConfig:
|
||||
trained with those scaling factors to be used at the same time. If not
|
||||
specified, only adapters trained with the base model scaling factor are
|
||||
allowed."""
|
||||
default_mm_loras: Optional[dict[str, str]] = None
|
||||
"""Dictionary mapping specific modalities to LoRA model paths; this field
|
||||
is only applicable to multimodal models and should be leveraged when a
|
||||
model always expects a LoRA to be active when a given modality is present.
|
||||
Note that currently, if a request provides multiple additional
|
||||
modalities, each of which have their own LoRA, we do NOT apply
|
||||
default_mm_loras because we currently only support one lora adapter
|
||||
per prompt. When run in offline mode, the lora IDs for n modalities
|
||||
will be automatically assigned to 1-n with the names of the modalities
|
||||
in alphabetic order."""
|
||||
bias_enabled: bool = False
|
||||
"""Enable bias for LoRA adapters."""
|
||||
|
||||
|
||||
@ -395,6 +395,8 @@ class EngineArgs:
|
||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||
max_loras: int = LoRAConfig.max_loras
|
||||
max_lora_rank: int = LoRAConfig.max_lora_rank
|
||||
default_mm_loras: Optional[Dict[str, str]] = \
|
||||
LoRAConfig.default_mm_loras
|
||||
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
|
||||
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
|
||||
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
|
||||
@ -807,6 +809,8 @@ class EngineArgs:
|
||||
**lora_kwargs["max_cpu_loras"])
|
||||
lora_group.add_argument("--fully-sharded-loras",
|
||||
**lora_kwargs["fully_sharded_loras"])
|
||||
lora_group.add_argument("--default-mm-loras",
|
||||
**lora_kwargs["default_mm_loras"])
|
||||
|
||||
# PromptAdapter related configs
|
||||
prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig)
|
||||
@ -1284,10 +1288,16 @@ class EngineArgs:
|
||||
disable_hybrid_kv_cache_manager,
|
||||
)
|
||||
|
||||
if not model_config.is_multimodal_model and self.default_mm_loras:
|
||||
raise ValueError(
|
||||
"Default modality-specific LoRA(s) were provided for a "
|
||||
"non multimodal model")
|
||||
|
||||
lora_config = LoRAConfig(
|
||||
bias_enabled=self.enable_lora_bias,
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
max_loras=self.max_loras,
|
||||
default_mm_loras=self.default_mm_loras,
|
||||
fully_sharded_loras=self.fully_sharded_loras,
|
||||
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||
long_lora_scaling_factors=self.long_lora_scaling_factors,
|
||||
|
||||
@ -499,6 +499,10 @@ class LLM:
|
||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
|
||||
# Add any modality specific loras to the corresponding prompts
|
||||
lora_request = self._get_modality_specific_lora_reqs(
|
||||
parsed_prompts, lora_request)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
params=sampling_params,
|
||||
@ -513,6 +517,83 @@ class LLM:
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
def _get_modality_specific_lora_reqs(
|
||||
self, parsed_prompts: Union[PromptType, Sequence[PromptType]],
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
|
||||
# Grab the lora config off the vllm config on the engine,
|
||||
# since this is the same for both v0 & v1.
|
||||
lora_config = self.llm_engine.vllm_config.lora_config
|
||||
|
||||
# If there's no lora config / default_mm_loras, or the model
|
||||
# isn't multimodal, leave the lora as is.
|
||||
if (lora_config is None
|
||||
or not self.llm_engine.model_config.is_multimodal_model
|
||||
or (lora_config and lora_config.default_mm_loras is None)):
|
||||
return lora_request
|
||||
|
||||
if not isinstance(parsed_prompts, Sequence):
|
||||
parsed_prompts = [parsed_prompts]
|
||||
|
||||
optional_loras = ([lora_request] * len(parsed_prompts)
|
||||
if not isinstance(lora_request, Sequence) else
|
||||
lora_request)
|
||||
|
||||
return [
|
||||
self._resolve_single_prompt_mm_lora(
|
||||
parsed_prompt,
|
||||
opt_lora_req,
|
||||
lora_config.default_mm_loras,
|
||||
) for parsed_prompt, opt_lora_req in zip(parsed_prompts,
|
||||
optional_loras)
|
||||
]
|
||||
|
||||
def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType,
|
||||
lora_request: Optional[LoRARequest],
|
||||
default_mm_loras: Optional[dict[str,
|
||||
str]]):
|
||||
if (not default_mm_loras or not isinstance(parsed_prompt, dict)
|
||||
or "multi_modal_data" not in parsed_prompt):
|
||||
return lora_request
|
||||
|
||||
parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt)
|
||||
|
||||
intersection = set(
|
||||
parsed_prompt["multi_modal_data"].keys()).intersection(
|
||||
default_mm_loras.keys())
|
||||
if not intersection:
|
||||
return lora_request
|
||||
if len(intersection) > 1:
|
||||
# TODO: Would be nice to be able to have multiple loras per prompt
|
||||
logger.warning(
|
||||
"Multiple modality specific loras were registered and would be"
|
||||
" used by a single prompt consuming several modalities; "
|
||||
" currently we only support one lora per request; as such,"
|
||||
" lora(s) registered with modalities: %s"
|
||||
" will be skipped", intersection)
|
||||
return lora_request
|
||||
|
||||
# Build the LoRA request; the ID of the default mm lora is the
|
||||
# index of the modality name sorted alphabetically + 1.
|
||||
modality_name = intersection.pop()
|
||||
modality_lora_path = default_mm_loras[modality_name]
|
||||
modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1
|
||||
|
||||
# If we have a collision, warn if there is a collision,
|
||||
# but always send the explicitly provided request.
|
||||
if lora_request:
|
||||
if lora_request.lora_int_id != modality_lora_id:
|
||||
logger.warning(
|
||||
"A modality with a registered lora and a lora_request "
|
||||
"with a different ID were provided; falling back to the "
|
||||
"lora_request as we only apply one LoRARequest per prompt")
|
||||
return lora_request
|
||||
|
||||
return LoRARequest(
|
||||
modality_name,
|
||||
modality_lora_id,
|
||||
modality_lora_path,
|
||||
)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
|
||||
@ -87,6 +87,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
LoRAModulePath,
|
||||
OpenAIServingModels)
|
||||
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||
@ -1481,11 +1482,28 @@ async def init_app_state(
|
||||
"This discrepancy may lead to performance degradation.",
|
||||
resolved_chat_template, args.model)
|
||||
|
||||
# Merge default_mm_loras into the static lora_modules
|
||||
default_mm_loras = (vllm_config.lora_config.default_mm_loras
|
||||
if vllm_config.lora_config is not None else {})
|
||||
|
||||
lora_modules = args.lora_modules
|
||||
if default_mm_loras:
|
||||
default_mm_lora_paths = [
|
||||
LoRAModulePath(
|
||||
name=modality,
|
||||
path=lora_path,
|
||||
) for modality, lora_path in default_mm_loras.items()
|
||||
]
|
||||
if args.lora_modules is None:
|
||||
lora_modules = default_mm_lora_paths
|
||||
else:
|
||||
lora_modules += default_mm_lora_paths
|
||||
|
||||
state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
)
|
||||
await state.openai_serving_models.init_static_loras()
|
||||
|
||||
@ -153,7 +153,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
) = self._maybe_get_adapters(request,
|
||||
supports_default_mm_loras=True)
|
||||
|
||||
model_name = self._get_model_name(request.model, lora_request)
|
||||
|
||||
|
||||
@ -458,20 +458,74 @@ class OpenAIServing:
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _get_active_default_mm_loras(
|
||||
self, request: AnyRequest) -> Optional[LoRARequest]:
|
||||
"""Determine if there are any active default multimodal loras."""
|
||||
# TODO: Currently this is only enabled for chat completions
|
||||
# to be better aligned with only being enabled for .generate
|
||||
# when run offline. It would be nice to support additional
|
||||
# tasks types in the future.
|
||||
message_types = self._get_message_types(request)
|
||||
default_mm_loras = set()
|
||||
|
||||
for lora in self.models.lora_requests.values():
|
||||
# Best effort match for default multimodal lora adapters;
|
||||
# There is probably a better way to do this, but currently
|
||||
# this matches against the set of 'types' in any content lists
|
||||
# up until '_', e.g., to match audio_url -> audio
|
||||
if lora.lora_name in message_types:
|
||||
default_mm_loras.add(lora)
|
||||
|
||||
# Currently only support default modality specific loras if
|
||||
# we have exactly one lora matched on the request.
|
||||
if len(default_mm_loras) == 1:
|
||||
return default_mm_loras.pop()
|
||||
return None
|
||||
|
||||
def _maybe_get_adapters(
|
||||
self, request: AnyRequest
|
||||
self,
|
||||
request: AnyRequest,
|
||||
supports_default_mm_loras: bool = False,
|
||||
) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
|
||||
None, PromptAdapterRequest]]:
|
||||
if self._is_model_supported(request.model):
|
||||
return None, None
|
||||
|
||||
if request.model in self.models.lora_requests:
|
||||
return self.models.lora_requests[request.model], None
|
||||
|
||||
# Currently only support default modality specific loras
|
||||
# if we have exactly one lora matched on the request.
|
||||
if supports_default_mm_loras:
|
||||
default_mm_lora = self._get_active_default_mm_loras(request)
|
||||
if default_mm_lora is not None:
|
||||
return default_mm_lora, None
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None, None
|
||||
|
||||
for prompt_adapter in self.models.prompt_adapter_requests:
|
||||
if request.model == prompt_adapter.prompt_adapter_name:
|
||||
return None, prompt_adapter
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
def _get_message_types(self, request: AnyRequest) -> set[str]:
|
||||
"""Retrieve the set of types from message content dicts up
|
||||
until `_`; we use this to match potential multimodal data
|
||||
with default per modality loras.
|
||||
"""
|
||||
message_types: set[str] = set()
|
||||
|
||||
if not hasattr(request, "messages"):
|
||||
return message_types
|
||||
|
||||
for message in request.messages:
|
||||
if (isinstance(message, dict) and "content" in message
|
||||
and isinstance(message["content"], list)):
|
||||
for content_dict in message["content"]:
|
||||
if "type" in content_dict:
|
||||
message_types.add(content_dict["type"].split("_")[0])
|
||||
return message_types
|
||||
|
||||
async def _normalize_prompt_text_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user