mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 06:34:58 +08:00
[V1] feat:add engine v1 tracing (#20372)
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com> Signed-off-by: Ye Zhang <zhysishu@gmail.com> Signed-off-by: RichardoMu <44485717+RichardoMrMu@users.noreply.github.com> Signed-off-by: simon-mo <simon.mo@hey.com> Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Mu Huai <tianbowen.tbw@antgroup.com> Co-authored-by: Ye Zhang <zhysishu@gmail.com> Co-authored-by: Benjamin Bartels <benjamin@bartels.dev> Co-authored-by: simon-mo <simon.mo@hey.com> Co-authored-by: 瑜琮 <ly186375@antfin.com> Co-authored-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
2e6bc46821
commit
40b6c9122b
@ -231,7 +231,7 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/metrics
|
||||
- tests/tracing
|
||||
- tests/v1/tracing
|
||||
commands:
|
||||
- pytest -v -s metrics
|
||||
- "pip install \
|
||||
|
||||
137
tests/v1/tracing/test_tracing.py
Normal file
137
tests/v1/tracing/test_tracing.py
Normal file
@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa
|
||||
# type: ignore
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from concurrent import futures
|
||||
from typing import Callable, Generator, Literal
|
||||
|
||||
import grpc
|
||||
import pytest
|
||||
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
||||
ExportTraceServiceResponse)
|
||||
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
|
||||
TraceServiceServicer, add_TraceServiceServicer_to_server)
|
||||
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
|
||||
from opentelemetry.sdk.environment_variables import (
|
||||
OTEL_EXPORTER_OTLP_TRACES_INSECURE)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.tracing import SpanAttributes
|
||||
|
||||
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
|
||||
|
||||
FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
|
||||
'array_value']
|
||||
|
||||
|
||||
def decode_value(value: AnyValue):
|
||||
field_decoders: dict[FieldName, Callable] = {
|
||||
"bool_value": (lambda v: v.bool_value),
|
||||
"string_value": (lambda v: v.string_value),
|
||||
"int_value": (lambda v: v.int_value),
|
||||
"double_value": (lambda v: v.double_value),
|
||||
"array_value":
|
||||
(lambda v: [decode_value(item) for item in v.array_value.values]),
|
||||
}
|
||||
for field, decoder in field_decoders.items():
|
||||
if value.HasField(field):
|
||||
return decoder(value)
|
||||
raise ValueError(f"Couldn't decode value: {value}")
|
||||
|
||||
|
||||
def decode_attributes(attributes: Iterable[KeyValue]):
|
||||
return {kv.key: decode_value(kv.value) for kv in attributes}
|
||||
|
||||
|
||||
class FakeTraceService(TraceServiceServicer):
|
||||
|
||||
def __init__(self):
|
||||
self.request = None
|
||||
self.evt = threading.Event()
|
||||
|
||||
def Export(self, request, context):
|
||||
self.request = request
|
||||
self.evt.set()
|
||||
return ExportTraceServiceResponse()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_service() -> Generator[FakeTraceService, None, None]:
|
||||
"""Fixture to set up a fake gRPC trace service"""
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
||||
service = FakeTraceService()
|
||||
add_TraceServiceServicer_to_server(service, server)
|
||||
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
|
||||
server.start()
|
||||
|
||||
yield service
|
||||
|
||||
server.stop(None)
|
||||
|
||||
|
||||
def test_traces(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
trace_service: FakeTraceService,
|
||||
):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.01,
|
||||
top_p=0.1,
|
||||
max_tokens=256,
|
||||
)
|
||||
model = "facebook/opt-125m"
|
||||
llm = LLM(model=model,
|
||||
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
|
||||
gpu_memory_utilization=0.3,
|
||||
disable_log_stats=False)
|
||||
prompts = ["This is a short prompt"]
|
||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
||||
print(f"test_traces outputs is : {outputs}")
|
||||
|
||||
timeout = 10
|
||||
if not trace_service.evt.wait(timeout):
|
||||
raise TimeoutError(
|
||||
f"The fake trace service didn't receive a trace within "
|
||||
f"the {timeout} seconds timeout")
|
||||
|
||||
request = trace_service.request
|
||||
assert len(request.resource_spans) == 1, (
|
||||
f"Expected 1 resource span, "
|
||||
f"but got {len(request.resource_spans)}")
|
||||
assert len(request.resource_spans[0].scope_spans) == 1, (
|
||||
f"Expected 1 scope span, "
|
||||
f"but got {len(request.resource_spans[0].scope_spans)}")
|
||||
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
|
||||
f"Expected 1 span, "
|
||||
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}")
|
||||
|
||||
attributes = decode_attributes(
|
||||
request.resource_spans[0].scope_spans[0].spans[0].attributes)
|
||||
# assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
|
||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
|
||||
) == sampling_params.temperature
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
|
||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS
|
||||
) == sampling_params.max_tokens
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
|
||||
outputs[0].prompt_token_ids)
|
||||
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
|
||||
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE) > 0
|
||||
assert attributes.get(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) > 0
|
||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) > 0
|
||||
@ -1491,12 +1491,6 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No OTLP observability so far.
|
||||
if (self.otlp_traces_endpoint or self.collect_detailed_traces):
|
||||
_raise_or_fallback(feature_name="--otlp-traces-endpoint",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# V1 supports N-gram, Medusa, and Eagle speculative decoding.
|
||||
if (self.speculative_config is not None
|
||||
and self.speculative_config.get("method") == "draft_model"):
|
||||
|
||||
@ -119,6 +119,11 @@ class SpanAttributes:
|
||||
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
|
||||
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = (
|
||||
"gen_ai.latency.time_in_model_execute")
|
||||
GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = \
|
||||
"gen_ai.latency.time_in_model_prefill"
|
||||
GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode"
|
||||
GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = \
|
||||
"gen_ai.latency.time_in_model_inference"
|
||||
|
||||
|
||||
def contains_trace_headers(headers: Mapping[str, str]) -> bool:
|
||||
|
||||
@ -969,9 +969,9 @@ class Scheduler(SchedulerInterface):
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events(),
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
trace_headers=request.trace_headers,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
))
|
||||
|
||||
else:
|
||||
# Invariant: EngineCore returns no partial prefill outputs.
|
||||
assert not prompt_logprobs_tensors
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
@ -66,6 +67,8 @@ class EngineCoreRequest(
|
||||
current_wave: int = 0
|
||||
priority: int = 0
|
||||
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
|
||||
|
||||
class EngineCoreEventType(enum.IntEnum):
|
||||
"""The type of engine core request event."""
|
||||
@ -111,6 +114,7 @@ class EngineCoreOutput(
|
||||
events: Optional[list[EngineCoreEvent]] = None
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
# The number of tokens with prefix cache hits.
|
||||
num_cached_tokens: int = 0
|
||||
|
||||
@ -144,7 +148,7 @@ class EngineCoreOutputs(
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
#NOTE(Nick): We could consider ways to make this more compact,
|
||||
# NOTE(Nick): We could consider ways to make this more compact,
|
||||
# e.g. columnwise layout
|
||||
|
||||
engine_index: int = 0
|
||||
|
||||
@ -26,6 +26,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -97,6 +98,7 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.vllm_config = vllm_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.log_requests = log_requests
|
||||
|
||||
self.log_stats = log_stats or (stat_loggers is not None)
|
||||
@ -124,6 +126,11 @@ class AsyncLLM(EngineClient):
|
||||
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(self.tokenizer,
|
||||
log_stats=self.log_stats)
|
||||
if self.observability_config.otlp_traces_endpoint is not None:
|
||||
tracer = init_tracer(
|
||||
"vllm.llm_engine",
|
||||
self.observability_config.otlp_traces_endpoint)
|
||||
self.output_processor.tracer = tracer
|
||||
|
||||
# EngineCore (starts the engine in background process).
|
||||
self.engine_core = EngineCoreClient.make_async_mp_client(
|
||||
@ -603,7 +610,7 @@ class AsyncLLM(EngineClient):
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request)
|
||||
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
return False
|
||||
return self.observability_config.otlp_traces_endpoint is not None
|
||||
|
||||
async def do_log_stats(
|
||||
self,
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tracing import init_tracer
|
||||
from vllm.transformers_utils.tokenizer_group import (
|
||||
TokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -65,6 +66,7 @@ class LLMEngine:
|
||||
"Set VLLM_USE_V1=0 and file and issue on Github.")
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
@ -99,6 +101,11 @@ class LLMEngine:
|
||||
# OutputProcessor (convert EngineCoreOutputs --> RequestOutput).
|
||||
self.output_processor = OutputProcessor(self.tokenizer,
|
||||
log_stats=self.log_stats)
|
||||
if self.observability_config.otlp_traces_endpoint is not None:
|
||||
tracer = init_tracer(
|
||||
"vllm.llm_engine",
|
||||
self.observability_config.otlp_traces_endpoint)
|
||||
self.output_processor.tracer = tracer
|
||||
|
||||
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
|
||||
self.engine_core = EngineCoreClient.make_client(
|
||||
|
||||
@ -11,6 +11,8 @@ import torch
|
||||
from vllm.outputs import (CompletionOutput, PoolingOutput,
|
||||
PoolingRequestOutput, RequestOutput)
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
|
||||
extract_trace_context)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
@ -71,7 +73,6 @@ class RequestOutputCollector:
|
||||
|
||||
@dataclass
|
||||
class OutputProcessorOutput:
|
||||
|
||||
request_outputs: list[Union[RequestOutput, PoolingRequestOutput]]
|
||||
reqs_to_abort: list[str]
|
||||
|
||||
@ -93,6 +94,9 @@ class RequestState:
|
||||
arrival_time: float,
|
||||
queue: Optional[RequestOutputCollector],
|
||||
log_stats: bool,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.parent_req = parent_req
|
||||
@ -105,6 +109,9 @@ class RequestState:
|
||||
self.logprobs_processor = logprobs_processor
|
||||
self.detokenizer = detokenizer
|
||||
self.max_tokens_param = max_tokens_param
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.temperature = temperature
|
||||
self.is_prefilling = True
|
||||
self.queue = queue
|
||||
self.num_cached_tokens = 0
|
||||
@ -137,10 +144,16 @@ class RequestState:
|
||||
request=request,
|
||||
)
|
||||
max_tokens_param = sampling_params.max_tokens
|
||||
top_p = sampling_params.top_p
|
||||
n = sampling_params.n
|
||||
temperature = sampling_params.temperature
|
||||
else:
|
||||
logprobs_processor = None
|
||||
detokenizer = None
|
||||
max_tokens_param = None
|
||||
top_p = None
|
||||
n = None
|
||||
temperature = None
|
||||
assert request.pooling_params is not None
|
||||
output_kind = request.pooling_params.output_kind
|
||||
|
||||
@ -156,6 +169,9 @@ class RequestState:
|
||||
logprobs_processor=logprobs_processor,
|
||||
detokenizer=detokenizer,
|
||||
max_tokens_param=max_tokens_param,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
temperature=temperature,
|
||||
arrival_time=request.arrival_time,
|
||||
queue=queue,
|
||||
log_stats=log_stats,
|
||||
@ -274,16 +290,13 @@ class RequestState:
|
||||
class OutputProcessor:
|
||||
"""Process EngineCoreOutputs into RequestOutputs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: TokenizerGroup,
|
||||
log_stats: bool,
|
||||
):
|
||||
def __init__(self, tokenizer: TokenizerGroup, log_stats: bool):
|
||||
self.log_stats = log_stats
|
||||
self.tokenizer = tokenizer
|
||||
self.request_states: dict[str, RequestState] = {}
|
||||
self.parent_requests: dict[str, ParentRequest] = {}
|
||||
self.lora_states = LoRARequestStates()
|
||||
self.tracer: Optional[Tracer] = None
|
||||
|
||||
def get_num_unfinished_requests(self):
|
||||
return len(self.request_states)
|
||||
@ -441,7 +454,9 @@ class OutputProcessor:
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(req_state, finish_reason,
|
||||
iteration_stats)
|
||||
|
||||
if self.tracer:
|
||||
self.do_tracing(engine_core_output, req_state,
|
||||
iteration_stats)
|
||||
self.lora_states.update_iteration_stats(iteration_stats)
|
||||
|
||||
return OutputProcessorOutput(
|
||||
@ -449,6 +464,63 @@ class OutputProcessor:
|
||||
reqs_to_abort=reqs_to_abort,
|
||||
)
|
||||
|
||||
def do_tracing(self, engine_core_output: EngineCoreOutput,
|
||||
req_state: RequestState,
|
||||
iteration_stats: Optional[IterationStats]) -> None:
|
||||
assert req_state.stats is not None
|
||||
assert iteration_stats is not None
|
||||
assert self.tracer is not None
|
||||
|
||||
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
|
||||
trace_context = extract_trace_context(engine_core_output.trace_headers)
|
||||
with (self.tracer.start_as_current_span(
|
||||
"llm_request",
|
||||
kind=SpanKind.SERVER,
|
||||
context=trace_context,
|
||||
start_time=arrival_time_nano_seconds) as span):
|
||||
metrics = req_state.stats
|
||||
e2e_time = iteration_stats.iteration_timestamp - \
|
||||
metrics.arrival_time
|
||||
queued_time = metrics.scheduled_ts - metrics.queued_ts
|
||||
prefill_time = metrics.first_token_ts - metrics.scheduled_ts
|
||||
decode_time = metrics.last_token_ts - metrics.first_token_ts
|
||||
inference_time = metrics.last_token_ts - metrics.scheduled_ts
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN,
|
||||
metrics.first_token_latency)
|
||||
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
|
||||
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
|
||||
queued_time)
|
||||
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
|
||||
len(req_state.prompt_token_ids))
|
||||
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
|
||||
metrics.num_generation_tokens)
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL,
|
||||
prefill_time)
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE,
|
||||
decode_time)
|
||||
span.set_attribute(
|
||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE,
|
||||
inference_time)
|
||||
|
||||
# meta
|
||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
|
||||
req_state.request_id)
|
||||
if req_state.top_p:
|
||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P,
|
||||
req_state.top_p)
|
||||
if req_state.max_tokens_param:
|
||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
|
||||
req_state.max_tokens_param)
|
||||
if req_state.temperature:
|
||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
|
||||
req_state.temperature)
|
||||
if req_state.n:
|
||||
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N,
|
||||
req_state.n)
|
||||
|
||||
def _update_stats_from_output(self, req_state: RequestState,
|
||||
engine_core_output: EngineCoreOutput,
|
||||
engine_core_timestamp: Optional[float],
|
||||
|
||||
@ -327,8 +327,6 @@ class Processor:
|
||||
# TODO(woosuk): Support pooling models.
|
||||
self._validate_lora(lora_request)
|
||||
self._validate_params(params, lora_request)
|
||||
if trace_headers is not None:
|
||||
raise ValueError("V1 does not support tracing yet.")
|
||||
|
||||
data_parallel_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
if data_parallel_rank is not None and not (0 <= data_parallel_rank <
|
||||
@ -435,6 +433,7 @@ class Processor:
|
||||
cache_salt=decoder_inputs.get("cache_salt"),
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
|
||||
def _validate_model_inputs(self,
|
||||
|
||||
@ -68,6 +68,9 @@ class RequestStateStats:
|
||||
first_token_ts: float = 0.0
|
||||
last_token_ts: float = 0.0
|
||||
|
||||
# first token latency
|
||||
first_token_latency: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinishedRequestStats:
|
||||
@ -116,6 +119,7 @@ class IterationStats:
|
||||
|
||||
first_token_latency = self._time_since(req_stats.arrival_time)
|
||||
self.time_to_first_tokens_iter.append(first_token_latency)
|
||||
req_stats.first_token_latency = first_token_latency
|
||||
|
||||
req_stats.num_generation_tokens += num_new_generation_tokens
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
@ -35,6 +36,7 @@ class Request:
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
cache_salt: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
block_hasher: Optional[Callable[["Request"],
|
||||
list["BlockHash"]]] = None,
|
||||
) -> None:
|
||||
@ -100,7 +102,8 @@ class Request:
|
||||
# they should also be updated simultaneously.
|
||||
self.output_token_ids = ConstantList(self._output_token_ids)
|
||||
self.all_token_ids = ConstantList(self._all_token_ids)
|
||||
|
||||
# trace_headers
|
||||
self.trace_headers = trace_headers
|
||||
# State
|
||||
# The number of tokens with prefix cache hits.
|
||||
self.num_cached_tokens = -1
|
||||
@ -136,6 +139,7 @@ class Request:
|
||||
if request.sampling_params else None,
|
||||
cache_salt=request.cache_salt,
|
||||
priority=request.priority,
|
||||
trace_headers=request.trace_headers,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user