mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 23:47:13 +08:00
Adds anthropic /v1/messages endpoint to openai api_server (#27882)
Signed-off-by: bbartels <benjamin@bartels.dev> Signed-off-by: Benjamin Bartels <benjamin@bartels.dev>
This commit is contained in:
parent
c2ed069b32
commit
1e88fb751b
@ -5,7 +5,7 @@ import anthropic
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from ...utils import RemoteAnthropicServer
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
|
||||
@ -23,13 +23,13 @@ def server(): # noqa: F811
|
||||
"claude-3-7-sonnet-latest",
|
||||
]
|
||||
|
||||
with RemoteAnthropicServer(MODEL_NAME, args) as remote_server:
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
async with server.get_async_client_anthropic() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@ -105,37 +105,37 @@ async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic):
|
||||
|
||||
print(f"Anthropic response: {resp.model_dump_json()}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_tool_call_streaming(client: anthropic.AsyncAnthropic):
|
||||
resp = await client.messages.create(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in New York today?",
|
||||
}
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Useful for querying the weather "
|
||||
"in a specified city.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City or region, for example: "
|
||||
"New York, London, Tokyo, etc.",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in resp:
|
||||
print(chunk.model_dump_json())
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_tool_call_streaming(client: anthropic.AsyncAnthropic):
|
||||
resp = await client.messages.create(
|
||||
model="claude-3-7-sonnet-latest",
|
||||
max_tokens=1024,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in New York today?",
|
||||
}
|
||||
],
|
||||
tools=[
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Useful for querying the weather in a specified city.",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "City or region, for example: "
|
||||
"New York, London, Tokyo, etc.",
|
||||
}
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in resp:
|
||||
print(chunk.model_dump_json())
|
||||
142
tests/utils.py
142
tests/utils.py
@ -247,6 +247,23 @@ class RemoteOpenAIServer:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_client_anthropic(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return anthropic.Anthropic(
|
||||
base_url=self.url_for(),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_async_client_anthropic(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return anthropic.AsyncAnthropic(
|
||||
base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class RemoteOpenAIServerCustom(RemoteOpenAIServer):
|
||||
"""Launch test server with custom child process"""
|
||||
@ -293,131 +310,6 @@ class RemoteOpenAIServerCustom(RemoteOpenAIServer):
|
||||
self.proc.kill()
|
||||
|
||||
|
||||
class RemoteAnthropicServer:
|
||||
DUMMY_API_KEY = "token-abc123" # vLLM's Anthropic server does not need API key
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
vllm_serve_args: list[str],
|
||||
*,
|
||||
env_dict: dict[str, str] | None = None,
|
||||
seed: int | None = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: float | None = None,
|
||||
) -> None:
|
||||
if auto_port:
|
||||
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
||||
raise ValueError(
|
||||
"You have manually specified the port when `auto_port=True`."
|
||||
)
|
||||
|
||||
# Don't mutate the input args
|
||||
vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())]
|
||||
if seed is not None:
|
||||
if "--seed" in vllm_serve_args:
|
||||
raise ValueError(
|
||||
f"You have manually specified the seed when `seed={seed}`."
|
||||
)
|
||||
|
||||
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
||||
|
||||
parser = FlexibleArgumentParser(description="vLLM's remote Anthropic server.")
|
||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||
parser = ServeSubcommand().subparser_init(subparsers)
|
||||
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
||||
self.host = str(args.host or "localhost")
|
||||
self.port = int(args.port)
|
||||
|
||||
self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None
|
||||
|
||||
# download the model before starting the server to avoid timeout
|
||||
is_local = os.path.isdir(model)
|
||||
if not is_local:
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
model_config = engine_args.create_model_config()
|
||||
load_config = engine_args.create_load_config()
|
||||
|
||||
model_loader = get_model_loader(load_config)
|
||||
model_loader.download_model(model_config)
|
||||
|
||||
env = os.environ.copy()
|
||||
# the current process might initialize cuda,
|
||||
# to be safe, we should use spawn method
|
||||
env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
if env_dict is not None:
|
||||
env.update(env_dict)
|
||||
self.proc = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"vllm.entrypoints.anthropic.api_server",
|
||||
model,
|
||||
*vllm_serve_args,
|
||||
],
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
max_wait_seconds = max_wait_seconds or 240
|
||||
self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.proc.terminate()
|
||||
try:
|
||||
self.proc.wait(8)
|
||||
except subprocess.TimeoutExpired:
|
||||
# force kill if needed
|
||||
self.proc.kill()
|
||||
|
||||
def _wait_for_server(self, *, url: str, timeout: float):
|
||||
# run health check
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get(url).status_code == 200:
|
||||
break
|
||||
except Exception:
|
||||
# this exception can only be raised by requests.get,
|
||||
# which means the server is not ready yet.
|
||||
# the stack trace is not useful, so we suppress it
|
||||
# by using `raise from None`.
|
||||
result = self.proc.poll()
|
||||
if result is not None and result != 0:
|
||||
raise RuntimeError("Server exited unexpectedly.") from None
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > timeout:
|
||||
raise RuntimeError("Server failed to start in time.") from None
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"http://{self.host}:{self.port}"
|
||||
|
||||
def url_for(self, *parts: str) -> str:
|
||||
return self.url_root + "/" + "/".join(parts)
|
||||
|
||||
def get_client(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return anthropic.Anthropic(
|
||||
base_url=self.url_for(),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_async_client(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return anthropic.AsyncAnthropic(
|
||||
base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def _test_completion(
|
||||
client: openai.OpenAI,
|
||||
model: str,
|
||||
|
||||
@ -1,301 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Adapted from:
|
||||
# https://github.com/vllm/vllm/entrypoints/openai/api_server.py
|
||||
|
||||
import asyncio
|
||||
import signal
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from http import HTTPStatus
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, Depends, FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.datastructures import State
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicErrorResponse,
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client,
|
||||
create_server_socket,
|
||||
lifespan,
|
||||
load_log_config,
|
||||
validate_api_server_args,
|
||||
validate_json_request,
|
||||
)
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
|
||||
from vllm.entrypoints.openai.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.serving_models import (
|
||||
BaseModelPath,
|
||||
OpenAIServingModels,
|
||||
)
|
||||
|
||||
#
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.entrypoints.utils import (
|
||||
cli_env_setup,
|
||||
load_aware_call,
|
||||
process_chat_template,
|
||||
process_lora_modules,
|
||||
with_cancellation,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||
from vllm.utils.system_utils import set_ulimit
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger("vllm.entrypoints.anthropic.api_server")
|
||||
|
||||
_running_tasks: set[asyncio.Task] = set()
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def messages(request: Request) -> AnthropicServingMessages:
|
||||
return request.app.state.anthropic_serving_messages
|
||||
|
||||
|
||||
def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health", response_class=Response)
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/ping", response_class=Response)
|
||||
@router.post("/ping", response_class=Response)
|
||||
async def ping(raw_request: Request) -> Response:
|
||||
"""Ping check. Endpoint required for SageMaker"""
|
||||
return await health(raw_request)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
|
||||
handler = messages(raw_request)
|
||||
if handler is None:
|
||||
return messages(raw_request).create_error_response(
|
||||
message="The model does not support Messages API"
|
||||
)
|
||||
|
||||
generator = await handler.create_messages(request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
elif isinstance(generator, AnthropicMessagesResponse):
|
||||
logger.debug(
|
||||
"Anthropic Messages Response: %s", generator.model_dump(exclude_none=True)
|
||||
)
|
||||
return JSONResponse(content=generator.model_dump(exclude_none=True))
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
async def init_app_state(
|
||||
engine_client: EngineClient,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
vllm_config = engine_client.vllm_config
|
||||
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model) for name in served_model_names
|
||||
]
|
||||
|
||||
state.engine_client = engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
state.vllm_config = vllm_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
default_mm_loras = (
|
||||
vllm_config.lora_config.default_mm_loras
|
||||
if vllm_config.lora_config is not None
|
||||
else {}
|
||||
)
|
||||
lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
|
||||
|
||||
resolved_chat_template = await process_chat_template(
|
||||
args.chat_template, engine_client, model_config
|
||||
)
|
||||
|
||||
state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=lora_modules,
|
||||
)
|
||||
await state.openai_serving_models.init_static_loras()
|
||||
state.anthropic_serving_messages = AnthropicServingMessages(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
reasoning_parser=args.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
|
||||
|
||||
def setup_server(args):
|
||||
"""Validate API server args, set up signal handler, create socket
|
||||
ready to serve."""
|
||||
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
validate_api_server_args(args)
|
||||
|
||||
# workaround to make sure that we bind the port before the engine is set up.
|
||||
# This avoids race conditions with ray.
|
||||
# see https://github.com/vllm-project/vllm/issues/8204
|
||||
sock_addr = (args.host or "", args.port)
|
||||
sock = create_server_socket(sock_addr)
|
||||
|
||||
# workaround to avoid footguns where uvicorn drops requests with too
|
||||
# many concurrent requests active
|
||||
set_ulimit()
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
addr, port = sock_addr
|
||||
is_ssl = args.ssl_keyfile and args.ssl_certfile
|
||||
host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
|
||||
listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
|
||||
|
||||
return listen_address, sock
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
"""Run a single-worker API server."""
|
||||
listen_address, sock = setup_server(args)
|
||||
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
|
||||
|
||||
|
||||
def build_app(args: Namespace) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router)
|
||||
app.root_path = args.root_path
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.allowed_origins,
|
||||
allow_credentials=args.allow_credentials,
|
||||
allow_methods=args.allowed_methods,
|
||||
allow_headers=args.allowed_headers,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_server_worker(
|
||||
listen_address, sock, args, client_config=None, **uvicorn_kwargs
|
||||
) -> None:
|
||||
"""Run a single API server worker."""
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
||||
|
||||
# Load logging config for uvicorn if specified
|
||||
log_config = load_log_config(args.log_config_file)
|
||||
if log_config is not None:
|
||||
uvicorn_kwargs["log_config"] = log_config
|
||||
|
||||
async with build_async_engine_client(
|
||||
args,
|
||||
client_config=client_config,
|
||||
) as engine_client:
|
||||
app = build_app(args)
|
||||
|
||||
await init_app_state(engine_client, app.state, args)
|
||||
|
||||
logger.info("Starting vLLM API server %d on %s", server_index, listen_address)
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=sock,
|
||||
enable_ssl_refresh=args.enable_ssl_refresh,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
# NOTE: When the 'disable_uvicorn_access_log' value is True,
|
||||
# no access log will be output.
|
||||
access_log=not args.disable_uvicorn_access_log,
|
||||
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile,
|
||||
ssl_ca_certs=args.ssl_ca_certs,
|
||||
ssl_cert_reqs=args.ssl_cert_reqs,
|
||||
**uvicorn_kwargs,
|
||||
)
|
||||
|
||||
# NB: Await server shutdown only after the backend context is exited
|
||||
try:
|
||||
await shutdown_task
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# NOTE(simon):
|
||||
# This section should be in sync with vllm/entrypoints/cli/main.py for CLI
|
||||
# entrypoints.
|
||||
cli_env_setup()
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM Anthropic-Compatible RESTful API server."
|
||||
)
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args()
|
||||
validate_parsed_serve_args(args)
|
||||
|
||||
uvloop.run(run_server(args))
|
||||
@ -41,6 +41,13 @@ import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.protocol import Device, EngineClient
|
||||
from vllm.entrypoints.anthropic.protocol import (
|
||||
AnthropicError,
|
||||
AnthropicErrorResponse,
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
|
||||
@ -308,6 +315,10 @@ def responses(request: Request) -> OpenAIServingResponses | None:
|
||||
return request.app.state.openai_serving_responses
|
||||
|
||||
|
||||
def messages(request: Request) -> AnthropicServingMessages:
|
||||
return request.app.state.anthropic_serving_messages
|
||||
|
||||
|
||||
def chat(request: Request) -> OpenAIServingChat | None:
|
||||
return request.app.state.openai_serving_chat
|
||||
|
||||
@ -591,6 +602,63 @@ async def cancel_responses(response_id: str, raw_request: Request):
|
||||
return JSONResponse(content=response.model_dump())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/messages",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
|
||||
def translate_error_response(response: ErrorResponse) -> JSONResponse:
|
||||
anthropic_error = AnthropicErrorResponse(
|
||||
error=AnthropicError(
|
||||
type=response.error.type,
|
||||
message=response.error.message,
|
||||
)
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=response.error.code, content=anthropic_error.model_dump()
|
||||
)
|
||||
|
||||
handler = messages(raw_request)
|
||||
if handler is None:
|
||||
error = base(raw_request).create_error_response(
|
||||
message="The model does not support Messages API"
|
||||
)
|
||||
return translate_error_response(error)
|
||||
|
||||
try:
|
||||
generator = await handler.create_messages(request, raw_request)
|
||||
except Exception as e:
|
||||
logger.exception("Error in create_messages: %s", e)
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
|
||||
content=AnthropicErrorResponse(
|
||||
error=AnthropicError(
|
||||
type="internal_error",
|
||||
message=str(e),
|
||||
)
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return translate_error_response(generator)
|
||||
|
||||
elif isinstance(generator, AnthropicMessagesResponse):
|
||||
logger.debug(
|
||||
"Anthropic Messages Response: %s", generator.model_dump(exclude_none=True)
|
||||
)
|
||||
return JSONResponse(content=generator.model_dump(exclude_none=True))
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
@ -1817,6 +1885,24 @@ async def init_app_state(
|
||||
if "transcription" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.anthropic_serving_messages = (
|
||||
AnthropicServingMessages(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
args.response_role,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser,
|
||||
reasoning_parser=args.structured_outputs_config.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
|
||||
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||
state.server_load_metrics = 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user