Adds anthropic /v1/messages endpoint to openai api_server (#27882)

Signed-off-by: bbartels <benjamin@bartels.dev>
Signed-off-by: Benjamin Bartels <benjamin@bartels.dev>
This commit is contained in:
Benjamin Bartels 2025-11-01 19:45:42 +00:00 committed by GitHub
parent c2ed069b32
commit 1e88fb751b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 139 additions and 462 deletions

View File

@ -5,7 +5,7 @@ import anthropic
import pytest
import pytest_asyncio
from ...utils import RemoteAnthropicServer
from ...utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen3-0.6B"
@ -23,13 +23,13 @@ def server(): # noqa: F811
"claude-3-7-sonnet-latest",
]
with RemoteAnthropicServer(MODEL_NAME, args) as remote_server:
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
async with server.get_async_client_anthropic() as async_client:
yield async_client
@ -105,37 +105,37 @@ async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic):
print(f"Anthropic response: {resp.model_dump_json()}")
@pytest.mark.asyncio
async def test_anthropic_tool_call_streaming(client: anthropic.AsyncAnthropic):
resp = await client.messages.create(
model="claude-3-7-sonnet-latest",
max_tokens=1024,
messages=[
{
"role": "user",
"content": "What's the weather like in New York today?",
}
],
tools=[
{
"name": "get_current_weather",
"description": "Useful for querying the weather "
"in a specified city.",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City or region, for example: "
"New York, London, Tokyo, etc.",
}
},
"required": ["location"],
},
}
],
stream=True,
)
async for chunk in resp:
print(chunk.model_dump_json())
@pytest.mark.asyncio
async def test_anthropic_tool_call_streaming(client: anthropic.AsyncAnthropic):
resp = await client.messages.create(
model="claude-3-7-sonnet-latest",
max_tokens=1024,
messages=[
{
"role": "user",
"content": "What's the weather like in New York today?",
}
],
tools=[
{
"name": "get_current_weather",
"description": "Useful for querying the weather in a specified city.",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City or region, for example: "
"New York, London, Tokyo, etc.",
}
},
"required": ["location"],
},
}
],
stream=True,
)
async for chunk in resp:
print(chunk.model_dump_json())

View File

@ -247,6 +247,23 @@ class RemoteOpenAIServer:
**kwargs,
)
def get_client_anthropic(self, **kwargs):
if "timeout" not in kwargs:
kwargs["timeout"] = 600
return anthropic.Anthropic(
base_url=self.url_for(),
api_key=self.DUMMY_API_KEY,
max_retries=0,
**kwargs,
)
def get_async_client_anthropic(self, **kwargs):
if "timeout" not in kwargs:
kwargs["timeout"] = 600
return anthropic.AsyncAnthropic(
base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs
)
class RemoteOpenAIServerCustom(RemoteOpenAIServer):
"""Launch test server with custom child process"""
@ -293,131 +310,6 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer):
self.proc.kill()
class RemoteAnthropicServer:
DUMMY_API_KEY = "token-abc123" # vLLM's Anthropic server does not need API key
def __init__(
self,
model: str,
vllm_serve_args: list[str],
*,
env_dict: dict[str, str] | None = None,
seed: int | None = 0,
auto_port: bool = True,
max_wait_seconds: float | None = None,
) -> None:
if auto_port:
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
raise ValueError(
"You have manually specified the port when `auto_port=True`."
)
# Don't mutate the input args
vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())]
if seed is not None:
if "--seed" in vllm_serve_args:
raise ValueError(
f"You have manually specified the seed when `seed={seed}`."
)
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
parser = FlexibleArgumentParser(description="vLLM's remote Anthropic server.")
subparsers = parser.add_subparsers(required=False, dest="subparser")
parser = ServeSubcommand().subparser_init(subparsers)
args = parser.parse_args(["--model", model, *vllm_serve_args])
self.host = str(args.host or "localhost")
self.port = int(args.port)
self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None
# download the model before starting the server to avoid timeout
is_local = os.path.isdir(model)
if not is_local:
engine_args = AsyncEngineArgs.from_cli_args(args)
model_config = engine_args.create_model_config()
load_config = engine_args.create_load_config()
model_loader = get_model_loader(load_config)
model_loader.download_model(model_config)
env = os.environ.copy()
# the current process might initialize cuda,
# to be safe, we should use spawn method
env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
if env_dict is not None:
env.update(env_dict)
self.proc = subprocess.Popen(
[
sys.executable,
"-m",
"vllm.entrypoints.anthropic.api_server",
model,
*vllm_serve_args,
],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
max_wait_seconds = max_wait_seconds or 240
self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.proc.terminate()
try:
self.proc.wait(8)
except subprocess.TimeoutExpired:
# force kill if needed
self.proc.kill()
def _wait_for_server(self, *, url: str, timeout: float):
# run health check
start = time.time()
while True:
try:
if requests.get(url).status_code == 200:
break
except Exception:
# this exception can only be raised by requests.get,
# which means the server is not ready yet.
# the stack trace is not useful, so we suppress it
# by using `raise from None`.
result = self.proc.poll()
if result is not None and result != 0:
raise RuntimeError("Server exited unexpectedly.") from None
time.sleep(0.5)
if time.time() - start > timeout:
raise RuntimeError("Server failed to start in time.") from None
@property
def url_root(self) -> str:
return f"http://{self.host}:{self.port}"
def url_for(self, *parts: str) -> str:
return self.url_root + "/" + "/".join(parts)
def get_client(self, **kwargs):
if "timeout" not in kwargs:
kwargs["timeout"] = 600
return anthropic.Anthropic(
base_url=self.url_for(),
api_key=self.DUMMY_API_KEY,
max_retries=0,
**kwargs,
)
def get_async_client(self, **kwargs):
if "timeout" not in kwargs:
kwargs["timeout"] = 600
return anthropic.AsyncAnthropic(
base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs
)
def _test_completion(
client: openai.OpenAI,
model: str,

View File

@ -1,301 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from:
# https://github.com/vllm/vllm/entrypoints/openai/api_server.py
import asyncio
import signal
import tempfile
from argparse import Namespace
from http import HTTPStatus
import uvloop
from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.datastructures import State
import vllm.envs as envs
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import (
AnthropicErrorResponse,
AnthropicMessagesRequest,
AnthropicMessagesResponse,
)
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.api_server import (
build_async_engine_client,
create_server_socket,
lifespan,
load_log_config,
validate_api_server_args,
validate_json_request,
)
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.openai.protocol import ErrorResponse
from vllm.entrypoints.openai.serving_models import (
BaseModelPath,
OpenAIServingModels,
)
#
# yapf: enable
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import (
cli_env_setup,
load_aware_call,
process_chat_template,
process_lora_modules,
with_cancellation,
)
from vllm.logger import init_logger
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.network_utils import is_valid_ipv6_address
from vllm.utils.system_utils import set_ulimit
from vllm.version import __version__ as VLLM_VERSION
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger("vllm.entrypoints.anthropic.api_server")
_running_tasks: set[asyncio.Task] = set()
router = APIRouter()
def messages(request: Request) -> AnthropicServingMessages:
return request.app.state.anthropic_serving_messages
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response:
"""Health check."""
await engine_client(raw_request).check_health()
return Response(status_code=200)
@router.get("/ping", response_class=Response)
@router.post("/ping", response_class=Response)
async def ping(raw_request: Request) -> Response:
"""Ping check. Endpoint required for SageMaker"""
return await health(raw_request)
@router.post(
"/v1/messages",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
handler = messages(raw_request)
if handler is None:
return messages(raw_request).create_error_response(
message="The model does not support Messages API"
)
generator = await handler.create_messages(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump())
elif isinstance(generator, AnthropicMessagesResponse):
logger.debug(
"Anthropic Messages Response: %s", generator.model_dump(exclude_none=True)
)
return JSONResponse(content=generator.model_dump(exclude_none=True))
return StreamingResponse(content=generator, media_type="text/event-stream")
async def init_app_state(
engine_client: EngineClient,
state: State,
args: Namespace,
) -> None:
vllm_config = engine_client.vllm_config
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
base_model_paths = [
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
]
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
state.vllm_config = vllm_config
model_config = vllm_config.model_config
default_mm_loras = (
vllm_config.lora_config.default_mm_loras
if vllm_config.lora_config is not None
else {}
)
lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
resolved_chat_template = await process_chat_template(
args.chat_template, engine_client, model_config
)
state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
)
await state.openai_serving_models.init_static_loras()
state.anthropic_serving_messages = AnthropicServingMessages(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
)
def setup_server(args):
"""Validate API server args, set up signal handler, create socket
ready to serve."""
logger.info("vLLM API server version %s", VLLM_VERSION)
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
validate_api_server_args(args)
# workaround to make sure that we bind the port before the engine is set up.
# This avoids race conditions with ray.
# see https://github.com/vllm-project/vllm/issues/8204
sock_addr = (args.host or "", args.port)
sock = create_server_socket(sock_addr)
# workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active
set_ulimit()
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
addr, port = sock_addr
is_ssl = args.ssl_keyfile and args.ssl_certfile
host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
return listen_address, sock
async def run_server(args, **uvicorn_kwargs) -> None:
"""Run a single-worker API server."""
listen_address, sock = setup_server(args)
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
def build_app(args: Namespace) -> FastAPI:
app = FastAPI(lifespan=lifespan)
app.include_router(router)
app.root_path = args.root_path
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
return app
async def run_server_worker(
listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
"""Run a single API server worker."""
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
server_index = client_config.get("client_index", 0) if client_config else 0
# Load logging config for uvicorn if specified
log_config = load_log_config(args.log_config_file)
if log_config is not None:
uvicorn_kwargs["log_config"] = log_config
async with build_async_engine_client(
args,
client_config=client_config,
) as engine_client:
app = build_app(args)
await init_app_state(engine_client, app.state, args)
logger.info("Starting vLLM API server %d on %s", server_index, listen_address)
shutdown_task = await serve_http(
app,
sock=sock,
enable_ssl_refresh=args.enable_ssl_refresh,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
# NOTE: When the 'disable_uvicorn_access_log' value is True,
# no access log will be output.
access_log=not args.disable_uvicorn_access_log,
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
ssl_ca_certs=args.ssl_ca_certs,
ssl_cert_reqs=args.ssl_cert_reqs,
**uvicorn_kwargs,
)
# NB: Await server shutdown only after the backend context is exited
try:
await shutdown_task
finally:
sock.close()
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
# entrypoints.
cli_env_setup()
parser = FlexibleArgumentParser(
description="vLLM Anthropic-Compatible RESTful API server."
)
parser = make_arg_parser(parser)
args = parser.parse_args()
validate_parsed_serve_args(args)
uvloop.run(run_server(args))

View File

@ -41,6 +41,13 @@ import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import Device, EngineClient
from vllm.entrypoints.anthropic.protocol import (
AnthropicError,
AnthropicErrorResponse,
AnthropicMessagesRequest,
AnthropicMessagesResponse,
)
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
@ -308,6 +315,10 @@ def responses(request: Request) -> OpenAIServingResponses | None:
return request.app.state.openai_serving_responses
def messages(request: Request) -> AnthropicServingMessages:
return request.app.state.anthropic_serving_messages
def chat(request: Request) -> OpenAIServingChat | None:
return request.app.state.openai_serving_chat
@ -591,6 +602,63 @@ async def cancel_responses(response_id: str, raw_request: Request):
return JSONResponse(content=response.model_dump())
@router.post(
"/v1/messages",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
def translate_error_response(response: ErrorResponse) -> JSONResponse:
anthropic_error = AnthropicErrorResponse(
error=AnthropicError(
type=response.error.type,
message=response.error.message,
)
)
return JSONResponse(
status_code=response.error.code, content=anthropic_error.model_dump()
)
handler = messages(raw_request)
if handler is None:
error = base(raw_request).create_error_response(
message="The model does not support Messages API"
)
return translate_error_response(error)
try:
generator = await handler.create_messages(request, raw_request)
except Exception as e:
logger.exception("Error in create_messages: %s", e)
return JSONResponse(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
content=AnthropicErrorResponse(
error=AnthropicError(
type="internal_error",
message=str(e),
)
).model_dump(),
)
if isinstance(generator, ErrorResponse):
return translate_error_response(generator)
elif isinstance(generator, AnthropicMessagesResponse):
logger.debug(
"Anthropic Messages Response: %s", generator.model_dump(exclude_none=True)
)
return JSONResponse(content=generator.model_dump(exclude_none=True))
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post(
"/v1/chat/completions",
dependencies=[Depends(validate_json_request)],
@ -1817,6 +1885,24 @@ async def init_app_state(
if "transcription" in supported_tasks
else None
)
state.anthropic_serving_messages = (
AnthropicServingMessages(
engine_client,
state.openai_serving_models,
args.response_role,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
reasoning_parser=args.structured_outputs_config.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
)
if "generate" in supported_tasks
else None
)
state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0