mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:15:51 +08:00
[V0 Deprecation] Remove V0 Tracing & Metrics tests (#25115)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
2c3c1bd07a
commit
2fc24e94f9
@ -217,16 +217,14 @@ steps:
|
|||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
- vllm/
|
- vllm/
|
||||||
- tests/metrics
|
|
||||||
- tests/v1/tracing
|
- tests/v1/tracing
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s metrics
|
|
||||||
- "pip install \
|
- "pip install \
|
||||||
'opentelemetry-sdk>=1.26.0' \
|
'opentelemetry-sdk>=1.26.0' \
|
||||||
'opentelemetry-api>=1.26.0' \
|
'opentelemetry-api>=1.26.0' \
|
||||||
'opentelemetry-exporter-otlp>=1.26.0' \
|
'opentelemetry-exporter-otlp>=1.26.0' \
|
||||||
'opentelemetry-semantic-conventions-ai>=0.4.1'"
|
'opentelemetry-semantic-conventions-ai>=0.4.1'"
|
||||||
- pytest -v -s tracing
|
- pytest -v -s v1/tracing
|
||||||
|
|
||||||
##### fast check tests #####
|
##### fast check tests #####
|
||||||
##### 1 GPU test #####
|
##### 1 GPU test #####
|
||||||
|
|||||||
@ -1,268 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import ray
|
|
||||||
from prometheus_client import REGISTRY
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm import EngineArgs, LLMEngine
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
||||||
from vllm.engine.metrics import RayPrometheusStatLogger
|
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
|
||||||
def use_v0_only(monkeypatch):
|
|
||||||
"""
|
|
||||||
This module tests V0 internals, so set VLLM_USE_V1=0.
|
|
||||||
"""
|
|
||||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
|
||||||
|
|
||||||
|
|
||||||
MODELS = [
|
|
||||||
"distilbert/distilgpt2",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
|
||||||
def test_metric_counter_prompt_tokens(
|
|
||||||
vllm_runner,
|
|
||||||
example_prompts,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
max_tokens: int,
|
|
||||||
) -> None:
|
|
||||||
with vllm_runner(model,
|
|
||||||
dtype=dtype,
|
|
||||||
disable_log_stats=False,
|
|
||||||
gpu_memory_utilization=0.4) as vllm_model:
|
|
||||||
tokenizer = vllm_model.llm.get_tokenizer()
|
|
||||||
prompt_token_counts = [
|
|
||||||
len(tokenizer.encode(p)) for p in example_prompts
|
|
||||||
]
|
|
||||||
# This test needs at least 2 prompts in a batch of different lengths to
|
|
||||||
# verify their token count is correct despite padding.
|
|
||||||
assert len(example_prompts) > 1, "at least 2 prompts are required"
|
|
||||||
assert prompt_token_counts[0] != prompt_token_counts[1], (
|
|
||||||
"prompts of different lengths are required")
|
|
||||||
vllm_prompt_token_count = sum(prompt_token_counts)
|
|
||||||
|
|
||||||
_ = vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
||||||
stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus']
|
|
||||||
metric_count = stat_logger.metrics.counter_prompt_tokens.labels(
|
|
||||||
**stat_logger.labels)._value.get()
|
|
||||||
|
|
||||||
assert vllm_prompt_token_count == metric_count, (
|
|
||||||
f"prompt token count: {vllm_prompt_token_count!r}\n"
|
|
||||||
f"metric: {metric_count!r}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
|
||||||
def test_metric_counter_generation_tokens(
|
|
||||||
vllm_runner,
|
|
||||||
example_prompts,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
max_tokens: int,
|
|
||||||
) -> None:
|
|
||||||
with vllm_runner(model,
|
|
||||||
dtype=dtype,
|
|
||||||
disable_log_stats=False,
|
|
||||||
gpu_memory_utilization=0.4) as vllm_model:
|
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
|
||||||
tokenizer = vllm_model.llm.get_tokenizer()
|
|
||||||
stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus']
|
|
||||||
metric_count = stat_logger.metrics.counter_generation_tokens.labels(
|
|
||||||
**stat_logger.labels)._value.get()
|
|
||||||
vllm_generation_count = 0
|
|
||||||
for i in range(len(example_prompts)):
|
|
||||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
|
||||||
prompt_ids = tokenizer.encode(example_prompts[i])
|
|
||||||
# vllm_output_ids contains both prompt tokens and generation tokens.
|
|
||||||
# We're interested only in the count of the generation tokens.
|
|
||||||
vllm_generation_count += len(vllm_output_ids) - len(prompt_ids)
|
|
||||||
|
|
||||||
assert vllm_generation_count == metric_count, (
|
|
||||||
f"generation token count: {vllm_generation_count!r}\n"
|
|
||||||
f"metric: {metric_count!r}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["float"])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"served_model_name",
|
|
||||||
[None, [], ["ModelName0"], ["ModelName0", "ModelName1", "ModelName2"]])
|
|
||||||
def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str,
|
|
||||||
served_model_name: list[str]) -> None:
|
|
||||||
with vllm_runner(model,
|
|
||||||
dtype=dtype,
|
|
||||||
disable_log_stats=False,
|
|
||||||
gpu_memory_utilization=0.3,
|
|
||||||
served_model_name=served_model_name) as vllm_model:
|
|
||||||
stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus']
|
|
||||||
metrics_tag_content = stat_logger.labels["model_name"]
|
|
||||||
|
|
||||||
if envs.VLLM_CI_USE_S3:
|
|
||||||
model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}"
|
|
||||||
if served_model_name is None or served_model_name == []:
|
|
||||||
assert metrics_tag_content == model, (
|
|
||||||
f"Metrics tag model_name is wrong! expect: {model!r}\n"
|
|
||||||
f"actual: {metrics_tag_content!r}")
|
|
||||||
else:
|
|
||||||
assert metrics_tag_content == served_model_name[0], (
|
|
||||||
f"Metrics tag model_name is wrong! expect: "
|
|
||||||
f"{served_model_name[0]!r}\n"
|
|
||||||
f"actual: {metrics_tag_content!r}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
|
||||||
@pytest.mark.parametrize("max_tokens", [4])
|
|
||||||
@pytest.mark.parametrize("disable_log_stats", [True, False])
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_engine_log_metrics_regression(
|
|
||||||
example_prompts,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
max_tokens: int,
|
|
||||||
disable_log_stats: bool,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Regression test ensuring async engine generates metrics
|
|
||||||
when disable_log_stats=False
|
|
||||||
(see: https://github.com/vllm-project/vllm/pull/4150#pullrequestreview-2008176678)
|
|
||||||
"""
|
|
||||||
engine_args = AsyncEngineArgs(
|
|
||||||
model=model,
|
|
||||||
dtype=dtype,
|
|
||||||
disable_log_stats=disable_log_stats,
|
|
||||||
)
|
|
||||||
async_engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
||||||
for i, prompt in enumerate(example_prompts):
|
|
||||||
results = async_engine.generate(
|
|
||||||
prompt,
|
|
||||||
SamplingParams(max_tokens=max_tokens),
|
|
||||||
f"request-id-{i}",
|
|
||||||
)
|
|
||||||
# Exhaust the async iterator to make the async engine work
|
|
||||||
async for _ in results:
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert_metrics(model, async_engine.engine, disable_log_stats,
|
|
||||||
len(example_prompts))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
|
||||||
@pytest.mark.parametrize("max_tokens", [4])
|
|
||||||
@pytest.mark.parametrize("disable_log_stats", [True, False])
|
|
||||||
def test_engine_log_metrics_regression(
|
|
||||||
example_prompts,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
max_tokens: int,
|
|
||||||
disable_log_stats: bool,
|
|
||||||
) -> None:
|
|
||||||
engine_args = EngineArgs(
|
|
||||||
model=model,
|
|
||||||
dtype=dtype,
|
|
||||||
disable_log_stats=disable_log_stats,
|
|
||||||
)
|
|
||||||
engine = LLMEngine.from_engine_args(engine_args)
|
|
||||||
for i, prompt in enumerate(example_prompts):
|
|
||||||
engine.add_request(
|
|
||||||
f"request-id-{i}",
|
|
||||||
prompt,
|
|
||||||
SamplingParams(max_tokens=max_tokens),
|
|
||||||
)
|
|
||||||
while engine.has_unfinished_requests():
|
|
||||||
engine.step()
|
|
||||||
|
|
||||||
if envs.VLLM_CI_USE_S3:
|
|
||||||
model = f"{MODEL_WEIGHTS_S3_BUCKET}/{model}"
|
|
||||||
assert_metrics(model, engine, disable_log_stats, len(example_prompts))
|
|
||||||
|
|
||||||
|
|
||||||
def assert_metrics(model: str, engine: LLMEngine, disable_log_stats: bool,
|
|
||||||
num_requests: int) -> None:
|
|
||||||
if disable_log_stats:
|
|
||||||
with pytest.raises(AttributeError):
|
|
||||||
_ = engine.stat_loggers
|
|
||||||
else:
|
|
||||||
assert (engine.stat_loggers
|
|
||||||
is not None), "engine.stat_loggers should be set"
|
|
||||||
# Ensure the count bucket of request-level histogram metrics matches
|
|
||||||
# the number of requests as a simple sanity check to ensure metrics are
|
|
||||||
# generated
|
|
||||||
labels = {'model_name': model}
|
|
||||||
request_histogram_metrics = [
|
|
||||||
"vllm:e2e_request_latency_seconds",
|
|
||||||
"vllm:request_prompt_tokens",
|
|
||||||
"vllm:request_generation_tokens",
|
|
||||||
"vllm:request_params_n",
|
|
||||||
"vllm:request_params_max_tokens",
|
|
||||||
]
|
|
||||||
for metric_name in request_histogram_metrics:
|
|
||||||
metric_value = REGISTRY.get_sample_value(f"{metric_name}_count",
|
|
||||||
labels)
|
|
||||||
assert (
|
|
||||||
metric_value == num_requests), "Metrics should be collected"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
|
||||||
@pytest.mark.parametrize("max_tokens", [16])
|
|
||||||
def test_engine_log_metrics_ray(
|
|
||||||
example_prompts,
|
|
||||||
model: str,
|
|
||||||
dtype: str,
|
|
||||||
max_tokens: int,
|
|
||||||
) -> None:
|
|
||||||
# This test is quite weak - it only checks that we can use
|
|
||||||
# RayPrometheusStatLogger without exceptions.
|
|
||||||
# Checking whether the metrics are actually emitted is unfortunately
|
|
||||||
# non-trivial.
|
|
||||||
|
|
||||||
# We have to run in a Ray task for Ray metrics to be emitted correctly
|
|
||||||
@ray.remote(num_gpus=1)
|
|
||||||
def _inner():
|
|
||||||
|
|
||||||
class _RayPrometheusStatLogger(RayPrometheusStatLogger):
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self._i = 0
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def log(self, *args, **kwargs):
|
|
||||||
self._i += 1
|
|
||||||
return super().log(*args, **kwargs)
|
|
||||||
|
|
||||||
engine_args = EngineArgs(
|
|
||||||
model=model,
|
|
||||||
dtype=dtype,
|
|
||||||
disable_log_stats=False,
|
|
||||||
)
|
|
||||||
engine = LLMEngine.from_engine_args(engine_args)
|
|
||||||
logger = _RayPrometheusStatLogger(
|
|
||||||
local_interval=0.5,
|
|
||||||
labels=dict(model_name=engine.model_config.served_model_name),
|
|
||||||
vllm_config=engine.vllm_config)
|
|
||||||
engine.add_logger("ray", logger)
|
|
||||||
for i, prompt in enumerate(example_prompts):
|
|
||||||
engine.add_request(
|
|
||||||
f"request-id-{i}",
|
|
||||||
prompt,
|
|
||||||
SamplingParams(max_tokens=max_tokens),
|
|
||||||
)
|
|
||||||
while engine.has_unfinished_requests():
|
|
||||||
engine.step()
|
|
||||||
assert logger._i > 0, ".log must be called at least once"
|
|
||||||
|
|
||||||
ray.get(_inner.remote())
|
|
||||||
@ -1,237 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
# ruff: noqa
|
|
||||||
# type: ignore
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import threading
|
|
||||||
from collections.abc import Iterable
|
|
||||||
from concurrent import futures
|
|
||||||
from typing import Callable, Generator, Literal
|
|
||||||
|
|
||||||
import grpc
|
|
||||||
import pytest
|
|
||||||
from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import (
|
|
||||||
ExportTraceServiceResponse)
|
|
||||||
from opentelemetry.proto.collector.trace.v1.trace_service_pb2_grpc import (
|
|
||||||
TraceServiceServicer, add_TraceServiceServicer_to_server)
|
|
||||||
from opentelemetry.proto.common.v1.common_pb2 import AnyValue, KeyValue
|
|
||||||
from opentelemetry.sdk.environment_variables import (
|
|
||||||
OTEL_EXPORTER_OTLP_TRACES_INSECURE)
|
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
from vllm.tracing import SpanAttributes
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
|
||||||
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
|
|
||||||
"""
|
|
||||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
|
||||||
all tests in the module.
|
|
||||||
"""
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setenv('VLLM_USE_V1', '0')
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
FAKE_TRACE_SERVER_ADDRESS = "localhost:4317"
|
|
||||||
|
|
||||||
FieldName = Literal['bool_value', 'string_value', 'int_value', 'double_value',
|
|
||||||
'array_value']
|
|
||||||
|
|
||||||
|
|
||||||
def decode_value(value: AnyValue):
|
|
||||||
field_decoders: dict[FieldName, Callable] = {
|
|
||||||
"bool_value": (lambda v: v.bool_value),
|
|
||||||
"string_value": (lambda v: v.string_value),
|
|
||||||
"int_value": (lambda v: v.int_value),
|
|
||||||
"double_value": (lambda v: v.double_value),
|
|
||||||
"array_value":
|
|
||||||
(lambda v: [decode_value(item) for item in v.array_value.values]),
|
|
||||||
}
|
|
||||||
for field, decoder in field_decoders.items():
|
|
||||||
if value.HasField(field):
|
|
||||||
return decoder(value)
|
|
||||||
raise ValueError(f"Couldn't decode value: {value}")
|
|
||||||
|
|
||||||
|
|
||||||
def decode_attributes(attributes: Iterable[KeyValue]):
|
|
||||||
return {kv.key: decode_value(kv.value) for kv in attributes}
|
|
||||||
|
|
||||||
|
|
||||||
class FakeTraceService(TraceServiceServicer):
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.request = None
|
|
||||||
self.evt = threading.Event()
|
|
||||||
|
|
||||||
def Export(self, request, context):
|
|
||||||
self.request = request
|
|
||||||
self.evt.set()
|
|
||||||
return ExportTraceServiceResponse()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def trace_service() -> Generator[FakeTraceService, None, None]:
|
|
||||||
"""Fixture to set up a fake gRPC trace service"""
|
|
||||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=1))
|
|
||||||
service = FakeTraceService()
|
|
||||||
add_TraceServiceServicer_to_server(service, server)
|
|
||||||
server.add_insecure_port(FAKE_TRACE_SERVER_ADDRESS)
|
|
||||||
server.start()
|
|
||||||
|
|
||||||
yield service
|
|
||||||
|
|
||||||
server.stop(None)
|
|
||||||
|
|
||||||
|
|
||||||
def test_traces(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
trace_service: FakeTraceService,
|
|
||||||
):
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.01,
|
|
||||||
top_p=0.1,
|
|
||||||
max_tokens=256,
|
|
||||||
)
|
|
||||||
model = "facebook/opt-125m"
|
|
||||||
llm = LLM(
|
|
||||||
model=model,
|
|
||||||
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
|
|
||||||
)
|
|
||||||
prompts = ["This is a short prompt"]
|
|
||||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
|
||||||
|
|
||||||
timeout = 5
|
|
||||||
if not trace_service.evt.wait(timeout):
|
|
||||||
raise TimeoutError(
|
|
||||||
f"The fake trace service didn't receive a trace within "
|
|
||||||
f"the {timeout} seconds timeout")
|
|
||||||
|
|
||||||
request = trace_service.request
|
|
||||||
assert len(request.resource_spans) == 1, (
|
|
||||||
f"Expected 1 resource span, "
|
|
||||||
f"but got {len(request.resource_spans)}")
|
|
||||||
assert len(request.resource_spans[0].scope_spans) == 1, (
|
|
||||||
f"Expected 1 scope span, "
|
|
||||||
f"but got {len(request.resource_spans[0].scope_spans)}")
|
|
||||||
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
|
|
||||||
f"Expected 1 span, "
|
|
||||||
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}")
|
|
||||||
|
|
||||||
attributes = decode_attributes(
|
|
||||||
request.resource_spans[0].scope_spans[0].spans[0].attributes)
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
|
|
||||||
) == sampling_params.temperature
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS
|
|
||||||
) == sampling_params.max_tokens
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
|
|
||||||
outputs[0].prompt_token_ids)
|
|
||||||
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
|
|
||||||
metrics = outputs[0].metrics
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE
|
|
||||||
) == metrics.time_in_queue
|
|
||||||
ttft = metrics.first_token_time - metrics.arrival_time
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
|
|
||||||
e2e_time = metrics.finished_time - metrics.arrival_time
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time
|
|
||||||
assert metrics.scheduler_time > 0
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER
|
|
||||||
) == metrics.scheduler_time
|
|
||||||
# Model forward and model execute should be none, since detailed traces is
|
|
||||||
# not enabled.
|
|
||||||
assert metrics.model_forward_time is None
|
|
||||||
assert metrics.model_execute_time is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_traces_with_detailed_steps(
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
trace_service: FakeTraceService,
|
|
||||||
):
|
|
||||||
with monkeypatch.context() as m:
|
|
||||||
m.setenv(OTEL_EXPORTER_OTLP_TRACES_INSECURE, "true")
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.01,
|
|
||||||
top_p=0.1,
|
|
||||||
max_tokens=256,
|
|
||||||
)
|
|
||||||
model = "facebook/opt-125m"
|
|
||||||
llm = LLM(
|
|
||||||
model=model,
|
|
||||||
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
|
|
||||||
collect_detailed_traces=["all"],
|
|
||||||
)
|
|
||||||
prompts = ["This is a short prompt"]
|
|
||||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
|
||||||
|
|
||||||
timeout = 5
|
|
||||||
if not trace_service.evt.wait(timeout):
|
|
||||||
raise TimeoutError(
|
|
||||||
f"The fake trace service didn't receive a trace within "
|
|
||||||
f"the {timeout} seconds timeout")
|
|
||||||
|
|
||||||
request = trace_service.request
|
|
||||||
assert len(request.resource_spans) == 1, (
|
|
||||||
f"Expected 1 resource span, "
|
|
||||||
f"but got {len(request.resource_spans)}")
|
|
||||||
assert len(request.resource_spans[0].scope_spans) == 1, (
|
|
||||||
f"Expected 1 scope span, "
|
|
||||||
f"but got {len(request.resource_spans[0].scope_spans)}")
|
|
||||||
assert len(request.resource_spans[0].scope_spans[0].spans) == 1, (
|
|
||||||
f"Expected 1 span, "
|
|
||||||
f"but got {len(request.resource_spans[0].scope_spans[0].spans)}")
|
|
||||||
|
|
||||||
attributes = decode_attributes(
|
|
||||||
request.resource_spans[0].scope_spans[0].spans[0].attributes)
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_RESPONSE_MODEL) == model
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_REQUEST_ID) == outputs[0].request_id
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE
|
|
||||||
) == sampling_params.temperature
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_REQUEST_TOP_P) == sampling_params.top_p
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS
|
|
||||||
) == sampling_params.max_tokens
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_REQUEST_N) == sampling_params.n
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS) == len(
|
|
||||||
outputs[0].prompt_token_ids)
|
|
||||||
completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs)
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS) == completion_tokens
|
|
||||||
metrics = outputs[0].metrics
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE
|
|
||||||
) == metrics.time_in_queue
|
|
||||||
ttft = metrics.first_token_time - metrics.arrival_time
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
|
|
||||||
e2e_time = metrics.finished_time - metrics.arrival_time
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_E2E) == e2e_time
|
|
||||||
assert metrics.scheduler_time > 0
|
|
||||||
assert attributes.get(SpanAttributes.GEN_AI_LATENCY_TIME_IN_SCHEDULER
|
|
||||||
) == metrics.scheduler_time
|
|
||||||
assert metrics.model_forward_time > 0
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_FORWARD
|
|
||||||
) == pytest.approx(metrics.model_forward_time / 1000)
|
|
||||||
assert metrics.model_execute_time > 0
|
|
||||||
assert attributes.get(
|
|
||||||
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE
|
|
||||||
) == metrics.model_execute_time
|
|
||||||
assert metrics.model_forward_time < 1000 * metrics.model_execute_time
|
|
||||||
Loading…
x
Reference in New Issue
Block a user