From bca74e32b7ef03515cda508ba88151e2e547bdc9 Mon Sep 17 00:00:00 2001 From: Zuyi Zhao Date: Mon, 10 Nov 2025 20:57:01 -0800 Subject: [PATCH] [Frontend] Add sagemaker_standards dynamic lora adapter and stateful session management decorators to vLLM OpenAI API server (#27892) Signed-off-by: Zuyi Zhao Signed-off-by: Shen Teng Co-authored-by: Shen Teng Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> --- requirements/common.txt | 1 + tests/entrypoints/sagemaker/__init__.py | 0 tests/entrypoints/sagemaker/conftest.py | 58 ++ .../test_sagemaker_handler_overrides.py | 734 ++++++++++++++++++ .../sagemaker/test_sagemaker_lora_adapters.py | 171 ++++ .../test_sagemaker_middleware_integration.py | 346 +++++++++ .../test_sagemaker_stateful_sessions.py | 153 ++++ vllm/entrypoints/dynamic_lora.py | 57 ++ vllm/entrypoints/openai/api_server.py | 100 +-- vllm/entrypoints/sagemaker/__init__.py | 4 + vllm/entrypoints/sagemaker/routes.py | 72 ++ 11 files changed, 1613 insertions(+), 83 deletions(-) create mode 100644 tests/entrypoints/sagemaker/__init__.py create mode 100644 tests/entrypoints/sagemaker/conftest.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py create mode 100644 tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py create mode 100644 vllm/entrypoints/dynamic_lora.py create mode 100644 vllm/entrypoints/sagemaker/__init__.py create mode 100644 vllm/entrypoints/sagemaker/routes.py diff --git a/requirements/common.txt b/requirements/common.txt index 8009581f62a4..90efb79a845d 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -49,3 +49,4 @@ cbor2 # Required for cross-language serialization of hashable objects setproctitle # Used to set process names for better debugging and monitoring openai-harmony >= 0.0.3 # Required for gpt-oss anthropic == 0.71.0 +model-hosting-container-standards < 1.0.0 \ No newline at end of file diff --git a/tests/entrypoints/sagemaker/__init__.py b/tests/entrypoints/sagemaker/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/entrypoints/sagemaker/conftest.py b/tests/entrypoints/sagemaker/conftest.py new file mode 100644 index 000000000000..4c859c2527d2 --- /dev/null +++ b/tests/entrypoints/sagemaker/conftest.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Shared fixtures and utilities for SageMaker tests.""" + +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + +# Model name constants used across tests +MODEL_NAME_ZEPHYR = "HuggingFaceH4/zephyr-7b-beta" +MODEL_NAME_SMOLLM = "HuggingFaceTB/SmolLM2-135M-Instruct" +LORA_ADAPTER_NAME_SMOLLM = "jekunz/smollm-135m-lora-fineweb-faroese" + +# SageMaker header constants +HEADER_SAGEMAKER_CLOSED_SESSION_ID = "X-Amzn-SageMaker-Closed-Session-Id" +HEADER_SAGEMAKER_SESSION_ID = "X-Amzn-SageMaker-Session-Id" +HEADER_SAGEMAKER_NEW_SESSION_ID = "X-Amzn-SageMaker-New-Session-Id" + + +@pytest.fixture(scope="session") +def smollm2_lora_files(): + """Download LoRA files once per test session.""" + from huggingface_hub import snapshot_download + + return snapshot_download(repo_id=LORA_ADAPTER_NAME_SMOLLM) + + +@pytest.fixture(scope="module") +def basic_server_with_lora(smollm2_lora_files): + """Basic server fixture with standard configuration.""" + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + # lora config below + "--enable-lora", + "--max-lora-rank", + "256", + "--max-cpu-loras", + "2", + "--max-num-seqs", + "64", + ] + + envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"} + with RemoteOpenAIServer(MODEL_NAME_SMOLLM, args, env_dict=envs) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def async_client(basic_server_with_lora: RemoteOpenAIServer): + """Async OpenAI client fixture for use with basic_server.""" + async with basic_server_with_lora.get_async_client() as async_client: + yield async_client diff --git a/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py b/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py new file mode 100644 index 000000000000..0d4f8e885824 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_handler_overrides.py @@ -0,0 +1,734 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Integration tests for handler override functionality. + +Tests real customer usage scenarios: +- Using @custom_ping_handler and @custom_invocation_handler decorators + to override handlers +- Setting environment variables for handler specifications +- Writing customer scripts with custom_sagemaker_ping_handler() and + custom_sagemaker_invocation_handler() functions +- Priority: env vars > decorators > customer script files > framework + defaults + +Note: These tests focus on validating server responses rather than directly calling +get_ping_handler() and get_invoke_handler() to ensure full integration testing. +""" + +import os +import tempfile + +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + MODEL_NAME_SMOLLM, +) + + +class TestHandlerOverrideIntegration: + """Integration tests simulating real customer usage scenarios. + + Each test simulates a fresh server startup where customers: + - Use @custom_ping_handler and @custom_invocation_handler decorators + - Set environment variables (CUSTOM_FASTAPI_PING_HANDLER, etc.) + - Write customer scripts with custom_sagemaker_ping_handler() and + custom_sagemaker_invocation_handler() functions + """ + + def setup_method(self): + """Setup for each test - simulate fresh server startup.""" + self._clear_caches() + self._clear_env_vars() + + def teardown_method(self): + """Cleanup after each test.""" + self._clear_env_vars() + + def _clear_caches(self): + """Clear handler registry and function loader cache.""" + try: + from model_hosting_container_standards.common.handler import ( + handler_registry, + ) + from model_hosting_container_standards.sagemaker.sagemaker_loader import ( + SageMakerFunctionLoader, + ) + + handler_registry.clear() + SageMakerFunctionLoader._default_function_loader = None + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + def _clear_env_vars(self): + """Clear SageMaker environment variables.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + + # Clear SageMaker env vars + for var in [ + SageMakerEnvVars.SAGEMAKER_MODEL_PATH, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME, + ]: + os.environ.pop(var, None) + + # Clear FastAPI env vars + for var in [ + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER, + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER, + ]: + os.environ.pop(var, None) + except ImportError: + pass + + @pytest.mark.asyncio + async def test_customer_script_functions_auto_loaded(self): + """Test customer scenario: script functions automatically override + framework defaults.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script file with ping() and invoke() functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "customer_override", + "message": "Custom ping from customer script" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Custom response from customer script"], + "source": "customer_override" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Customer sets SageMaker environment variables to point to their script + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Customer tests their server and sees their overrides work + # automatically + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Customer sees their functions are used + assert ping_data["source"] == "customer_override" + assert ping_data["message"] == "Custom ping from customer script" + assert invoke_data["source"] == "customer_override" + assert invoke_data["predictions"] == [ + "Custom response from customer script" + ] + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_customer_decorator_usage(self): + """Test customer scenario: using @custom_ping_handler and + @custom_invocation_handler decorators.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script file with decorators + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request + +@sagemaker_standards.custom_ping_handler +async def my_ping(): + return { + "type": "ping", + "source": "customer_decorator" + } + +@sagemaker_standards.custom_invocation_handler +async def my_invoke(request: Request): + return { + "type": "invoke", + "source": "customer_decorator" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Customer sees their handlers are used by the server + assert ping_data["source"] == "customer_decorator" + assert invoke_data["source"] == "customer_decorator" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_handler_priority_order(self): + """Test priority: @custom_ping_handler/@custom_invocation_handler + decorators vs script functions.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script with both decorator and regular functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request + +# Customer uses @custom_ping_handler decorator (higher priority than script functions) +@sagemaker_standards.custom_ping_handler +async def decorated_ping(): + return { + "status": "healthy", + "source": "ping_decorator_in_script", + "priority": "decorator" + } + +# Customer also has a regular function (lower priority than +# @custom_ping_handler decorator) +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script_function", + "priority": "function" + } + +# Customer has a regular invoke function +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script_invoke_function", + "priority": "function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # @custom_ping_handler decorator has higher priority than + # script function + assert ping_data["source"] == "ping_decorator_in_script" + assert ping_data["priority"] == "decorator" + + # Script function is used for invoke + assert invoke_data["source"] == "script_invoke_function" + assert invoke_data["priority"] == "function" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_environment_variable_script_loading(self): + """Test that environment variables correctly specify script location + and loading.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a script in a specific directory + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "env_loaded_script", + "method": "environment_variable_loading" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Loaded via environment variables"], + "source": "env_loaded_script", + "method": "environment_variable_loading" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Test environment variable script loading + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Verify that the script was loaded via environment variables + assert ping_data["source"] == "env_loaded_script" + assert ping_data["method"] == "environment_variable_loading" + assert invoke_data["source"] == "env_loaded_script" + assert invoke_data["method"] == "environment_variable_loading" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_framework_default_handlers(self): + """Test that framework default handlers work when no customer + overrides exist.""" + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + # Explicitly pass empty env_dict to ensure no SageMaker env vars are set + # This prevents pollution from previous tests + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + + env_dict = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: "", + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: "", + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: "", + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: "", + } + except ImportError: + env_dict = {} + + with RemoteOpenAIServer(MODEL_NAME_SMOLLM, args, env_dict=env_dict) as server: + # Test that default ping works + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + + # Test that default invocations work + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + + @pytest.mark.asyncio + async def test_handler_env_var_override(self): + """Test CUSTOM_FASTAPI_PING_HANDLER and CUSTOM_FASTAPI_INVOCATION_HANDLER + environment variable overrides.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with both env var handlers and script functions + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request, Response +import json + +async def env_var_ping_handler(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "env_var_ping", + "method": "environment_variable" + }), + media_type="application/json" + ) + +async def env_var_invoke_handler(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Environment variable response"], + "source": "env_var_invoke", + "method": "environment_variable" + }), + media_type="application/json" + ) + +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script_ping", + "method": "script_function" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script_invoke", + "method": "script_function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to override both handlers + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: ( + f"{script_name}:env_var_ping_handler" + ), + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: ( + f"{script_name}:env_var_invoke_handler" + ), + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping handler override + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + # Environment variable should override script function + assert ping_data["method"] == "environment_variable" + assert ping_data["source"] == "env_var_ping" + + # Test invocation handler override + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Environment variable should override script function + assert invoke_data["method"] == "environment_variable" + assert invoke_data["source"] == "env_var_invoke" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_env_var_priority_over_decorator_and_script(self): + """Test that environment variables have highest priority over decorators + and script functions for both ping and invocation handlers.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with all three handler types for both ping and invocation + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import Request, Response +import json + +# Environment variable handlers (highest priority) +async def env_priority_ping(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "env_var", + "priority": "environment_variable" + }), + media_type="application/json" + ) + +async def env_priority_invoke(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Environment variable response"], + "source": "env_var", + "priority": "environment_variable" + }), + media_type="application/json" + ) + +# Decorator handlers (medium priority) +@sagemaker_standards.custom_ping_handler +async def decorator_ping(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "status": "healthy", + "source": "decorator", + "priority": "decorator" + }), + media_type="application/json" + ) + +@sagemaker_standards.custom_invocation_handler +async def decorator_invoke(raw_request: Request) -> Response: + return Response( + content=json.dumps({ + "predictions": ["Decorator response"], + "source": "decorator", + "priority": "decorator" + }), + media_type="application/json" + ) + +# Script functions (lowest priority) +async def custom_sagemaker_ping_handler(): + return { + "status": "healthy", + "source": "script", + "priority": "script_function" + } + +async def custom_sagemaker_invocation_handler(request: Request): + return { + "predictions": ["Script function response"], + "source": "script", + "priority": "script_function" + } +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to specify highest priority handlers + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_PING_HANDLER: ( + f"{script_name}:env_priority_ping" + ), + FastAPIEnvVars.CUSTOM_FASTAPI_INVOCATION_HANDLER: ( + f"{script_name}:env_priority_invoke" + ), + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping handler priority + ping_response = requests.get(server.url_for("ping")) + assert ping_response.status_code == 200 + ping_data = ping_response.json() + + # Environment variable has highest priority and should be used + assert ping_data["priority"] == "environment_variable" + assert ping_data["source"] == "env_var" + + # Test invocation handler priority + invoke_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + }, + ) + assert invoke_response.status_code == 200 + invoke_data = invoke_response.json() + + # Environment variable has highest priority and should be used + assert invoke_data["priority"] == "environment_variable" + assert invoke_data["source"] == "env_var" + + finally: + os.unlink(script_path) diff --git a/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py b/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py new file mode 100644 index 000000000000..a2867efdc584 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_lora_adapters.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import openai # use the official async_client for correctness check +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import MODEL_NAME_SMOLLM + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_happy_path( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # The SageMaker standards library creates a POST /adapters endpoint + # that maps to the load_lora_adapter handler with request shape: + # {"lora_name": "body.name", "lora_path": "body.src"} + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "smollm2-lora-sagemaker", "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + models = await async_client.models.list() + models = models.data + dynamic_lora_model = models[-1] + assert dynamic_lora_model.root == smollm2_lora_files + assert dynamic_lora_model.parent == MODEL_NAME_SMOLLM + assert dynamic_lora_model.id == "smollm2-lora-sagemaker" + + +@pytest.mark.asyncio +async def test_sagemaker_unload_adapter_happy_path( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # First, load an adapter + adapter_name = "smollm2-lora-sagemaker-unload" + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Verify it's in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + assert adapter_name in adapter_ids + + # Now unload it using DELETE /adapters/{adapter_name} + # The SageMaker standards maps this to unload_lora_adapter with: + # {"lora_name": "path_params.adapter_name"} + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", adapter_name), + ) + unload_response.raise_for_status() + + # Verify it's no longer in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + assert adapter_name not in adapter_ids + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_not_found( + basic_server_with_lora: RemoteOpenAIServer, +): + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "nonexistent-adapter", "src": "/path/does/not/exist"}, + ) + assert load_response.status_code == 404 + + +@pytest.mark.asyncio +async def test_sagemaker_load_adapter_invalid_files( + basic_server_with_lora: RemoteOpenAIServer, + tmp_path, +): + invalid_files = tmp_path / "invalid_adapter" + invalid_files.mkdir() + (invalid_files / "adapter_config.json").write_text("not valid json") + + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": "invalid-adapter", "src": str(invalid_files)}, + ) + assert load_response.status_code == 400 + + +@pytest.mark.asyncio +async def test_sagemaker_unload_nonexistent_adapter( + basic_server_with_lora: RemoteOpenAIServer, +): + # Attempt to unload an adapter that doesn't exist + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", "nonexistent-adapter-name"), + ) + assert unload_response.status_code in (400, 404) + + +@pytest.mark.asyncio +async def test_sagemaker_invocations_with_adapter( + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + # First, load an adapter via SageMaker endpoint + adapter_name = "smollm2-lora-invoke-test" + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Now test the /invocations endpoint with the adapter + invocation_response = requests.post( + basic_server_with_lora.url_for("invocations"), + headers={ + "X-Amzn-SageMaker-Adapter-Identifier": adapter_name, + }, + json={ + "prompt": "Hello, how are you?", + "max_tokens": 10, + }, + ) + invocation_response.raise_for_status() + invocation_output = invocation_response.json() + + # Verify we got a valid completion response + assert "choices" in invocation_output + assert len(invocation_output["choices"]) > 0 + assert "text" in invocation_output["choices"][0] + + +@pytest.mark.asyncio +async def test_sagemaker_multiple_adapters_load_unload( + async_client: openai.AsyncOpenAI, + basic_server_with_lora: RemoteOpenAIServer, + smollm2_lora_files, +): + adapter_names = [f"sagemaker-adapter-{i}" for i in range(5)] + + # Load all adapters + for adapter_name in adapter_names: + load_response = requests.post( + basic_server_with_lora.url_for("adapters"), + json={"name": adapter_name, "src": smollm2_lora_files}, + ) + load_response.raise_for_status() + + # Verify all are in the models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + for adapter_name in adapter_names: + assert adapter_name in adapter_ids + + # Unload all adapters + for adapter_name in adapter_names: + unload_response = requests.delete( + basic_server_with_lora.url_for("adapters", adapter_name), + ) + unload_response.raise_for_status() + + # Verify all are removed from models list + models = await async_client.models.list() + adapter_ids = [model.id for model in models.data] + for adapter_name in adapter_names: + assert adapter_name not in adapter_ids diff --git a/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py b/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py new file mode 100644 index 000000000000..f1ed0c7e2897 --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_middleware_integration.py @@ -0,0 +1,346 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Integration test for middleware loader functionality. + +Tests that customer middlewares get called correctly with a vLLM server. +""" + +import os +import tempfile + +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + MODEL_NAME_SMOLLM, +) + + +class TestMiddlewareIntegration: + """Integration test for middleware with vLLM server.""" + + def setup_method(self): + """Setup for each test - simulate fresh server startup.""" + self._clear_caches() + + def _clear_caches(self): + """Clear middleware registry and function loader cache.""" + try: + from model_hosting_container_standards.common.fastapi.middleware import ( + middleware_registry, + ) + from model_hosting_container_standards.common.fastapi.middleware.source.decorator_loader import ( # noqa: E501 + decorator_loader, + ) + from model_hosting_container_standards.sagemaker.sagemaker_loader import ( + SageMakerFunctionLoader, + ) + + middleware_registry.clear_middlewares() + decorator_loader.clear() + SageMakerFunctionLoader._default_function_loader = None + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + @pytest.mark.asyncio + async def test_customer_middleware_with_vllm_server(self): + """Test that customer middlewares work with actual vLLM server. + + Tests decorator-based middlewares (@custom_middleware, @input_formatter, + @output_formatter) + on multiple endpoints (chat/completions, invocations). + """ + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a middleware script with multiple decorators + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from model_hosting_container_standards.common.fastapi.middleware import ( + custom_middleware, input_formatter, output_formatter +) + +# Global flag to track if input formatter was called +_input_formatter_called = False + +@input_formatter +async def customer_input_formatter(request): + # Process input - mark that input formatter was called + global _input_formatter_called + _input_formatter_called = True + return request + +@custom_middleware("throttle") +async def customer_throttle_middleware(request, call_next): + response = await call_next(request) + response.headers["X-Customer-Throttle"] = "applied" + order = response.headers.get("X-Middleware-Order", "") + response.headers["X-Middleware-Order"] = order + "throttle," + return response + +@output_formatter +async def customer_output_formatter(response): + global _input_formatter_called + response.headers["X-Customer-Processed"] = "true" + # Since input_formatter and output_formatter are combined into + # pre_post_process middleware, + # if output_formatter is called, input_formatter should have been called too + if _input_formatter_called: + response.headers["X-Input-Formatter-Called"] = "true" + order = response.headers.get("X-Middleware-Order", "") + response.headers["X-Middleware-Order"] = order + "output_formatter," + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables to point to customer script + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test 1: Middlewares applied to chat/completions endpoint + chat_response = requests.post( + server.url_for("v1/chat/completions"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + "temperature": 0.0, + }, + ) + + assert chat_response.status_code == 200 + + # Verify all middlewares were executed + assert "X-Customer-Throttle" in chat_response.headers + assert chat_response.headers["X-Customer-Throttle"] == "applied" + assert "X-Customer-Processed" in chat_response.headers + assert chat_response.headers["X-Customer-Processed"] == "true" + + # Verify input formatter was called + assert "X-Input-Formatter-Called" in chat_response.headers + assert chat_response.headers["X-Input-Formatter-Called"] == "true" + + # Verify middleware execution order + execution_order = chat_response.headers.get( + "X-Middleware-Order", "" + ).rstrip(",") + order_parts = execution_order.split(",") if execution_order else [] + assert "throttle" in order_parts + assert "output_formatter" in order_parts + + # Test 2: Middlewares applied to invocations endpoint + invocations_response = requests.post( + server.url_for("invocations"), + json={ + "model": MODEL_NAME_SMOLLM, + "messages": [{"role": "user", "content": "Hello"}], + "max_tokens": 5, + "temperature": 0.0, + }, + ) + + assert invocations_response.status_code == 200 + + # Verify all middlewares were executed + assert "X-Customer-Throttle" in invocations_response.headers + assert invocations_response.headers["X-Customer-Throttle"] == "applied" + assert "X-Customer-Processed" in invocations_response.headers + assert invocations_response.headers["X-Customer-Processed"] == "true" + + # Verify input formatter was called + assert "X-Input-Formatter-Called" in invocations_response.headers + assert ( + invocations_response.headers["X-Input-Formatter-Called"] == "true" + ) + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_middleware_with_ping_endpoint(self): + """Test that middlewares work with SageMaker ping endpoint.""" + try: + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Customer writes a middleware script + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from model_hosting_container_standards.common.fastapi.middleware import ( + custom_middleware +) + +@custom_middleware("pre_post_process") +async def ping_tracking_middleware(request, call_next): + response = await call_next(request) + if request.url.path == "/ping": + response.headers["X-Ping-Tracked"] = "true" + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + # Test ping endpoint with middleware + response = requests.get(server.url_for("ping")) + + assert response.status_code == 200 + assert "X-Ping-Tracked" in response.headers + assert response.headers["X-Ping-Tracked"] == "true" + + finally: + os.unlink(script_path) + + @pytest.mark.asyncio + async def test_middleware_env_var_override(self): + """Test middleware environment variable overrides.""" + try: + from model_hosting_container_standards.common.fastapi.config import ( + FastAPIEnvVars, + ) + from model_hosting_container_standards.sagemaker.config import ( + SageMakerEnvVars, + ) + except ImportError: + pytest.skip("model-hosting-container-standards not available") + + # Create a script with middleware functions specified via env vars + with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f: + f.write( + """ +from fastapi import Request + +# Global flag to track if pre_process was called +_pre_process_called = False + +async def env_throttle_middleware(request, call_next): + response = await call_next(request) + response.headers["X-Env-Throttle"] = "applied" + return response + +async def env_pre_process(request: Request) -> Request: + # Mark that pre_process was called + global _pre_process_called + _pre_process_called = True + return request + +async def env_post_process(response): + global _pre_process_called + if hasattr(response, 'headers'): + response.headers["X-Env-Post-Process"] = "applied" + # Since pre_process and post_process are combined into + # pre_post_process middleware, + # if post_process is called, pre_process should have been called too + if _pre_process_called: + response.headers["X-Pre-Process-Called"] = "true" + return response +""" + ) + script_path = f.name + + try: + script_dir = os.path.dirname(script_path) + script_name = os.path.basename(script_path) + + # Set environment variables for middleware + # Use script_name with .py extension as per plugin example + env_vars = { + SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir, + SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name, + FastAPIEnvVars.CUSTOM_FASTAPI_MIDDLEWARE_THROTTLE: ( + f"{script_name}:env_throttle_middleware" + ), + FastAPIEnvVars.CUSTOM_PRE_PROCESS: f"{script_name}:env_pre_process", + FastAPIEnvVars.CUSTOM_POST_PROCESS: f"{script_name}:env_post_process", + } + + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--enforce-eager", + "--max-num-seqs", + "32", + ] + + with RemoteOpenAIServer( + MODEL_NAME_SMOLLM, args, env_dict=env_vars + ) as server: + response = requests.get(server.url_for("ping")) + assert response.status_code == 200 + + # Check if environment variable middleware was applied + headers = response.headers + + # Verify that env var middlewares were applied + assert "X-Env-Throttle" in headers, ( + "Throttle middleware should be applied via env var" + ) + assert headers["X-Env-Throttle"] == "applied" + + assert "X-Env-Post-Process" in headers, ( + "Post-process middleware should be applied via env var" + ) + assert headers["X-Env-Post-Process"] == "applied" + + # Verify that pre_process was called + assert "X-Pre-Process-Called" in headers, ( + "Pre-process should be called via env var" + ) + assert headers["X-Pre-Process-Called"] == "true" + + finally: + os.unlink(script_path) diff --git a/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py b/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py new file mode 100644 index 000000000000..6206000385bd --- /dev/null +++ b/tests/entrypoints/sagemaker/test_sagemaker_stateful_sessions.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import openai # use the official client for correctness check +import pytest +import requests + +from ...utils import RemoteOpenAIServer +from .conftest import ( + HEADER_SAGEMAKER_CLOSED_SESSION_ID, + HEADER_SAGEMAKER_NEW_SESSION_ID, + HEADER_SAGEMAKER_SESSION_ID, + MODEL_NAME_SMOLLM, +) + +CLOSE_BADREQUEST_CASES = [ + ( + "nonexistent_session_id", + {"session_id": "nonexistent-session-id"}, + {}, + "session not found", + ), + ("malformed_close_request", {}, {"extra-field": "extra-field-data"}, None), +] + + +@pytest.mark.asyncio +async def test_create_session_badrequest(basic_server_with_lora: RemoteOpenAIServer): + bad_response = requests.post( + basic_server_with_lora.url_for("invocations"), + json={"requestType": "NEW_SESSION", "extra-field": "extra-field-data"}, + ) + + assert bad_response.status_code == 400 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_name,session_id_change,request_body_change,expected_error", + CLOSE_BADREQUEST_CASES, +) +async def test_close_session_badrequest( + basic_server_with_lora: RemoteOpenAIServer, + test_name: str, + session_id_change: dict[str, str], + request_body_change: dict[str, str], + expected_error: str | None, +): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + close_request_json = {"requestType": "CLOSE"} + if request_body_change: + close_request_json.update(request_body_change) + bad_session_id = session_id_change.get("session_id") + bad_close_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: bad_session_id or valid_session_id}, + json=close_request_json, + ) + + # clean up created session, should succeed + clean_up_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + clean_up_response.raise_for_status() + + assert bad_close_response.status_code == 400 + if expected_error: + assert expected_error in bad_close_response.json()["error"]["message"] + + +@pytest.mark.asyncio +async def test_close_session_invalidrequest( + basic_server_with_lora: RemoteOpenAIServer, async_client: openai.AsyncOpenAI +): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + close_request_json = {"requestType": "CLOSE"} + invalid_close_response = requests.post( + url, + # no headers to specify session_id + json=close_request_json, + ) + + # clean up created session, should succeed + clean_up_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + clean_up_response.raise_for_status() + + assert invalid_close_response.status_code == 424 + assert "invalid session_id" in invalid_close_response.json()["error"]["message"] + + +@pytest.mark.asyncio +async def test_session(basic_server_with_lora: RemoteOpenAIServer): + # first attempt to create a session + url = basic_server_with_lora.url_for("invocations") + create_response = requests.post(url, json={"requestType": "NEW_SESSION"}) + create_response.raise_for_status() + valid_session_id, expiration = create_response.headers.get( + HEADER_SAGEMAKER_NEW_SESSION_ID, "" + ).split(";") + assert valid_session_id + + # test invocation with session id + + request_args = { + "model": MODEL_NAME_SMOLLM, + "prompt": "what is 1+1?", + "max_completion_tokens": 5, + "temperature": 0.0, + "logprobs": False, + } + + invocation_response = requests.post( + basic_server_with_lora.url_for("invocations"), + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json=request_args, + ) + invocation_response.raise_for_status() + + # close created session, should succeed + close_response = requests.post( + url, + headers={HEADER_SAGEMAKER_SESSION_ID: valid_session_id}, + json={"requestType": "CLOSE"}, + ) + close_response.raise_for_status() + + assert ( + close_response.headers.get(HEADER_SAGEMAKER_CLOSED_SESSION_ID) + == valid_session_id + ) diff --git a/vllm/entrypoints/dynamic_lora.py b/vllm/entrypoints/dynamic_lora.py new file mode 100644 index 000000000000..cc0f437e5c77 --- /dev/null +++ b/vllm/entrypoints/dynamic_lora.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import model_hosting_container_standards.sagemaker as sagemaker_standards +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse, Response + +from vllm.entrypoints.openai.api_server import models, validate_json_request +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, + LoadLoRAAdapterRequest, + UnloadLoRAAdapterRequest, +) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def register_dynamic_lora_routes(router: APIRouter): + @sagemaker_standards.register_load_adapter_handler( + request_shape={ + "lora_name": "body.name", + "lora_path": "body.src", + }, + ) + @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)]) + async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request): + handler: OpenAIServingModels = models(raw_request) + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) + + return Response(status_code=200, content=response) + + @sagemaker_standards.register_unload_adapter_handler( + request_shape={ + "lora_name": "path_params.adapter_name", + } + ) + @router.post( + "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)] + ) + async def unload_lora_adapter( + request: UnloadLoRAAdapterRequest, raw_request: Request + ): + handler: OpenAIServingModels = models(raw_request) + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse( + content=response.model_dump(), status_code=response.error.code + ) + + return Response(status_code=200, content=response) + + return router diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 51191879e478..fbb2d32a229d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -19,6 +19,7 @@ from contextlib import asynccontextmanager from http import HTTPStatus from typing import Annotated, Any, Literal +import model_hosting_container_standards.sagemaker as sagemaker_standards import prometheus_client import pydantic import regex as re @@ -65,7 +66,6 @@ from vllm.entrypoints.openai.protocol import ( ErrorInfo, ErrorResponse, IOProcessorResponse, - LoadLoRAAdapterRequest, PoolingBytesResponse, PoolingRequest, PoolingResponse, @@ -82,7 +82,6 @@ from vllm.entrypoints.openai.protocol import ( TranscriptionResponse, TranslationRequest, TranslationResponse, - UnloadLoRAAdapterRequest, ) from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_classification import ServingClassification @@ -387,13 +386,6 @@ async def get_server_load_metrics(request: Request): return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) -@router.get("/ping", response_class=Response) -@router.post("/ping", response_class=Response) -async def ping(raw_request: Request) -> Response: - """Ping check. Endpoint required for SageMaker""" - return await health(raw_request) - - @router.post( "/tokenize", dependencies=[Depends(validate_json_request)], @@ -1236,47 +1228,6 @@ INVOCATION_VALIDATORS = [ ] -@router.post( - "/invocations", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -async def invocations(raw_request: Request): - """For SageMaker, routes requests based on the request type.""" - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}" - ) from e - - valid_endpoints = [ - (validator, endpoint) - for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS - if get_handler(raw_request) is not None - ] - - for request_validator, endpoint in valid_endpoints: - try: - request = request_validator.validate_python(body) - except pydantic.ValidationError: - continue - - return await endpoint(request, raw_request) - - type_names = [ - t.__name__ if isinstance(t := validator._type, type) else str(t) - for validator, _ in valid_endpoints - ] - msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" - res = base(raw_request).create_error_response(message=msg) - return JSONResponse(content=res.model_dump(), status_code=res.error.code) - - if envs.VLLM_TORCH_PROFILER_DIR: logger.warning_once( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -1304,39 +1255,6 @@ if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE: return Response(status_code=200) -if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - logger.warning( - "LoRA dynamic loading & unloading is enabled in the API server. " - "This should ONLY be used for local development!" - ) - - @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)]) - async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request): - handler = models(raw_request) - response = await handler.load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse( - content=response.model_dump(), status_code=response.error.code - ) - - return Response(status_code=200, content=response) - - @router.post( - "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)] - ) - async def unload_lora_adapter( - request: UnloadLoRAAdapterRequest, raw_request: Request - ): - handler = models(raw_request) - response = await handler.unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse( - content=response.model_dump(), status_code=response.error.code - ) - - return Response(status_code=200, content=response) - - def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None @@ -1606,6 +1524,20 @@ def build_app(args: Namespace) -> FastAPI: ) else: app = FastAPI(lifespan=lifespan) + + if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "LoRA dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!" + ) + from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes + + register_dynamic_lora_routes(router) + + from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes + + register_sagemaker_routes(router) + app.include_router(router) app.root_path = args.root_path @@ -1696,6 +1628,8 @@ def build_app(args: Namespace) -> FastAPI: f"Invalid middleware {middleware}. Must be a function or a class." ) + app = sagemaker_standards.bootstrap(app) + return app diff --git a/vllm/entrypoints/sagemaker/__init__.py b/vllm/entrypoints/sagemaker/__init__.py new file mode 100644 index 000000000000..c1767137e4ea --- /dev/null +++ b/vllm/entrypoints/sagemaker/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""SageMaker-specific integration for vLLM.""" diff --git a/vllm/entrypoints/sagemaker/routes.py b/vllm/entrypoints/sagemaker/routes.py new file mode 100644 index 000000000000..498b7294f0d8 --- /dev/null +++ b/vllm/entrypoints/sagemaker/routes.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +from http import HTTPStatus + +import model_hosting_container_standards.sagemaker as sagemaker_standards +import pydantic +from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi.responses import JSONResponse, Response + +from vllm.entrypoints.openai.api_server import ( + INVOCATION_VALIDATORS, + base, + health, + validate_json_request, +) +from vllm.entrypoints.openai.protocol import ErrorResponse + + +def register_sagemaker_routes(router: APIRouter): + @router.post("/ping", response_class=Response) + @router.get("/ping", response_class=Response) + @sagemaker_standards.register_ping_handler + async def ping(raw_request: Request) -> Response: + """Ping check. Endpoint required for SageMaker""" + return await health(raw_request) + + @router.post( + "/invocations", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, + ) + @sagemaker_standards.register_invocation_handler + @sagemaker_standards.stateful_session_manager() + @sagemaker_standards.inject_adapter_id(adapter_path="model") + async def invocations(raw_request: Request): + """For SageMaker, routes requests based on the request type.""" + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + + valid_endpoints = [ + (validator, endpoint) + for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS + if get_handler(raw_request) is not None + ] + + for request_validator, endpoint in valid_endpoints: + try: + request = request_validator.validate_python(body) + except pydantic.ValidationError: + continue + + return await endpoint(request, raw_request) + + type_names = [ + t.__name__ if isinstance(t := validator._type, type) else str(t) + for validator, _ in valid_endpoints + ] + msg = f"Cannot find suitable handler for request. Expected one of: {type_names}" + res = base(raw_request).create_error_response(message=msg) + return JSONResponse(content=res.model_dump(), status_code=res.error.code) + + return router