mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 07:25:01 +08:00
Signed-off-by: Zuyi Zhao <zhaozuy@amazon.com> Signed-off-by: Shen Teng <sheteng@amazon.com> Co-authored-by: Shen Teng <sheteng@amazon.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
347 lines
12 KiB
Python
347 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
"""Integration test for middleware loader functionality.
|
|
|
|
Tests that customer middlewares get called correctly with a vLLM server.
|
|
"""
|
|
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
import requests
|
|
|
|
from ...utils import RemoteOpenAIServer
|
|
from .conftest import (
|
|
MODEL_NAME_SMOLLM,
|
|
)
|
|
|
|
|
|
class TestMiddlewareIntegration:
|
|
"""Integration test for middleware with vLLM server."""
|
|
|
|
def setup_method(self):
|
|
"""Setup for each test - simulate fresh server startup."""
|
|
self._clear_caches()
|
|
|
|
def _clear_caches(self):
|
|
"""Clear middleware registry and function loader cache."""
|
|
try:
|
|
from model_hosting_container_standards.common.fastapi.middleware import (
|
|
middleware_registry,
|
|
)
|
|
from model_hosting_container_standards.common.fastapi.middleware.source.decorator_loader import ( # noqa: E501
|
|
decorator_loader,
|
|
)
|
|
from model_hosting_container_standards.sagemaker.sagemaker_loader import (
|
|
SageMakerFunctionLoader,
|
|
)
|
|
|
|
middleware_registry.clear_middlewares()
|
|
decorator_loader.clear()
|
|
SageMakerFunctionLoader._default_function_loader = None
|
|
except ImportError:
|
|
pytest.skip("model-hosting-container-standards not available")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_customer_middleware_with_vllm_server(self):
|
|
"""Test that customer middlewares work with actual vLLM server.
|
|
|
|
Tests decorator-based middlewares (@custom_middleware, @input_formatter,
|
|
@output_formatter)
|
|
on multiple endpoints (chat/completions, invocations).
|
|
"""
|
|
try:
|
|
from model_hosting_container_standards.sagemaker.config import (
|
|
SageMakerEnvVars,
|
|
)
|
|
except ImportError:
|
|
pytest.skip("model-hosting-container-standards not available")
|
|
|
|
# Customer writes a middleware script with multiple decorators
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
|
f.write(
|
|
"""
|
|
from model_hosting_container_standards.common.fastapi.middleware import (
|
|
custom_middleware, input_formatter, output_formatter
|
|
)
|
|
|
|
# Global flag to track if input formatter was called
|
|
_input_formatter_called = False
|
|
|
|
@input_formatter
|
|
async def customer_input_formatter(request):
|
|
# Process input - mark that input formatter was called
|
|
global _input_formatter_called
|
|
_input_formatter_called = True
|
|
return request
|
|
|
|
@custom_middleware("throttle")
|
|
async def customer_throttle_middleware(request, call_next):
|
|
response = await call_next(request)
|
|
response.headers["X-Customer-Throttle"] = "applied"
|
|
order = response.headers.get("X-Middleware-Order", "")
|
|
response.headers["X-Middleware-Order"] = order + "throttle,"
|
|
return response
|
|
|
|
@output_formatter
|
|
async def customer_output_formatter(response):
|
|
global _input_formatter_called
|
|
response.headers["X-Customer-Processed"] = "true"
|
|
# Since input_formatter and output_formatter are combined into
|
|
# pre_post_process middleware,
|
|
# if output_formatter is called, input_formatter should have been called too
|
|
if _input_formatter_called:
|
|
response.headers["X-Input-Formatter-Called"] = "true"
|
|
order = response.headers.get("X-Middleware-Order", "")
|
|
response.headers["X-Middleware-Order"] = order + "output_formatter,"
|
|
return response
|
|
"""
|
|
)
|
|
script_path = f.name
|
|
|
|
try:
|
|
script_dir = os.path.dirname(script_path)
|
|
script_name = os.path.basename(script_path)
|
|
|
|
# Set environment variables to point to customer script
|
|
env_vars = {
|
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
|
}
|
|
|
|
args = [
|
|
"--dtype",
|
|
"bfloat16",
|
|
"--max-model-len",
|
|
"2048",
|
|
"--enforce-eager",
|
|
"--max-num-seqs",
|
|
"32",
|
|
]
|
|
|
|
with RemoteOpenAIServer(
|
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
|
) as server:
|
|
# Test 1: Middlewares applied to chat/completions endpoint
|
|
chat_response = requests.post(
|
|
server.url_for("v1/chat/completions"),
|
|
json={
|
|
"model": MODEL_NAME_SMOLLM,
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"max_tokens": 5,
|
|
"temperature": 0.0,
|
|
},
|
|
)
|
|
|
|
assert chat_response.status_code == 200
|
|
|
|
# Verify all middlewares were executed
|
|
assert "X-Customer-Throttle" in chat_response.headers
|
|
assert chat_response.headers["X-Customer-Throttle"] == "applied"
|
|
assert "X-Customer-Processed" in chat_response.headers
|
|
assert chat_response.headers["X-Customer-Processed"] == "true"
|
|
|
|
# Verify input formatter was called
|
|
assert "X-Input-Formatter-Called" in chat_response.headers
|
|
assert chat_response.headers["X-Input-Formatter-Called"] == "true"
|
|
|
|
# Verify middleware execution order
|
|
execution_order = chat_response.headers.get(
|
|
"X-Middleware-Order", ""
|
|
).rstrip(",")
|
|
order_parts = execution_order.split(",") if execution_order else []
|
|
assert "throttle" in order_parts
|
|
assert "output_formatter" in order_parts
|
|
|
|
# Test 2: Middlewares applied to invocations endpoint
|
|
invocations_response = requests.post(
|
|
server.url_for("invocations"),
|
|
json={
|
|
"model": MODEL_NAME_SMOLLM,
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
"max_tokens": 5,
|
|
"temperature": 0.0,
|
|
},
|
|
)
|
|
|
|
assert invocations_response.status_code == 200
|
|
|
|
# Verify all middlewares were executed
|
|
assert "X-Customer-Throttle" in invocations_response.headers
|
|
assert invocations_response.headers["X-Customer-Throttle"] == "applied"
|
|
assert "X-Customer-Processed" in invocations_response.headers
|
|
assert invocations_response.headers["X-Customer-Processed"] == "true"
|
|
|
|
# Verify input formatter was called
|
|
assert "X-Input-Formatter-Called" in invocations_response.headers
|
|
assert (
|
|
invocations_response.headers["X-Input-Formatter-Called"] == "true"
|
|
)
|
|
|
|
finally:
|
|
os.unlink(script_path)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_with_ping_endpoint(self):
|
|
"""Test that middlewares work with SageMaker ping endpoint."""
|
|
try:
|
|
from model_hosting_container_standards.sagemaker.config import (
|
|
SageMakerEnvVars,
|
|
)
|
|
except ImportError:
|
|
pytest.skip("model-hosting-container-standards not available")
|
|
|
|
# Customer writes a middleware script
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
|
f.write(
|
|
"""
|
|
from model_hosting_container_standards.common.fastapi.middleware import (
|
|
custom_middleware
|
|
)
|
|
|
|
@custom_middleware("pre_post_process")
|
|
async def ping_tracking_middleware(request, call_next):
|
|
response = await call_next(request)
|
|
if request.url.path == "/ping":
|
|
response.headers["X-Ping-Tracked"] = "true"
|
|
return response
|
|
"""
|
|
)
|
|
script_path = f.name
|
|
|
|
try:
|
|
script_dir = os.path.dirname(script_path)
|
|
script_name = os.path.basename(script_path)
|
|
|
|
env_vars = {
|
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
|
}
|
|
|
|
args = [
|
|
"--dtype",
|
|
"bfloat16",
|
|
"--max-model-len",
|
|
"2048",
|
|
"--enforce-eager",
|
|
"--max-num-seqs",
|
|
"32",
|
|
]
|
|
|
|
with RemoteOpenAIServer(
|
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
|
) as server:
|
|
# Test ping endpoint with middleware
|
|
response = requests.get(server.url_for("ping"))
|
|
|
|
assert response.status_code == 200
|
|
assert "X-Ping-Tracked" in response.headers
|
|
assert response.headers["X-Ping-Tracked"] == "true"
|
|
|
|
finally:
|
|
os.unlink(script_path)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_middleware_env_var_override(self):
|
|
"""Test middleware environment variable overrides."""
|
|
try:
|
|
from model_hosting_container_standards.common.fastapi.config import (
|
|
FastAPIEnvVars,
|
|
)
|
|
from model_hosting_container_standards.sagemaker.config import (
|
|
SageMakerEnvVars,
|
|
)
|
|
except ImportError:
|
|
pytest.skip("model-hosting-container-standards not available")
|
|
|
|
# Create a script with middleware functions specified via env vars
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) as f:
|
|
f.write(
|
|
"""
|
|
from fastapi import Request
|
|
|
|
# Global flag to track if pre_process was called
|
|
_pre_process_called = False
|
|
|
|
async def env_throttle_middleware(request, call_next):
|
|
response = await call_next(request)
|
|
response.headers["X-Env-Throttle"] = "applied"
|
|
return response
|
|
|
|
async def env_pre_process(request: Request) -> Request:
|
|
# Mark that pre_process was called
|
|
global _pre_process_called
|
|
_pre_process_called = True
|
|
return request
|
|
|
|
async def env_post_process(response):
|
|
global _pre_process_called
|
|
if hasattr(response, 'headers'):
|
|
response.headers["X-Env-Post-Process"] = "applied"
|
|
# Since pre_process and post_process are combined into
|
|
# pre_post_process middleware,
|
|
# if post_process is called, pre_process should have been called too
|
|
if _pre_process_called:
|
|
response.headers["X-Pre-Process-Called"] = "true"
|
|
return response
|
|
"""
|
|
)
|
|
script_path = f.name
|
|
|
|
try:
|
|
script_dir = os.path.dirname(script_path)
|
|
script_name = os.path.basename(script_path)
|
|
|
|
# Set environment variables for middleware
|
|
# Use script_name with .py extension as per plugin example
|
|
env_vars = {
|
|
SageMakerEnvVars.SAGEMAKER_MODEL_PATH: script_dir,
|
|
SageMakerEnvVars.CUSTOM_SCRIPT_FILENAME: script_name,
|
|
FastAPIEnvVars.CUSTOM_FASTAPI_MIDDLEWARE_THROTTLE: (
|
|
f"{script_name}:env_throttle_middleware"
|
|
),
|
|
FastAPIEnvVars.CUSTOM_PRE_PROCESS: f"{script_name}:env_pre_process",
|
|
FastAPIEnvVars.CUSTOM_POST_PROCESS: f"{script_name}:env_post_process",
|
|
}
|
|
|
|
args = [
|
|
"--dtype",
|
|
"bfloat16",
|
|
"--max-model-len",
|
|
"2048",
|
|
"--enforce-eager",
|
|
"--max-num-seqs",
|
|
"32",
|
|
]
|
|
|
|
with RemoteOpenAIServer(
|
|
MODEL_NAME_SMOLLM, args, env_dict=env_vars
|
|
) as server:
|
|
response = requests.get(server.url_for("ping"))
|
|
assert response.status_code == 200
|
|
|
|
# Check if environment variable middleware was applied
|
|
headers = response.headers
|
|
|
|
# Verify that env var middlewares were applied
|
|
assert "X-Env-Throttle" in headers, (
|
|
"Throttle middleware should be applied via env var"
|
|
)
|
|
assert headers["X-Env-Throttle"] == "applied"
|
|
|
|
assert "X-Env-Post-Process" in headers, (
|
|
"Post-process middleware should be applied via env var"
|
|
)
|
|
assert headers["X-Env-Post-Process"] == "applied"
|
|
|
|
# Verify that pre_process was called
|
|
assert "X-Pre-Process-Called" in headers, (
|
|
"Pre-process should be called via env var"
|
|
)
|
|
assert headers["X-Pre-Process-Called"] == "true"
|
|
|
|
finally:
|
|
os.unlink(script_path)
|