From c9f09a4fe83ef13824ea1663214ac7aad08d2b31 Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Fri, 10 Jan 2025 17:04:58 -0800 Subject: [PATCH] [mypy] Fix mypy warnings in api_server.py (#11941) Signed-off-by: Fred Reiss --- vllm/entrypoints/openai/api_server.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 925d7db43138b..1aeefe86cd05e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -14,7 +14,7 @@ from argparse import Namespace from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import AsyncIterator, Optional, Set, Tuple +from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union import uvloop from fastapi import APIRouter, FastAPI, HTTPException, Request @@ -420,6 +420,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): "use the Pooling API (`/pooling`) instead.") res = await fallback_handler.create_pooling(request, raw_request) + + generator: Union[ErrorResponse, EmbeddingResponse] if isinstance(res, PoolingResponse): generator = EmbeddingResponse( id=res.id, @@ -494,7 +496,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request): return await create_score(request, raw_request) -TASK_HANDLERS = { +TASK_HANDLERS: Dict[str, Dict[str, tuple]] = { "generate": { "messages": (ChatCompletionRequest, create_chat_completion), "default": (CompletionRequest, create_completion), @@ -652,7 +654,7 @@ def build_app(args: Namespace) -> FastAPI: module_path, object_name = middleware.rsplit(".", 1) imported = getattr(importlib.import_module(module_path), object_name) if inspect.isclass(imported): - app.add_middleware(imported) + app.add_middleware(imported) # type: ignore[arg-type] elif inspect.iscoroutinefunction(imported): app.middleware("http")(imported) else: