[Frontend] Use request id from header (#10968)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde 2024-12-09 22:46:29 -07:00 committed by GitHub
parent 391d7b2763
commit 980ad394a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 27 additions and 13 deletions

View File

@ -16,5 +16,6 @@ mistral_common >= 1.5.0
aiohttp
starlette
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
requests

View File

@ -305,7 +305,7 @@ async def health(raw_request: Request) -> Response:
async def tokenize(request: TokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
generator = await handler.create_tokenize(request)
generator = await handler.create_tokenize(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@ -319,7 +319,7 @@ async def tokenize(request: TokenizeRequest, raw_request: Request):
async def detokenize(request: DetokenizeRequest, raw_request: Request):
handler = tokenization(raw_request)
generator = await handler.create_detokenize(request)
generator = await handler.create_detokenize(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)

View File

@ -176,7 +176,8 @@ class OpenAIServingChat(OpenAIServing):
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
request_id = f"chatcmpl-{request.request_id}"
request_id = "chatcmpl-" \
f"{self._base_request_id(raw_request, request.request_id)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:

View File

@ -30,7 +30,7 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators, random_uuid
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
@ -86,7 +86,7 @@ class OpenAIServingCompletion(OpenAIServing):
"suffix is not currently supported")
model_name = self.base_model_paths[0].name
request_id = f"cmpl-{random_uuid()}"
request_id = f"cmpl-{self._base_request_id(raw_request)}"
created_time = int(time.time())
request_metadata = RequestResponseMetadata(request_id=request_id)

View File

@ -19,7 +19,7 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators, random_uuid
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
@ -110,7 +110,7 @@ class OpenAIServingEmbedding(OpenAIServing):
"dimensions is currently not supported")
model_name = request.model
request_id = f"embd-{random_uuid()}"
request_id = f"embd-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic())
truncate_prompt_tokens = None

View File

@ -6,6 +6,7 @@ from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union)
from fastapi import Request
from pydantic import Field
from starlette.datastructures import Headers
from typing_extensions import Annotated
@ -47,7 +48,7 @@ from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of, make_async
from vllm.utils import AtomicCounter, is_list_of, make_async, random_uuid
logger = init_logger(__name__)
@ -565,6 +566,14 @@ class OpenAIServing:
return None
@staticmethod
def _base_request_id(raw_request: Request,
default: Optional[str] = None) -> Optional[str]:
"""Pulls the request id to use from a header, if provided"""
default = default or random_uuid()
return raw_request.headers.get(
"X-Request-Id", default) if raw_request is not None else default
@staticmethod
def _get_decoded_token(logprob: Logprob,
token_id: int,

View File

@ -15,7 +15,7 @@ from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import make_async, merge_async_iterators, random_uuid
from vllm.utils import make_async, merge_async_iterators
logger = init_logger(__name__)
@ -102,7 +102,7 @@ class OpenAIServingScores(OpenAIServing):
return error_check_ret
model_name = request.model
request_id = f"score-{random_uuid()}"
request_id = f"score-{self._base_request_id(raw_request)}"
created_time = int(time.monotonic())
truncate_prompt_tokens = request.truncate_prompt_tokens

View File

@ -1,5 +1,7 @@
from typing import Final, List, Optional, Union
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
@ -17,7 +19,6 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -48,12 +49,13 @@ class OpenAIServingTokenization(OpenAIServing):
async def create_tokenize(
self,
request: TokenizeRequest,
raw_request: Request,
) -> Union[TokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{random_uuid()}"
request_id = f"tokn-{self._base_request_id(raw_request)}"
try:
(
@ -112,12 +114,13 @@ class OpenAIServingTokenization(OpenAIServing):
async def create_detokenize(
self,
request: DetokenizeRequest,
raw_request: Request,
) -> Union[DetokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
request_id = f"tokn-{random_uuid()}"
request_id = f"tokn-{self._base_request_id(raw_request)}"
(
lora_request,