mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 18:17:55 +08:00
[Frontend] Add sagemaker_standards dynamic lora adapter and stateful session management decorators to vLLM OpenAI API server (#27892)
Signed-off-by: Zuyi Zhao <zhaozuy@amazon.com> Signed-off-by: Shen Teng <sheteng@amazon.com> Co-authored-by: Shen Teng <sheteng@amazon.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
parent
8d706cca90
commit
bca74e32b7
@ -49,3 +49,4 @@ cbor2 # Required for cross-language serialization of hashable objects
|
|||||||
setproctitle # Used to set process names for better debugging and monitoring
|
setproctitle # Used to set process names for better debugging and monitoring
|
||||||
openai-harmony >= 0.0.3 # Required for gpt-oss
|
openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||||
anthropic == 0.71.0
|
anthropic == 0.71.0
|
||||||
|
model-hosting-container-standards < 1.0.0
|
||||||
0
tests/entrypoints/sagemaker/__init__.py
Normal file
0
tests/entrypoints/sagemaker/__init__.py
Normal file
58
tests/entrypoints/sagemaker/conftest.py
Normal file
58
tests/entrypoints/sagemaker/conftest.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""Shared fixtures and utilities for SageMaker tests."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
# Model name constants used across tests
|
||||||
|
MODEL_NAME_ZEPHYR = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct"
|
||||||
|
LORA_ADAPTER_NAME_SMOLLM = "jekunz/smollm-135m-lora-fineweb-faroese"
|
||||||
|
|
||||||
|
# SageMaker header constants
|
||||||
|
HEADER_SAGEMAKER_CLOSED_SESSION_ID = "X-Amzn-SageMaker-Closed-Session-Id"
|
||||||
|
HEADER_SAGEMAKER_SESSION_ID = "X-Amzn-SageMaker-Session-Id"
|
||||||
|
HEADER_SAGEMAKER_NEW_SESSION_ID = "X-Amzn-SageMaker-New-Session-Id"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def smollm2_lora_files():
|
||||||
|
"""Download LoRA files once per test session."""
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
return snapshot_download(repo_id=LORA_ADAPTER_NAME_SMOLLM)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def basic_server_with_lora(smollm2_lora_files):
|
||||||
|
"""Basic server fixture with standard configuration."""
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"8192",
|
||||||
|
"--enforce-eager",
|
||||||
|
# lora config below
|
||||||
|
"--enable-lora",
|
||||||
|
"--max-lora-rank",
|
||||||
|
"256",
|
||||||
|
"--max-cpu-loras",
|
||||||
|
"2",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"64",
|
||||||
|
]
|
||||||
|
|
||||||
|
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME_SMOLLM, args, env_dict=envs) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def async_client(basic_server_with_lora: RemoteOpenAIServer):
|
||||||
|
"""Async OpenAI client fixture for use with basic_server."""
|
||||||
|
async with basic_server_with_lora.get_async_client() as async_client:
|
||||||
|
yield async_client
|
||||||
734
tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py
Normal file
734
tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py
Normal file
@ -0,0 +1,734 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""Integration tests for handler override functionality.
|
||||||
|
|
||||||
|
Tests real customer usage scenarios:
|
||||||
|
- Using @custom_ping_handler and @custom_invocation_handler decorators
|
||||||
|
to override handlers
|
||||||
|
- Setting environment variables for handler specifications
|
||||||
|
- Writing customer scripts with custom_sagemaker_ping_handler() and
|
||||||
|
custom_sagemaker_invocation_handler() functions
|
||||||
|
- Priority: env vars > decorators > customer script files > framework
|
||||||
|
defaults
|
||||||
|
|
||||||
|
Note: These tests focus on validating server responses rather than directly calling
|
||||||
|
get_ping_handler() and get_invoke_handler() to ensure full integration testing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
from .conftest import (
|
||||||
|
MODEL_NAME_SMOLLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandlerOverrideIntegration:
|
||||||
|
"""Integration tests simulating real customer usage scenarios.
|
||||||
|
|
||||||
|
Each test simulates a fresh server startup where customers:
|
||||||
|
- Use @custom_ping_handler and @custom_invocation_handler decorators
|
||||||
|
- Set environment variables (CUSTOM_FASTAPI_PING_HANDLER, etc.)
|
||||||
|
- Write customer scripts with custom_sagemaker_ping_handler() and
|
||||||
|
custom_sagemaker_invocation_handler() functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup for each test - simulate fresh server startup."""
|
||||||
|
self._clear_caches()
|
||||||
|
self._clear_env_vars()
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Cleanup after each test."""
|
||||||
|
self._clear_env_vars()
|
||||||
|
|
||||||
|
def _clear_caches(self):
|
||||||
|
"""Clear handler registry and function loader cache."""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.common.handler import (
|
||||||
|
handler_registry,
|
||||||
|
)
|
||||||
|
from model_hosting_container_standards.sagemaker.sagemaker_loader import (
|
||||||
|
SageMakerFunctionLoader,
|
||||||
|
)
|
||||||
|
|
||||||
|
handler_registry.clear()
|
||||||
|
SageMakerFunctionLoader._default_function_loader = None
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("model-hosting-container-standards not available")
|
||||||
|
|
||||||
|
def _clear_env_vars(self):
|
||||||
|
"""Clear SageMaker environment variables."""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.common.fastapi.config import (
|
||||||
|
FastAPIEnvVars,
|
||||||
|
)
|
||||||
|
from model_hosting_container_standards.sagemaker.config import (
|
||||||
|
SageMakerEnvVars,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear SageMaker env vars
|
||||||
|
for var in [
|
||||||
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH,
|
||||||
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME,
|
||||||
|
]:
|
||||||
|
os.environ.pop(var, None)
|
||||||
|
|
||||||
|
# Clear FastAPI env vars
|
||||||
|
for var in [
|
||||||
|
FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER,
|
||||||
|
FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER,
|
||||||
|
]:
|
||||||
|
os.environ.pop(var, None)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_customer_script_functions_auto_loaded(self):
|
||||||
|
"""Test customer scenario: script functions automatically override
|
||||||
|
framework defaults."""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.sagemaker.config import (
|
||||||
|
SageMakerEnvVars,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("model-hosting-container-standards not available")
|
||||||
|
|
||||||
|
# Customer writes a script file with ping() and invoke() functions
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||||
|
f.write(
|
||||||
|
"""
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
async def custom_sagemaker_ping_handler():
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"source": "customer_override",
|
||||||
|
"message": "Custom ping from customer script"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def custom_sagemaker_invocation_handler(request: Request):
|
||||||
|
return {
|
||||||
|
"predictions": ["Custom response from customer script"],
|
||||||
|
"source": "customer_override"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
script_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
script_dir = os.path.dirname(script_path)
|
||||||
|
script_name = os.path.basename(script_path)
|
||||||
|
|
||||||
|
# Customer sets SageMaker environment variables to point to their script
|
||||||
|
env_vars = {
|
||||||
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
||||||
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"32",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
||||||
|
) as server:
|
||||||
|
# Customer tests their server and sees their overrides work
|
||||||
|
# automatically
|
||||||
|
ping_response = requests.get(server.url_for("ping"))
|
||||||
|
assert ping_response.status_code == 200
|
||||||
|
ping_data = ping_response.json()
|
||||||
|
|
||||||
|
invoke_response = requests.post(
|
||||||
|
server.url_for("invocations"),
|
||||||
|
json={
|
||||||
|
"model": MODEL_NAME_SMOLLM,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert invoke_response.status_code == 200
|
||||||
|
invoke_data = invoke_response.json()
|
||||||
|
|
||||||
|
# Customer sees their functions are used
|
||||||
|
assert ping_data["source"] == "customer_override"
|
||||||
|
assert ping_data["message"] == "Custom ping from customer script"
|
||||||
|
assert invoke_data["source"] == "customer_override"
|
||||||
|
assert invoke_data["predictions"] == [
|
||||||
|
"Custom response from customer script"
|
||||||
|
]
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(script_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_customer_decorator_usage(self):
|
||||||
|
"""Test customer scenario: using @custom_ping_handler and
|
||||||
|
@custom_invocation_handler decorators."""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.sagemaker.config import (
|
||||||
|
SageMakerEnvVars,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("model-hosting-container-standards not available")
|
||||||
|
|
||||||
|
# Customer writes a script file with decorators
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||||
|
f.write(
|
||||||
|
"""
|
||||||
|
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
@sagemaker_standards.custom_ping_handler
|
||||||
|
async def my_ping():
|
||||||
|
return {
|
||||||
|
"type": "ping",
|
||||||
|
"source": "customer_decorator"
|
||||||
|
}
|
||||||
|
|
||||||
|
@sagemaker_standards.custom_invocation_handler
|
||||||
|
async def my_invoke(request: Request):
|
||||||
|
return {
|
||||||
|
"type": "invoke",
|
||||||
|
"source": "customer_decorator"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
script_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
script_dir = os.path.dirname(script_path)
|
||||||
|
script_name = os.path.basename(script_path)
|
||||||
|
|
||||||
|
env_vars = {
|
||||||
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
||||||
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"32",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
||||||
|
) as server:
|
||||||
|
ping_response = requests.get(server.url_for("ping"))
|
||||||
|
assert ping_response.status_code == 200
|
||||||
|
ping_data = ping_response.json()
|
||||||
|
|
||||||
|
invoke_response = requests.post(
|
||||||
|
server.url_for("invocations"),
|
||||||
|
json={
|
||||||
|
"model": MODEL_NAME_SMOLLM,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert invoke_response.status_code == 200
|
||||||
|
invoke_data = invoke_response.json()
|
||||||
|
|
||||||
|
# Customer sees their handlers are used by the server
|
||||||
|
assert ping_data["source"] == "customer_decorator"
|
||||||
|
assert invoke_data["source"] == "customer_decorator"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(script_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handler_priority_order(self):
|
||||||
|
"""Test priority: @custom_ping_handler/@custom_invocation_handler
|
||||||
|
decorators vs script functions."""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.sagemaker.config import (
|
||||||
|
SageMakerEnvVars,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("model-hosting-container-standards not available")
|
||||||
|
|
||||||
|
# Customer writes a script with both decorator and regular functions
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||||
|
f.write(
|
||||||
|
"""
|
||||||
|
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
# Customer uses @custom_ping_handler decorator (higher priority than script functions)
|
||||||
|
@sagemaker_standards.custom_ping_handler
|
||||||
|
async def decorated_ping():
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"source": "ping_decorator_in_script",
|
||||||
|
"priority": "decorator"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Customer also has a regular function (lower priority than
|
||||||
|
# @custom_ping_handler decorator)
|
||||||
|
async def custom_sagemaker_ping_handler():
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"source": "script_function",
|
||||||
|
"priority": "function"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Customer has a regular invoke function
|
||||||
|
async def custom_sagemaker_invocation_handler(request: Request):
|
||||||
|
return {
|
||||||
|
"predictions": ["Script function response"],
|
||||||
|
"source": "script_invoke_function",
|
||||||
|
"priority": "function"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
script_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
script_dir = os.path.dirname(script_path)
|
||||||
|
script_name = os.path.basename(script_path)
|
||||||
|
|
||||||
|
env_vars = {
|
||||||
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
||||||
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"32",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
||||||
|
) as server:
|
||||||
|
ping_response = requests.get(server.url_for("ping"))
|
||||||
|
assert ping_response.status_code == 200
|
||||||
|
ping_data = ping_response.json()
|
||||||
|
|
||||||
|
invoke_response = requests.post(
|
||||||
|
server.url_for("invocations"),
|
||||||
|
json={
|
||||||
|
"model": MODEL_NAME_SMOLLM,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert invoke_response.status_code == 200
|
||||||
|
invoke_data = invoke_response.json()
|
||||||
|
|
||||||
|
# @custom_ping_handler decorator has higher priority than
|
||||||
|
# script function
|
||||||
|
assert ping_data["source"] == "ping_decorator_in_script"
|
||||||
|
assert ping_data["priority"] == "decorator"
|
||||||
|
|
||||||
|
# Script function is used for invoke
|
||||||
|
assert invoke_data["source"] == "script_invoke_function"
|
||||||
|
assert invoke_data["priority"] == "function"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(script_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_environment_variable_script_loading(self):
|
||||||
|
"""Test that environment variables correctly specify script location
|
||||||
|
and loading."""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.sagemaker.config import (
|
||||||
|
SageMakerEnvVars,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("model-hosting-container-standards not available")
|
||||||
|
|
||||||
|
# Customer writes a script in a specific directory
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||||
|
f.write(
|
||||||
|
"""
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
async def custom_sagemaker_ping_handler():
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"source": "env_loaded_script",
|
||||||
|
"method": "environment_variable_loading"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def custom_sagemaker_invocation_handler(request: Request):
|
||||||
|
return {
|
||||||
|
"predictions": ["Loaded via environment variables"],
|
||||||
|
"source": "env_loaded_script",
|
||||||
|
"method": "environment_variable_loading"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
script_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
script_dir = os.path.dirname(script_path)
|
||||||
|
script_name = os.path.basename(script_path)
|
||||||
|
|
||||||
|
# Test environment variable script loading
|
||||||
|
env_vars = {
|
||||||
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
||||||
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"32",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
||||||
|
) as server:
|
||||||
|
ping_response = requests.get(server.url_for("ping"))
|
||||||
|
assert ping_response.status_code == 200
|
||||||
|
ping_data = ping_response.json()
|
||||||
|
|
||||||
|
invoke_response = requests.post(
|
||||||
|
server.url_for("invocations"),
|
||||||
|
json={
|
||||||
|
"model": MODEL_NAME_SMOLLM,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert invoke_response.status_code == 200
|
||||||
|
invoke_data = invoke_response.json()
|
||||||
|
|
||||||
|
# Verify that the script was loaded via environment variables
|
||||||
|
assert ping_data["source"] == "env_loaded_script"
|
||||||
|
assert ping_data["method"] == "environment_variable_loading"
|
||||||
|
assert invoke_data["source"] == "env_loaded_script"
|
||||||
|
assert invoke_data["method"] == "environment_variable_loading"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(script_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_framework_default_handlers(self):
|
||||||
|
"""Test that framework default handlers work when no customer
|
||||||
|
overrides exist."""
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"32",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Explicitly pass empty env_dict to ensure no SageMaker env vars are set
|
||||||
|
# This prevents pollution from previous tests
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.common.fastapi.config import (
|
||||||
|
FastAPIEnvVars,
|
||||||
|
)
|
||||||
|
from model_hosting_container_standards.sagemaker.config import (
|
||||||
|
SageMakerEnvVars,
|
||||||
|
)
|
||||||
|
|
||||||
|
env_dict = {
|
||||||
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: "",
|
||||||
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: "",
|
||||||
|
FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: "",
|
||||||
|
FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: "",
|
||||||
|
}
|
||||||
|
except ImportError:
|
||||||
|
env_dict = {}
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(MODEL_NAME_SMOLLM, args, env_dict=env_dict) as server:
|
||||||
|
# Test that default ping works
|
||||||
|
ping_response = requests.get(server.url_for("ping"))
|
||||||
|
assert ping_response.status_code == 200
|
||||||
|
|
||||||
|
# Test that default invocations work
|
||||||
|
invoke_response = requests.post(
|
||||||
|
server.url_for("invocations"),
|
||||||
|
json={
|
||||||
|
"model": MODEL_NAME_SMOLLM,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert invoke_response.status_code == 200
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handler_env_var_override(self):
|
||||||
|
"""Test CUSTOM_FASTAPI_PING_HANDLER and CUSTOM_FASTAPI_INVOCATION_HANDLER
|
||||||
|
environment variable overrides."""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.common.fastapi.config import (
|
||||||
|
FastAPIEnvVars,
|
||||||
|
)
|
||||||
|
from model_hosting_container_standards.sagemaker.config import (
|
||||||
|
SageMakerEnvVars,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("model-hosting-container-standards not available")
|
||||||
|
|
||||||
|
# Create a script with both env var handlers and script functions
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||||
|
f.write(
|
||||||
|
"""
|
||||||
|
from fastapi import Request, Response
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def env_var_ping_handler(raw_request: Request) -> Response:
|
||||||
|
return Response(
|
||||||
|
content=json.dumps({
|
||||||
|
"status": "healthy",
|
||||||
|
"source": "env_var_ping",
|
||||||
|
"method": "environment_variable"
|
||||||
|
}),
|
||||||
|
media_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def env_var_invoke_handler(raw_request: Request) -> Response:
|
||||||
|
return Response(
|
||||||
|
content=json.dumps({
|
||||||
|
"predictions": ["Environment variable response"],
|
||||||
|
"source": "env_var_invoke",
|
||||||
|
"method": "environment_variable"
|
||||||
|
}),
|
||||||
|
media_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def custom_sagemaker_ping_handler():
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"source": "script_ping",
|
||||||
|
"method": "script_function"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def custom_sagemaker_invocation_handler(request: Request):
|
||||||
|
return {
|
||||||
|
"predictions": ["Script function response"],
|
||||||
|
"source": "script_invoke",
|
||||||
|
"method": "script_function"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
script_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
script_dir = os.path.dirname(script_path)
|
||||||
|
script_name = os.path.basename(script_path)
|
||||||
|
|
||||||
|
# Set environment variables to override both handlers
|
||||||
|
env_vars = {
|
||||||
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
||||||
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
||||||
|
FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: (
|
||||||
|
f"{script_name}:env_var_ping_handler"
|
||||||
|
),
|
||||||
|
FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: (
|
||||||
|
f"{script_name}:env_var_invoke_handler"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"32",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
||||||
|
) as server:
|
||||||
|
# Test ping handler override
|
||||||
|
ping_response = requests.get(server.url_for("ping"))
|
||||||
|
assert ping_response.status_code == 200
|
||||||
|
ping_data = ping_response.json()
|
||||||
|
|
||||||
|
# Environment variable should override script function
|
||||||
|
assert ping_data["method"] == "environment_variable"
|
||||||
|
assert ping_data["source"] == "env_var_ping"
|
||||||
|
|
||||||
|
# Test invocation handler override
|
||||||
|
invoke_response = requests.post(
|
||||||
|
server.url_for("invocations"),
|
||||||
|
json={
|
||||||
|
"model": MODEL_NAME_SMOLLM,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert invoke_response.status_code == 200
|
||||||
|
invoke_data = invoke_response.json()
|
||||||
|
|
||||||
|
# Environment variable should override script function
|
||||||
|
assert invoke_data["method"] == "environment_variable"
|
||||||
|
assert invoke_data["source"] == "env_var_invoke"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(script_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_env_var_priority_over_decorator_and_script(self):
|
||||||
|
"""Test that environment variables have highest priority over decorators
|
||||||
|
and script functions for both ping and invocation handlers."""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.common.fastapi.config import (
|
||||||
|
FastAPIEnvVars,
|
||||||
|
)
|
||||||
|
from model_hosting_container_standards.sagemaker.config import (
|
||||||
|
SageMakerEnvVars,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("model-hosting-container-standards not available")
|
||||||
|
|
||||||
|
# Create a script with all three handler types for both ping and invocation
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||||
|
f.write(
|
||||||
|
"""
|
||||||
|
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||||
|
from fastapi import Request, Response
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Environment variable handlers (highest priority)
|
||||||
|
async def env_priority_ping(raw_request: Request) -> Response:
|
||||||
|
return Response(
|
||||||
|
content=json.dumps({
|
||||||
|
"status": "healthy",
|
||||||
|
"source": "env_var",
|
||||||
|
"priority": "environment_variable"
|
||||||
|
}),
|
||||||
|
media_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def env_priority_invoke(raw_request: Request) -> Response:
|
||||||
|
return Response(
|
||||||
|
content=json.dumps({
|
||||||
|
"predictions": ["Environment variable response"],
|
||||||
|
"source": "env_var",
|
||||||
|
"priority": "environment_variable"
|
||||||
|
}),
|
||||||
|
media_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decorator handlers (medium priority)
|
||||||
|
@sagemaker_standards.custom_ping_handler
|
||||||
|
async def decorator_ping(raw_request: Request) -> Response:
|
||||||
|
return Response(
|
||||||
|
content=json.dumps({
|
||||||
|
"status": "healthy",
|
||||||
|
"source": "decorator",
|
||||||
|
"priority": "decorator"
|
||||||
|
}),
|
||||||
|
media_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
@sagemaker_standards.custom_invocation_handler
|
||||||
|
async def decorator_invoke(raw_request: Request) -> Response:
|
||||||
|
return Response(
|
||||||
|
content=json.dumps({
|
||||||
|
"predictions": ["Decorator response"],
|
||||||
|
"source": "decorator",
|
||||||
|
"priority": "decorator"
|
||||||
|
}),
|
||||||
|
media_type="application/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Script functions (lowest priority)
|
||||||
|
async def custom_sagemaker_ping_handler():
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"source": "script",
|
||||||
|
"priority": "script_function"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def custom_sagemaker_invocation_handler(request: Request):
|
||||||
|
return {
|
||||||
|
"predictions": ["Script function response"],
|
||||||
|
"source": "script",
|
||||||
|
"priority": "script_function"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
script_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
script_dir = os.path.dirname(script_path)
|
||||||
|
script_name = os.path.basename(script_path)
|
||||||
|
|
||||||
|
# Set environment variables to specify highest priority handlers
|
||||||
|
env_vars = {
|
||||||
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
||||||
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
||||||
|
FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: (
|
||||||
|
f"{script_name}:env_priority_ping"
|
||||||
|
),
|
||||||
|
FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: (
|
||||||
|
f"{script_name}:env_priority_invoke"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"32",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
||||||
|
) as server:
|
||||||
|
# Test ping handler priority
|
||||||
|
ping_response = requests.get(server.url_for("ping"))
|
||||||
|
assert ping_response.status_code == 200
|
||||||
|
ping_data = ping_response.json()
|
||||||
|
|
||||||
|
# Environment variable has highest priority and should be used
|
||||||
|
assert ping_data["priority"] == "environment_variable"
|
||||||
|
assert ping_data["source"] == "env_var"
|
||||||
|
|
||||||
|
# Test invocation handler priority
|
||||||
|
invoke_response = requests.post(
|
||||||
|
server.url_for("invocations"),
|
||||||
|
json={
|
||||||
|
"model": MODEL_NAME_SMOLLM,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert invoke_response.status_code == 200
|
||||||
|
invoke_data = invoke_response.json()
|
||||||
|
|
||||||
|
# Environment variable has highest priority and should be used
|
||||||
|
assert invoke_data["priority"] == "environment_variable"
|
||||||
|
assert invoke_data["source"] == "env_var"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(script_path)
|
||||||
171
tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py
Normal file
171
tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import openai # use the official async_client for correctness check
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
from .conftest import MODEL_NAME_SMOLLM
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sagemaker_load_adapter_happy_path(
|
||||||
|
async_client: openai.AsyncOpenAI,
|
||||||
|
basic_server_with_lora: RemoteOpenAIServer,
|
||||||
|
smollm2_lora_files,
|
||||||
|
):
|
||||||
|
# The SageMaker standards library creates a POST /adapters endpoint
|
||||||
|
# that maps to the load_lora_adapter handler with request shape:
|
||||||
|
# {"lora_name": "body.name", "lora_path": "body.src"}
|
||||||
|
load_response = requests.post(
|
||||||
|
basic_server_with_lora.url_for("adapters"),
|
||||||
|
json={"name": "smollm2-lora-sagemaker", "src": smollm2_lora_files},
|
||||||
|
)
|
||||||
|
load_response.raise_for_status()
|
||||||
|
|
||||||
|
models = await async_client.models.list()
|
||||||
|
models = models.data
|
||||||
|
dynamic_lora_model = models[-1]
|
||||||
|
assert dynamic_lora_model.root == smollm2_lora_files
|
||||||
|
assert dynamic_lora_model.parent == MODEL_NAME_SMOLLM
|
||||||
|
assert dynamic_lora_model.id == "smollm2-lora-sagemaker"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sagemaker_unload_adapter_happy_path(
|
||||||
|
async_client: openai.AsyncOpenAI,
|
||||||
|
basic_server_with_lora: RemoteOpenAIServer,
|
||||||
|
smollm2_lora_files,
|
||||||
|
):
|
||||||
|
# First, load an adapter
|
||||||
|
adapter_name = "smollm2-lora-sagemaker-unload"
|
||||||
|
load_response = requests.post(
|
||||||
|
basic_server_with_lora.url_for("adapters"),
|
||||||
|
json={"name": adapter_name, "src": smollm2_lora_files},
|
||||||
|
)
|
||||||
|
load_response.raise_for_status()
|
||||||
|
|
||||||
|
# Verify it's in the models list
|
||||||
|
models = await async_client.models.list()
|
||||||
|
adapter_ids = [model.id for model in models.data]
|
||||||
|
assert adapter_name in adapter_ids
|
||||||
|
|
||||||
|
# Now unload it using DELETE /adapters/{adapter_name}
|
||||||
|
# The SageMaker standards maps this to unload_lora_adapter with:
|
||||||
|
# {"lora_name": "path_params.adapter_name"}
|
||||||
|
unload_response = requests.delete(
|
||||||
|
basic_server_with_lora.url_for("adapters", adapter_name),
|
||||||
|
)
|
||||||
|
unload_response.raise_for_status()
|
||||||
|
|
||||||
|
# Verify it's no longer in the models list
|
||||||
|
models = await async_client.models.list()
|
||||||
|
adapter_ids = [model.id for model in models.data]
|
||||||
|
assert adapter_name not in adapter_ids
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sagemaker_load_adapter_not_found(
|
||||||
|
basic_server_with_lora: RemoteOpenAIServer,
|
||||||
|
):
|
||||||
|
load_response = requests.post(
|
||||||
|
basic_server_with_lora.url_for("adapters"),
|
||||||
|
json={"name": "nonexistent-adapter", "src": "/path/does/not/exist"},
|
||||||
|
)
|
||||||
|
assert load_response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sagemaker_load_adapter_invalid_files(
|
||||||
|
basic_server_with_lora: RemoteOpenAIServer,
|
||||||
|
tmp_path,
|
||||||
|
):
|
||||||
|
invalid_files = tmp_path / "invalid_adapter"
|
||||||
|
invalid_files.mkdir()
|
||||||
|
(invalid_files / "adapter_config.json").write_text("not valid json")
|
||||||
|
|
||||||
|
load_response = requests.post(
|
||||||
|
basic_server_with_lora.url_for("adapters"),
|
||||||
|
json={"name": "invalid-adapter", "src": str(invalid_files)},
|
||||||
|
)
|
||||||
|
assert load_response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sagemaker_unload_nonexistent_adapter(
|
||||||
|
basic_server_with_lora: RemoteOpenAIServer,
|
||||||
|
):
|
||||||
|
# Attempt to unload an adapter that doesn't exist
|
||||||
|
unload_response = requests.delete(
|
||||||
|
basic_server_with_lora.url_for("adapters", "nonexistent-adapter-name"),
|
||||||
|
)
|
||||||
|
assert unload_response.status_code in (400, 404)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sagemaker_invocations_with_adapter(
|
||||||
|
basic_server_with_lora: RemoteOpenAIServer,
|
||||||
|
smollm2_lora_files,
|
||||||
|
):
|
||||||
|
# First, load an adapter via SageMaker endpoint
|
||||||
|
adapter_name = "smollm2-lora-invoke-test"
|
||||||
|
load_response = requests.post(
|
||||||
|
basic_server_with_lora.url_for("adapters"),
|
||||||
|
json={"name": adapter_name, "src": smollm2_lora_files},
|
||||||
|
)
|
||||||
|
load_response.raise_for_status()
|
||||||
|
|
||||||
|
# Now test the /invocations endpoint with the adapter
|
||||||
|
invocation_response = requests.post(
|
||||||
|
basic_server_with_lora.url_for("invocations"),
|
||||||
|
headers={
|
||||||
|
"X-Amzn-SageMaker-Adapter-Identifier": adapter_name,
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"prompt": "Hello, how are you?",
|
||||||
|
"max_tokens": 10,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
invocation_response.raise_for_status()
|
||||||
|
invocation_output = invocation_response.json()
|
||||||
|
|
||||||
|
# Verify we got a valid completion response
|
||||||
|
assert "choices" in invocation_output
|
||||||
|
assert len(invocation_output["choices"]) > 0
|
||||||
|
assert "text" in invocation_output["choices"][0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sagemaker_multiple_adapters_load_unload(
|
||||||
|
async_client: openai.AsyncOpenAI,
|
||||||
|
basic_server_with_lora: RemoteOpenAIServer,
|
||||||
|
smollm2_lora_files,
|
||||||
|
):
|
||||||
|
adapter_names = [f"sagemaker-adapter-{i}" for i in range(5)]
|
||||||
|
|
||||||
|
# Load all adapters
|
||||||
|
for adapter_name in adapter_names:
|
||||||
|
load_response = requests.post(
|
||||||
|
basic_server_with_lora.url_for("adapters"),
|
||||||
|
json={"name": adapter_name, "src": smollm2_lora_files},
|
||||||
|
)
|
||||||
|
load_response.raise_for_status()
|
||||||
|
|
||||||
|
# Verify all are in the models list
|
||||||
|
models = await async_client.models.list()
|
||||||
|
adapter_ids = [model.id for model in models.data]
|
||||||
|
for adapter_name in adapter_names:
|
||||||
|
assert adapter_name in adapter_ids
|
||||||
|
|
||||||
|
# Unload all adapters
|
||||||
|
for adapter_name in adapter_names:
|
||||||
|
unload_response = requests.delete(
|
||||||
|
basic_server_with_lora.url_for("adapters", adapter_name),
|
||||||
|
)
|
||||||
|
unload_response.raise_for_status()
|
||||||
|
|
||||||
|
# Verify all are removed from models list
|
||||||
|
models = await async_client.models.list()
|
||||||
|
adapter_ids = [model.id for model in models.data]
|
||||||
|
for adapter_name in adapter_names:
|
||||||
|
assert adapter_name not in adapter_ids
|
||||||
@ -0,0 +1,346 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""Integration test for middleware loader functionality.
|
||||||
|
|
||||||
|
Tests that customer middlewares get called correctly with a vLLM server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
from .conftest import (
|
||||||
|
MODEL_NAME_SMOLLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMiddlewareIntegration:
|
||||||
|
"""Integration test for middleware with vLLM server."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup for each test - simulate fresh server startup."""
|
||||||
|
self._clear_caches()
|
||||||
|
|
||||||
|
def _clear_caches(self):
|
||||||
|
"""Clear middleware registry and function loader cache."""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.common.fastapi.middleware import (
|
||||||
|
middleware_registry,
|
||||||
|
)
|
||||||
|
from model_hosting_container_standards.common.fastapi.middleware.source.decorator_loader import ( # noqa: E501
|
||||||
|
decorator_loader,
|
||||||
|
)
|
||||||
|
from model_hosting_container_standards.sagemaker.sagemaker_loader import (
|
||||||
|
SageMakerFunctionLoader,
|
||||||
|
)
|
||||||
|
|
||||||
|
middleware_registry.clear_middlewares()
|
||||||
|
decorator_loader.clear()
|
||||||
|
SageMakerFunctionLoader._default_function_loader = None
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("model-hosting-container-standards not available")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_customer_middleware_with_vllm_server(self):
|
||||||
|
"""Test that customer middlewares work with actual vLLM server.
|
||||||
|
|
||||||
|
Tests decorator-based middlewares (@custom_middleware, @input_formatter,
|
||||||
|
@output_formatter)
|
||||||
|
on multiple endpoints (chat/completions, invocations).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.sagemaker.config import (
|
||||||
|
SageMakerEnvVars,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("model-hosting-container-standards not available")
|
||||||
|
|
||||||
|
# Customer writes a middleware script with multiple decorators
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||||
|
f.write(
|
||||||
|
"""
|
||||||
|
from model_hosting_container_standards.common.fastapi.middleware import (
|
||||||
|
custom_middleware, input_formatter, output_formatter
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global flag to track if input formatter was called
|
||||||
|
_input_formatter_called = False
|
||||||
|
|
||||||
|
@input_formatter
|
||||||
|
async def customer_input_formatter(request):
|
||||||
|
# Process input - mark that input formatter was called
|
||||||
|
global _input_formatter_called
|
||||||
|
_input_formatter_called = True
|
||||||
|
return request
|
||||||
|
|
||||||
|
@custom_middleware("throttle")
|
||||||
|
async def customer_throttle_middleware(request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["X-Customer-Throttle"] = "applied"
|
||||||
|
order = response.headers.get("X-Middleware-Order", "")
|
||||||
|
response.headers["X-Middleware-Order"] = order + "throttle,"
|
||||||
|
return response
|
||||||
|
|
||||||
|
@output_formatter
|
||||||
|
async def customer_output_formatter(response):
|
||||||
|
global _input_formatter_called
|
||||||
|
response.headers["X-Customer-Processed"] = "true"
|
||||||
|
# Since input_formatter and output_formatter are combined into
|
||||||
|
# pre_post_process middleware,
|
||||||
|
# if output_formatter is called, input_formatter should have been called too
|
||||||
|
if _input_formatter_called:
|
||||||
|
response.headers["X-Input-Formatter-Called"] = "true"
|
||||||
|
order = response.headers.get("X-Middleware-Order", "")
|
||||||
|
response.headers["X-Middleware-Order"] = order + "output_formatter,"
|
||||||
|
return response
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
script_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
script_dir = os.path.dirname(script_path)
|
||||||
|
script_name = os.path.basename(script_path)
|
||||||
|
|
||||||
|
# Set environment variables to point to customer script
|
||||||
|
env_vars = {
|
||||||
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
||||||
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"32",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
||||||
|
) as server:
|
||||||
|
# Test 1: Middlewares applied to chat/completions endpoint
|
||||||
|
chat_response = requests.post(
|
||||||
|
server.url_for("v1/chat/completions"),
|
||||||
|
json={
|
||||||
|
"model": MODEL_NAME_SMOLLM,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 5,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert chat_response.status_code == 200
|
||||||
|
|
||||||
|
# Verify all middlewares were executed
|
||||||
|
assert "X-Customer-Throttle" in chat_response.headers
|
||||||
|
assert chat_response.headers["X-Customer-Throttle"] == "applied"
|
||||||
|
assert "X-Customer-Processed" in chat_response.headers
|
||||||
|
assert chat_response.headers["X-Customer-Processed"] == "true"
|
||||||
|
|
||||||
|
# Verify input formatter was called
|
||||||
|
assert "X-Input-Formatter-Called" in chat_response.headers
|
||||||
|
assert chat_response.headers["X-Input-Formatter-Called"] == "true"
|
||||||
|
|
||||||
|
# Verify middleware execution order
|
||||||
|
execution_order = chat_response.headers.get(
|
||||||
|
"X-Middleware-Order", ""
|
||||||
|
).rstrip(",")
|
||||||
|
order_parts = execution_order.split(",") if execution_order else []
|
||||||
|
assert "throttle" in order_parts
|
||||||
|
assert "output_formatter" in order_parts
|
||||||
|
|
||||||
|
# Test 2: Middlewares applied to invocations endpoint
|
||||||
|
invocations_response = requests.post(
|
||||||
|
server.url_for("invocations"),
|
||||||
|
json={
|
||||||
|
"model": MODEL_NAME_SMOLLM,
|
||||||
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
|
"max_tokens": 5,
|
||||||
|
"temperature": 0.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert invocations_response.status_code == 200
|
||||||
|
|
||||||
|
# Verify all middlewares were executed
|
||||||
|
assert "X-Customer-Throttle" in invocations_response.headers
|
||||||
|
assert invocations_response.headers["X-Customer-Throttle"] == "applied"
|
||||||
|
assert "X-Customer-Processed" in invocations_response.headers
|
||||||
|
assert invocations_response.headers["X-Customer-Processed"] == "true"
|
||||||
|
|
||||||
|
# Verify input formatter was called
|
||||||
|
assert "X-Input-Formatter-Called" in invocations_response.headers
|
||||||
|
assert (
|
||||||
|
invocations_response.headers["X-Input-Formatter-Called"] == "true"
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(script_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_middleware_with_ping_endpoint(self):
|
||||||
|
"""Test that middlewares work with SageMaker ping endpoint."""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.sagemaker.config import (
|
||||||
|
SageMakerEnvVars,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("model-hosting-container-standards not available")
|
||||||
|
|
||||||
|
# Customer writes a middleware script
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||||
|
f.write(
|
||||||
|
"""
|
||||||
|
from model_hosting_container_standards.common.fastapi.middleware import (
|
||||||
|
custom_middleware
|
||||||
|
)
|
||||||
|
|
||||||
|
@custom_middleware("pre_post_process")
|
||||||
|
async def ping_tracking_middleware(request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
if request.url.path == "/ping":
|
||||||
|
response.headers["X-Ping-Tracked"] = "true"
|
||||||
|
return response
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
script_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
script_dir = os.path.dirname(script_path)
|
||||||
|
script_name = os.path.basename(script_path)
|
||||||
|
|
||||||
|
env_vars = {
|
||||||
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
||||||
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"32",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
||||||
|
) as server:
|
||||||
|
# Test ping endpoint with middleware
|
||||||
|
response = requests.get(server.url_for("ping"))
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "X-Ping-Tracked" in response.headers
|
||||||
|
assert response.headers["X-Ping-Tracked"] == "true"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(script_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_middleware_env_var_override(self):
|
||||||
|
"""Test middleware environment variable overrides."""
|
||||||
|
try:
|
||||||
|
from model_hosting_container_standards.common.fastapi.config import (
|
||||||
|
FastAPIEnvVars,
|
||||||
|
)
|
||||||
|
from model_hosting_container_standards.sagemaker.config import (
|
||||||
|
SageMakerEnvVars,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
pytest.skip("model-hosting-container-standards not available")
|
||||||
|
|
||||||
|
# Create a script with middleware functions specified via env vars
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
||||||
|
f.write(
|
||||||
|
"""
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
# Global flag to track if pre_process was called
|
||||||
|
_pre_process_called = False
|
||||||
|
|
||||||
|
async def env_throttle_middleware(request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["X-Env-Throttle"] = "applied"
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def env_pre_process(request: Request) -> Request:
|
||||||
|
# Mark that pre_process was called
|
||||||
|
global _pre_process_called
|
||||||
|
_pre_process_called = True
|
||||||
|
return request
|
||||||
|
|
||||||
|
async def env_post_process(response):
|
||||||
|
global _pre_process_called
|
||||||
|
if hasattr(response, 'headers'):
|
||||||
|
response.headers["X-Env-Post-Process"] = "applied"
|
||||||
|
# Since pre_process and post_process are combined into
|
||||||
|
# pre_post_process middleware,
|
||||||
|
# if post_process is called, pre_process should have been called too
|
||||||
|
if _pre_process_called:
|
||||||
|
response.headers["X-Pre-Process-Called"] = "true"
|
||||||
|
return response
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
script_path = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
script_dir = os.path.dirname(script_path)
|
||||||
|
script_name = os.path.basename(script_path)
|
||||||
|
|
||||||
|
# Set environment variables for middleware
|
||||||
|
# Use script_name with .py extension as per plugin example
|
||||||
|
env_vars = {
|
||||||
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
||||||
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
||||||
|
FastAPIEnvVars.CUSTOM_FASTAPI_MIDDLEWARE_THROTTLE: (
|
||||||
|
f"{script_name}:env_throttle_middleware"
|
||||||
|
),
|
||||||
|
FastAPIEnvVars.CUSTOM_PRE_PROCESS: f"{script_name}:env_pre_process",
|
||||||
|
FastAPIEnvVars.CUSTOM_POST_PROCESS: f"{script_name}:env_post_process",
|
||||||
|
}
|
||||||
|
|
||||||
|
args = [
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
|
"--max-model-len",
|
||||||
|
"2048",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--max-num-seqs",
|
||||||
|
"32",
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServer(
|
||||||
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
||||||
|
) as server:
|
||||||
|
response = requests.get(server.url_for("ping"))
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Check if environment variable middleware was applied
|
||||||
|
headers = response.headers
|
||||||
|
|
||||||
|
# Verify that env var middlewares were applied
|
||||||
|
assert "X-Env-Throttle" in headers, (
|
||||||
|
"Throttle middleware should be applied via env var"
|
||||||
|
)
|
||||||
|
assert headers["X-Env-Throttle"] == "applied"
|
||||||
|
|
||||||
|
assert "X-Env-Post-Process" in headers, (
|
||||||
|
"Post-process middleware should be applied via env var"
|
||||||
|
)
|
||||||
|
assert headers["X-Env-Post-Process"] == "applied"
|
||||||
|
|
||||||
|
# Verify that pre_process was called
|
||||||
|
assert "X-Pre-Process-Called" in headers, (
|
||||||
|
"Pre-process should be called via env var"
|
||||||
|
)
|
||||||
|
assert headers["X-Pre-Process-Called"] == "true"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.unlink(script_path)
|
||||||
153
tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py
Normal file
153
tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
|
||||||
|
import openai # use the official client for correctness check
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
from .conftest import (
|
||||||
|
HEADER_SAGEMAKER_CLOSED_SESSION_ID,
|
||||||
|
HEADER_SAGEMAKER_NEW_SESSION_ID,
|
||||||
|
HEADER_SAGEMAKER_SESSION_ID,
|
||||||
|
MODEL_NAME_SMOLLM,
|
||||||
|
)
|
||||||
|
|
||||||
|
CLOSE_BADREQUEST_CASES = [
|
||||||
|
(
|
||||||
|
"nonexistent_session_id",
|
||||||
|
{"session_id": "nonexistent-session-id"},
|
||||||
|
{},
|
||||||
|
"session not found",
|
||||||
|
),
|
||||||
|
("malformed_close_request", {}, {"extra-field": "extra-field-data"}, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_session_badrequest(basic_server_with_lora: RemoteOpenAIServer):
|
||||||
|
bad_response = requests.post(
|
||||||
|
basic_server_with_lora.url_for("invocations"),
|
||||||
|
json={"requestType": "NEW_SESSION", "extra-field": "extra-field-data"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert bad_response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_name,session_id_change,request_body_change,expected_error",
|
||||||
|
CLOSE_BADREQUEST_CASES,
|
||||||
|
)
|
||||||
|
async def test_close_session_badrequest(
|
||||||
|
basic_server_with_lora: RemoteOpenAIServer,
|
||||||
|
test_name: str,
|
||||||
|
session_id_change: dict[str, str],
|
||||||
|
request_body_change: dict[str, str],
|
||||||
|
expected_error: str | None,
|
||||||
|
):
|
||||||
|
# first attempt to create a session
|
||||||
|
url = basic_server_with_lora.url_for("invocations")
|
||||||
|
create_response = requests.post(url, json={"requestType": "NEW_SESSION"})
|
||||||
|
create_response.raise_for_status()
|
||||||
|
valid_session_id, expiration = create_response.headers.get(
|
||||||
|
HEADER_SAGEMAKER_NEW_SESSION_ID, ""
|
||||||
|
).split(";")
|
||||||
|
assert valid_session_id
|
||||||
|
|
||||||
|
close_request_json = {"requestType": "CLOSE"}
|
||||||
|
if request_body_change:
|
||||||
|
close_request_json.update(request_body_change)
|
||||||
|
bad_session_id = session_id_change.get("session_id")
|
||||||
|
bad_close_response = requests.post(
|
||||||
|
url,
|
||||||
|
headers={HEADER_SAGEMAKER_SESSION_ID: bad_session_id or valid_session_id},
|
||||||
|
json=close_request_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
# clean up created session, should succeed
|
||||||
|
clean_up_response = requests.post(
|
||||||
|
url,
|
||||||
|
headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id},
|
||||||
|
json={"requestType": "CLOSE"},
|
||||||
|
)
|
||||||
|
clean_up_response.raise_for_status()
|
||||||
|
|
||||||
|
assert bad_close_response.status_code == 400
|
||||||
|
if expected_error:
|
||||||
|
assert expected_error in bad_close_response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_session_invalidrequest(
|
||||||
|
basic_server_with_lora: RemoteOpenAIServer, async_client: openai.AsyncOpenAI
|
||||||
|
):
|
||||||
|
# first attempt to create a session
|
||||||
|
url = basic_server_with_lora.url_for("invocations")
|
||||||
|
create_response = requests.post(url, json={"requestType": "NEW_SESSION"})
|
||||||
|
create_response.raise_for_status()
|
||||||
|
valid_session_id, expiration = create_response.headers.get(
|
||||||
|
HEADER_SAGEMAKER_NEW_SESSION_ID, ""
|
||||||
|
).split(";")
|
||||||
|
assert valid_session_id
|
||||||
|
|
||||||
|
close_request_json = {"requestType": "CLOSE"}
|
||||||
|
invalid_close_response = requests.post(
|
||||||
|
url,
|
||||||
|
# no headers to specify session_id
|
||||||
|
json=close_request_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
# clean up created session, should succeed
|
||||||
|
clean_up_response = requests.post(
|
||||||
|
url,
|
||||||
|
headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id},
|
||||||
|
json={"requestType": "CLOSE"},
|
||||||
|
)
|
||||||
|
clean_up_response.raise_for_status()
|
||||||
|
|
||||||
|
assert invalid_close_response.status_code == 424
|
||||||
|
assert "invalid session_id" in invalid_close_response.json()["error"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session(basic_server_with_lora: RemoteOpenAIServer):
|
||||||
|
# first attempt to create a session
|
||||||
|
url = basic_server_with_lora.url_for("invocations")
|
||||||
|
create_response = requests.post(url, json={"requestType": "NEW_SESSION"})
|
||||||
|
create_response.raise_for_status()
|
||||||
|
valid_session_id, expiration = create_response.headers.get(
|
||||||
|
HEADER_SAGEMAKER_NEW_SESSION_ID, ""
|
||||||
|
).split(";")
|
||||||
|
assert valid_session_id
|
||||||
|
|
||||||
|
# test invocation with session id
|
||||||
|
|
||||||
|
request_args = {
|
||||||
|
"model": MODEL_NAME_SMOLLM,
|
||||||
|
"prompt": "what is 1+1?",
|
||||||
|
"max_completion_tokens": 5,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"logprobs": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
invocation_response = requests.post(
|
||||||
|
basic_server_with_lora.url_for("invocations"),
|
||||||
|
headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id},
|
||||||
|
json=request_args,
|
||||||
|
)
|
||||||
|
invocation_response.raise_for_status()
|
||||||
|
|
||||||
|
# close created session, should succeed
|
||||||
|
close_response = requests.post(
|
||||||
|
url,
|
||||||
|
headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id},
|
||||||
|
json={"requestType": "CLOSE"},
|
||||||
|
)
|
||||||
|
close_response.raise_for_status()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
close_response.headers.get(HEADER_SAGEMAKER_CLOSED_SESSION_ID)
|
||||||
|
== valid_session_id
|
||||||
|
)
|
||||||
57
vllm/entrypoints/dynamic_lora.py
Normal file
57
vllm/entrypoints/dynamic_lora.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
from fastapi.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.api_server import models, validate_json_request
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ErrorResponse,
|
||||||
|
LoadLoRAAdapterRequest,
|
||||||
|
UnloadLoRAAdapterRequest,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def register_dynamic_lora_routes(router: APIRouter):
|
||||||
|
@sagemaker_standards.register_load_adapter_handler(
|
||||||
|
request_shape={
|
||||||
|
"lora_name": "body.name",
|
||||||
|
"lora_path": "body.src",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)])
|
||||||
|
async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request):
|
||||||
|
handler: OpenAIServingModels = models(raw_request)
|
||||||
|
response = await handler.load_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=response.model_dump(), status_code=response.error.code
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
|
@sagemaker_standards.register_unload_adapter_handler(
|
||||||
|
request_shape={
|
||||||
|
"lora_name": "path_params.adapter_name",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@router.post(
|
||||||
|
"/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)]
|
||||||
|
)
|
||||||
|
async def unload_lora_adapter(
|
||||||
|
request: UnloadLoRAAdapterRequest, raw_request: Request
|
||||||
|
):
|
||||||
|
handler: OpenAIServingModels = models(raw_request)
|
||||||
|
response = await handler.unload_lora_adapter(request)
|
||||||
|
if isinstance(response, ErrorResponse):
|
||||||
|
return JSONResponse(
|
||||||
|
content=response.model_dump(), status_code=response.error.code
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(status_code=200, content=response)
|
||||||
|
|
||||||
|
return router
|
||||||
@ -19,6 +19,7 @@ from contextlib import asynccontextmanager
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
|
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||||
import prometheus_client
|
import prometheus_client
|
||||||
import pydantic
|
import pydantic
|
||||||
import regex as re
|
import regex as re
|
||||||
@ -65,7 +66,6 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ErrorInfo,
|
ErrorInfo,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
IOProcessorResponse,
|
IOProcessorResponse,
|
||||||
LoadLoRAAdapterRequest,
|
|
||||||
PoolingBytesResponse,
|
PoolingBytesResponse,
|
||||||
PoolingRequest,
|
PoolingRequest,
|
||||||
PoolingResponse,
|
PoolingResponse,
|
||||||
@ -82,7 +82,6 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
TranslationRequest,
|
TranslationRequest,
|
||||||
TranslationResponse,
|
TranslationResponse,
|
||||||
UnloadLoRAAdapterRequest,
|
|
||||||
)
|
)
|
||||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
from vllm.entrypoints.openai.serving_classification import ServingClassification
|
from vllm.entrypoints.openai.serving_classification import ServingClassification
|
||||||
@ -387,13 +386,6 @@ async def get_server_load_metrics(request: Request):
|
|||||||
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
|
||||||
|
|
||||||
|
|
||||||
@router.get("/ping", response_class=Response)
|
|
||||||
@router.post("/ping", response_class=Response)
|
|
||||||
async def ping(raw_request: Request) -> Response:
|
|
||||||
"""Ping check. Endpoint required for SageMaker"""
|
|
||||||
return await health(raw_request)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/tokenize",
|
"/tokenize",
|
||||||
dependencies=[Depends(validate_json_request)],
|
dependencies=[Depends(validate_json_request)],
|
||||||
@ -1236,47 +1228,6 @@ INVOCATION_VALIDATORS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/invocations",
|
|
||||||
dependencies=[Depends(validate_json_request)],
|
|
||||||
responses={
|
|
||||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
|
|
||||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
async def invocations(raw_request: Request):
|
|
||||||
"""For SageMaker, routes requests based on the request type."""
|
|
||||||
try:
|
|
||||||
body = await raw_request.json()
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
valid_endpoints = [
|
|
||||||
(validator, endpoint)
|
|
||||||
for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS
|
|
||||||
if get_handler(raw_request) is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
for request_validator, endpoint in valid_endpoints:
|
|
||||||
try:
|
|
||||||
request = request_validator.validate_python(body)
|
|
||||||
except pydantic.ValidationError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return await endpoint(request, raw_request)
|
|
||||||
|
|
||||||
type_names = [
|
|
||||||
t.__name__ if isinstance(t := validator._type, type) else str(t)
|
|
||||||
for validator, _ in valid_endpoints
|
|
||||||
]
|
|
||||||
msg = f"Cannot find suitable handler for request. Expected one of: {type_names}"
|
|
||||||
res = base(raw_request).create_error_response(message=msg)
|
|
||||||
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
|
|
||||||
|
|
||||||
|
|
||||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||||
@ -1304,39 +1255,6 @@ if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
|
|||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
|
||||||
logger.warning(
|
|
||||||
"LoRA dynamic loading & unloading is enabled in the API server. "
|
|
||||||
"This should ONLY be used for local development!"
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)])
|
|
||||||
async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request):
|
|
||||||
handler = models(raw_request)
|
|
||||||
response = await handler.load_lora_adapter(request)
|
|
||||||
if isinstance(response, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=response.model_dump(), status_code=response.error.code
|
|
||||||
)
|
|
||||||
|
|
||||||
return Response(status_code=200, content=response)
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)]
|
|
||||||
)
|
|
||||||
async def unload_lora_adapter(
|
|
||||||
request: UnloadLoRAAdapterRequest, raw_request: Request
|
|
||||||
):
|
|
||||||
handler = models(raw_request)
|
|
||||||
response = await handler.unload_lora_adapter(request)
|
|
||||||
if isinstance(response, ErrorResponse):
|
|
||||||
return JSONResponse(
|
|
||||||
content=response.model_dump(), status_code=response.error.code
|
|
||||||
)
|
|
||||||
|
|
||||||
return Response(status_code=200, content=response)
|
|
||||||
|
|
||||||
|
|
||||||
def load_log_config(log_config_file: str | None) -> dict | None:
|
def load_log_config(log_config_file: str | None) -> dict | None:
|
||||||
if not log_config_file:
|
if not log_config_file:
|
||||||
return None
|
return None
|
||||||
@ -1606,6 +1524,20 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
|
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||||
|
logger.warning(
|
||||||
|
"LoRA dynamic loading & unloading is enabled in the API server. "
|
||||||
|
"This should ONLY be used for local development!"
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes
|
||||||
|
|
||||||
|
register_dynamic_lora_routes(router)
|
||||||
|
|
||||||
|
from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes
|
||||||
|
|
||||||
|
register_sagemaker_routes(router)
|
||||||
|
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
app.root_path = args.root_path
|
app.root_path = args.root_path
|
||||||
|
|
||||||
@ -1696,6 +1628,8 @@ def build_app(args: Namespace) -> FastAPI:
|
|||||||
f"Invalid middleware {middleware}. Must be a function or a class."
|
f"Invalid middleware {middleware}. Must be a function or a class."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
app = sagemaker_standards.bootstrap(app)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
4
vllm/entrypoints/sagemaker/__init__.py
Normal file
4
vllm/entrypoints/sagemaker/__init__.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""SageMaker-specific integration for vLLM."""
|
||||||
72
vllm/entrypoints/sagemaker/routes.py
Normal file
72
vllm/entrypoints/sagemaker/routes.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import json
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
import model_hosting_container_standards.sagemaker as sagemaker_standards
|
||||||
|
import pydantic
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.api_server import (
|
||||||
|
INVOCATION_VALIDATORS,
|
||||||
|
base,
|
||||||
|
health,
|
||||||
|
validate_json_request,
|
||||||
|
)
|
||||||
|
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||||
|
|
||||||
|
|
||||||
|
def register_sagemaker_routes(router: APIRouter):
|
||||||
|
@router.post("/ping", response_class=Response)
|
||||||
|
@router.get("/ping", response_class=Response)
|
||||||
|
@sagemaker_standards.register_ping_handler
|
||||||
|
async def ping(raw_request: Request) -> Response:
|
||||||
|
"""Ping check. Endpoint required for SageMaker"""
|
||||||
|
return await health(raw_request)
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/invocations",
|
||||||
|
dependencies=[Depends(validate_json_request)],
|
||||||
|
responses={
|
||||||
|
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
|
||||||
|
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@sagemaker_standards.register_invocation_handler
|
||||||
|
@sagemaker_standards.stateful_session_manager()
|
||||||
|
@sagemaker_standards.inject_adapter_id(adapter_path="model")
|
||||||
|
async def invocations(raw_request: Request):
|
||||||
|
"""For SageMaker, routes requests based on the request type."""
|
||||||
|
try:
|
||||||
|
body = await raw_request.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||||
|
detail=f"JSON decode error: {e}",
|
||||||
|
) from e
|
||||||
|
|
||||||
|
valid_endpoints = [
|
||||||
|
(validator, endpoint)
|
||||||
|
for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS
|
||||||
|
if get_handler(raw_request) is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
for request_validator, endpoint in valid_endpoints:
|
||||||
|
try:
|
||||||
|
request = request_validator.validate_python(body)
|
||||||
|
except pydantic.ValidationError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return await endpoint(request, raw_request)
|
||||||
|
|
||||||
|
type_names = [
|
||||||
|
t.__name__ if isinstance(t := validator._type, type) else str(t)
|
||||||
|
for validator, _ in valid_endpoints
|
||||||
|
]
|
||||||
|
msg = f"Cannot find suitable handler for request. Expected one of: {type_names}"
|
||||||
|
res = base(raw_request).create_error_response(message=msg)
|
||||||
|
return JSONResponse(content=res.model_dump(), status_code=res.error.code)
|
||||||
|
|
||||||
|
return router
|
||||||
Loading…
x
Reference in New Issue
Block a user