diff --git a/docs/features/README.md b/docs/features/README.md
index 10cc448cc2ee3..05ce0b57a9fc8 100644
--- a/docs/features/README.md
+++ b/docs/features/README.md
@@ -52,7 +52,7 @@ th:not(:first-child) {
| [mm](multimodal_inputs.md) | ✅ | ✅ | [🟠](gh-pr:4194)^ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | |
| best-of | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | | |
| beam-search | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ | |
-| [prompt-embeds](prompt_embeds.md) | ✅ | [❌](gh-issue:25096) | ? | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ? | ? | ❌ | ? | ? | ✅ |
+| [prompt-embeds](prompt_embeds.md) | ✅ | [❌](gh-issue:25096) | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❔ | ❔ | ❌ | ❔ | ❔ | ✅ |
\* Chunked prefill and prefix caching are only applicable to last-token pooling.
^ LoRA is only applicable to the language backbone of multimodal models.
diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 8e87a98e3d51d..d720fa2458e1d 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -403,7 +403,7 @@ th {
| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `OLMo3ForCausalLM` | OLMo3 | TBA | ✅︎ | ✅︎ | ✅︎ |
| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ |
-| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ |
+| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ |
| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
diff --git a/tests/entrypoints/conftest.py b/tests/entrypoints/conftest.py
index da75806ccf4de..7daf62595b1b6 100644
--- a/tests/entrypoints/conftest.py
+++ b/tests/entrypoints/conftest.py
@@ -208,3 +208,11 @@ def zephyr_lora_files():
"""Download zephyr LoRA files once per test session."""
from huggingface_hub import snapshot_download
return snapshot_download(repo_id="typeof/zephyr-7b-beta-lora")
+
+
+@pytest.fixture(scope="session")
+def opt125_lora_files() -> str:
+ """Download opt-125m LoRA files once per test session."""
+ from huggingface_hub import snapshot_download
+ return snapshot_download(
+ repo_id="peft-internal-testing/opt-125m-dummy-lora")
diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py
index ae51025455b10..cad9142823061 100644
--- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py
+++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py
@@ -3,6 +3,7 @@
import base64
import io
+import json
import openai # use the official client for correctness check
import pytest
@@ -16,13 +17,15 @@ from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
+LORA_SERVING_MODEL_NAME = "opt125m-lora"
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
-@pytest.fixture(scope="module")
-def default_server_args() -> list[str]:
- return [
+@pytest.fixture(scope="module", params=["use-lora"])
+def default_server_args(request: pytest.FixtureRequest,
+ opt125_lora_files: str) -> list[str]:
+ args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
@@ -35,6 +38,25 @@ def default_server_args() -> list[str]:
"--enable-prompt-embeds",
]
+ if request.param == "use-lora":
+ lora_module_1 = {
+ "name": LORA_SERVING_MODEL_NAME,
+ "path": opt125_lora_files,
+ "base_model_name": MODEL_NAME
+ }
+
+ args.extend([
+ "--enable-lora",
+ "--lora-module",
+ json.dumps(lora_module_1),
+ "--max-lora-rank",
+ "64",
+ "--max-cpu-loras",
+ "2",
+ ])
+
+ return args
+
EXAMPLE_PROMPTS = [
"Hello, my name is",
@@ -74,7 +96,7 @@ async def client_with_prompt_embeds(server_with_prompt_embeds):
@pytest.mark.asyncio
-@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
async def test_completions_with_prompt_embeds(
example_prompt_embeds,
client_with_prompt_embeds: openai.AsyncOpenAI,
@@ -179,7 +201,7 @@ async def test_completions_with_prompt_embeds(
@pytest.mark.asyncio
-@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
async def test_completions_errors_with_prompt_embeds(
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
# Test error case: invalid prompt_embeds
@@ -194,7 +216,7 @@ async def test_completions_errors_with_prompt_embeds(
@pytest.mark.asyncio
@pytest.mark.parametrize("logprobs_arg", [1, 0])
-@pytest.mark.parametrize("model_name", [MODEL_NAME])
+@pytest.mark.parametrize("model_name", [MODEL_NAME, LORA_SERVING_MODEL_NAME])
async def test_completions_with_logprobs_and_prompt_embeds(
example_prompt_embeds,
client_with_prompt_embeds: openai.AsyncOpenAI,
diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py
index 4c3ce9f61efb3..c4746166471c0 100644
--- a/vllm/model_executor/models/opt.py
+++ b/vllm/model_executor/models/opt.py
@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors
-from .interfaces import SupportsPP
+from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -352,10 +352,9 @@ class OPTModel(nn.Module):
return loaded_params
-class OPTForCausalLM(nn.Module, SupportsPP):
+class OPTForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
- "gate_up_proj": ["gate_proj", "up_proj"]
}
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={