mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 12:45:31 +08:00
Support SSL Key Rotation in HTTP Server (#13495)
This commit is contained in:
parent
2382ad29d1
commit
8db1b9d0a1
@ -37,3 +37,4 @@ einops # Required for Qwen2-VL.
|
|||||||
compressed-tensors == 0.9.2 # required for compressed-tensors
|
compressed-tensors == 0.9.2 # required for compressed-tensors
|
||||||
depyf==0.18.0 # required for profiling and debugging with compilation config
|
depyf==0.18.0 # required for profiling and debugging with compilation config
|
||||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||||
|
watchfiles # required for http server to monitor the updates of TLS files
|
||||||
|
|||||||
72
tests/entrypoints/test_ssl_cert_refresher.py
Normal file
72
tests/entrypoints/test_ssl_cert_refresher.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import asyncio
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from ssl import SSLContext
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||||
|
|
||||||
|
|
||||||
|
class MockSSLContext(SSLContext):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.load_cert_chain_count = 0
|
||||||
|
self.load_ca_count = 0
|
||||||
|
|
||||||
|
def load_cert_chain(
|
||||||
|
self,
|
||||||
|
certfile,
|
||||||
|
keyfile=None,
|
||||||
|
password=None,
|
||||||
|
):
|
||||||
|
self.load_cert_chain_count += 1
|
||||||
|
|
||||||
|
def load_verify_locations(
|
||||||
|
self,
|
||||||
|
cafile=None,
|
||||||
|
capath=None,
|
||||||
|
cadata=None,
|
||||||
|
):
|
||||||
|
self.load_ca_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
def create_file() -> str:
|
||||||
|
with tempfile.NamedTemporaryFile(dir='/tmp', delete=False) as f:
|
||||||
|
return f.name
|
||||||
|
|
||||||
|
|
||||||
|
def touch_file(path: str) -> None:
|
||||||
|
Path(path).touch()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ssl_refresher():
|
||||||
|
ssl_context = MockSSLContext()
|
||||||
|
key_path = create_file()
|
||||||
|
cert_path = create_file()
|
||||||
|
ca_path = create_file()
|
||||||
|
ssl_refresher = SSLCertRefresher(ssl_context, key_path, cert_path, ca_path)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
assert ssl_context.load_cert_chain_count == 0
|
||||||
|
assert ssl_context.load_ca_count == 0
|
||||||
|
|
||||||
|
touch_file(key_path)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
assert ssl_context.load_cert_chain_count == 1
|
||||||
|
assert ssl_context.load_ca_count == 0
|
||||||
|
|
||||||
|
touch_file(cert_path)
|
||||||
|
touch_file(ca_path)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
assert ssl_context.load_cert_chain_count == 2
|
||||||
|
assert ssl_context.load_ca_count == 1
|
||||||
|
|
||||||
|
ssl_refresher.stop()
|
||||||
|
|
||||||
|
touch_file(cert_path)
|
||||||
|
touch_file(ca_path)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
assert ssl_context.load_cert_chain_count == 2
|
||||||
|
assert ssl_context.load_ca_count == 1
|
||||||
@ -128,6 +128,7 @@ async def run_server(args: Namespace,
|
|||||||
shutdown_task = await serve_http(
|
shutdown_task = await serve_http(
|
||||||
app,
|
app,
|
||||||
sock=None,
|
sock=None,
|
||||||
|
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
log_level=args.log_level,
|
log_level=args.log_level,
|
||||||
@ -152,6 +153,11 @@ if __name__ == "__main__":
|
|||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The CA certificates file")
|
help="The CA certificates file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-ssl-refresh",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Refresh SSL Context when SSL certificate files change")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ssl-cert-reqs",
|
"--ssl-cert-reqs",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@ -12,13 +12,16 @@ from fastapi import FastAPI, Request, Response
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||||
|
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import find_process_using_port
|
from vllm.utils import find_process_using_port
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def serve_http(app: FastAPI, sock: Optional[socket.socket],
|
async def serve_http(app: FastAPI,
|
||||||
|
sock: Optional[socket.socket],
|
||||||
|
enable_ssl_refresh: bool = False,
|
||||||
**uvicorn_kwargs: Any):
|
**uvicorn_kwargs: Any):
|
||||||
logger.info("Available routes are:")
|
logger.info("Available routes are:")
|
||||||
for route in app.routes:
|
for route in app.routes:
|
||||||
@ -31,6 +34,7 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
|
|||||||
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
logger.info("Route: %s, Methods: %s", path, ', '.join(methods))
|
||||||
|
|
||||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||||
|
config.load()
|
||||||
server = uvicorn.Server(config)
|
server = uvicorn.Server(config)
|
||||||
_add_shutdown_handlers(app, server)
|
_add_shutdown_handlers(app, server)
|
||||||
|
|
||||||
@ -39,9 +43,17 @@ async def serve_http(app: FastAPI, sock: Optional[socket.socket],
|
|||||||
server_task = loop.create_task(
|
server_task = loop.create_task(
|
||||||
server.serve(sockets=[sock] if sock else None))
|
server.serve(sockets=[sock] if sock else None))
|
||||||
|
|
||||||
|
ssl_cert_refresher = None if not enable_ssl_refresh else SSLCertRefresher(
|
||||||
|
ssl_context=config.ssl,
|
||||||
|
key_path=config.ssl_keyfile,
|
||||||
|
cert_path=config.ssl_certfile,
|
||||||
|
ca_path=config.ssl_ca_certs)
|
||||||
|
|
||||||
def signal_handler() -> None:
|
def signal_handler() -> None:
|
||||||
# prevents the uvicorn signal handler to exit early
|
# prevents the uvicorn signal handler to exit early
|
||||||
server_task.cancel()
|
server_task.cancel()
|
||||||
|
if ssl_cert_refresher:
|
||||||
|
ssl_cert_refresher.stop()
|
||||||
|
|
||||||
async def dummy_shutdown() -> None:
|
async def dummy_shutdown() -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -960,6 +960,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||||||
shutdown_task = await serve_http(
|
shutdown_task = await serve_http(
|
||||||
app,
|
app,
|
||||||
sock=sock,
|
sock=sock,
|
||||||
|
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
log_level=args.uvicorn_log_level,
|
log_level=args.uvicorn_log_level,
|
||||||
|
|||||||
@ -164,6 +164,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
default=None,
|
default=None,
|
||||||
help="The CA certificates file.")
|
help="The CA certificates file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-ssl-refresh",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Refresh SSL Context when SSL certificate files change")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ssl-cert-reqs",
|
"--ssl-cert-reqs",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
74
vllm/entrypoints/ssl.py
Normal file
74
vllm/entrypoints/ssl.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from ssl import SSLContext
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from watchfiles import Change, awatch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SSLCertRefresher:
|
||||||
|
"""A class that monitors SSL certificate files and
|
||||||
|
reloads them when they change.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
ssl_context: SSLContext,
|
||||||
|
key_path: Optional[str] = None,
|
||||||
|
cert_path: Optional[str] = None,
|
||||||
|
ca_path: Optional[str] = None) -> None:
|
||||||
|
self.ssl = ssl_context
|
||||||
|
self.key_path = key_path
|
||||||
|
self.cert_path = cert_path
|
||||||
|
self.ca_path = ca_path
|
||||||
|
|
||||||
|
# Setup certification chain watcher
|
||||||
|
def update_ssl_cert_chain(change: Change, file_path: str) -> None:
|
||||||
|
logger.info("Reloading SSL certificate chain")
|
||||||
|
assert self.key_path and self.cert_path
|
||||||
|
self.ssl.load_cert_chain(self.cert_path, self.key_path)
|
||||||
|
|
||||||
|
self.watch_ssl_cert_task = None
|
||||||
|
if self.key_path and self.cert_path:
|
||||||
|
self.watch_ssl_cert_task = asyncio.create_task(
|
||||||
|
self._watch_files([self.key_path, self.cert_path],
|
||||||
|
update_ssl_cert_chain))
|
||||||
|
|
||||||
|
# Setup CA files watcher
|
||||||
|
def update_ssl_ca(change: Change, file_path: str) -> None:
|
||||||
|
logger.info("Reloading SSL CA certificates")
|
||||||
|
assert self.ca_path
|
||||||
|
self.ssl.load_verify_locations(self.ca_path)
|
||||||
|
|
||||||
|
self.watch_ssl_ca_task = None
|
||||||
|
if self.ca_path:
|
||||||
|
self.watch_ssl_ca_task = asyncio.create_task(
|
||||||
|
self._watch_files([self.ca_path], update_ssl_ca))
|
||||||
|
|
||||||
|
async def _watch_files(self, paths, fun: Callable[[Change, str],
|
||||||
|
None]) -> None:
|
||||||
|
"""Watch multiple file paths asynchronously."""
|
||||||
|
logger.info("SSLCertRefresher monitors files: %s", paths)
|
||||||
|
async for changes in awatch(*paths):
|
||||||
|
try:
|
||||||
|
for change, file_path in changes:
|
||||||
|
logger.info("File change detected: %s - %s", change.name,
|
||||||
|
file_path)
|
||||||
|
fun(change, file_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"SSLCertRefresher failed taking action on file change. "
|
||||||
|
"Error: %s", e)
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop watching files."""
|
||||||
|
if self.watch_ssl_cert_task:
|
||||||
|
self.watch_ssl_cert_task.cancel()
|
||||||
|
self.watch_ssl_cert_task = None
|
||||||
|
if self.watch_ssl_ca_task:
|
||||||
|
self.watch_ssl_ca_task.cancel()
|
||||||
|
self.watch_ssl_ca_task = None
|
||||||
Loading…
x
Reference in New Issue
Block a user