From ae122b1cbde96c871fb74611363e04eecfbcce03 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 3 Mar 2025 19:04:45 +0000 Subject: [PATCH] [WIP][[V1][Metrics] Implement max_num_generation_tokens, request_params_n, and request_params_max_tokens metrics (#14055) Signed-off-by: Mark McLoughlin --- tests/entrypoints/openai/test_metrics.py | 6 +++ vllm/v1/engine/output_processor.py | 13 ++++++ vllm/v1/engine/parallel_sampling.py | 39 +++++++++++++++++- vllm/v1/metrics/loggers.py | 50 ++++++++++++++++++++++++ vllm/v1/metrics/stats.py | 5 +++ 5 files changed, 111 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 39ce4ba23548d..2bffd0ce138e6 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -239,6 +239,12 @@ EXPECTED_METRICS_V1 = [ "vllm:request_generation_tokens_sum", "vllm:request_generation_tokens_bucket", "vllm:request_generation_tokens_count", + "vllm:request_params_n_sum", + "vllm:request_params_n_bucket", + "vllm:request_params_n_count", + "vllm:request_params_max_tokens_sum", + "vllm:request_params_max_tokens_bucket", + "vllm:request_params_max_tokens_count", "vllm:time_to_first_token_seconds_sum", "vllm:time_to_first_token_seconds_bucket", "vllm:time_to_first_token_seconds_count", diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 4e1d1e3bf51bc..75c638a854f8f 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -36,6 +36,7 @@ class RequestState: prompt_token_ids: list[int], logprobs_processor: LogprobsProcessor, detokenizer: IncrementalDetokenizer, + max_tokens_param: Optional[int], arrival_time: float, queue: Optional[asyncio.Queue[RequestOutput]], log_stats: bool, @@ -50,6 +51,7 @@ class RequestState: self.prompt_len = len(prompt_token_ids) self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer + self.max_tokens_param = max_tokens_param self.is_prefilling = True self.queue = queue @@ -83,6 +85,8 @@ class RequestState: tokenizer=tokenizer, request=request, ), + max_tokens_param=(request.sampling_params.max_tokens if + request.sampling_params is not None else None), arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, @@ -198,6 +202,8 @@ class OutputProcessor: req_state = self.request_states.pop(request_id, None) if req_state is not None: self.lora_states.abort_request(req_state) + if req_state.parent_req is not None: + req_state.parent_req.finish_child_request(request_id) def add_request( self, @@ -310,6 +316,8 @@ class OutputProcessor: # If req not finished in EngineCore, but Detokenizer # detected stop string, abort needed in EngineCore. reqs_to_abort.append(req_id) + if req_state.parent_req is not None: + req_state.parent_req.finish_child_request(req_id) # Track per-request stats self._update_stats_from_finished(req_state, finish_reason, @@ -350,5 +358,10 @@ class OutputProcessor: iteration_stats.update_from_finished_request( finish_reason=finish_reason, num_prompt_tokens=len(req_state.prompt_token_ids), + max_tokens_param=req_state.max_tokens_param, req_stats=req_state.stats) self.lora_states.finish_request(req_state) + + ParentRequest.observe_finished_request( + req_state.parent_req, iteration_stats, + req_state.stats.num_generation_tokens) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index adced8973b033..4e2c78173b513 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -6,6 +6,7 @@ from typing import Callable, Optional, Union from vllm.outputs import CompletionOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams +from vllm.v1.metrics.stats import IterationStats class ParentRequest: @@ -18,9 +19,15 @@ class ParentRequest: request_id: str sampling_params: SamplingParams + # To track the completion of child requests + child_requests: set[str] + # To aggregate child completions when not streaming output_aggregator: Optional[RequestOutput] + # To find the max number of generated tokens across all children + max_num_generation_tokens: int + # To efficiently obtain child sampling params cached_child_sampling_params: Optional[SamplingParams] @@ -29,7 +36,9 @@ class ParentRequest: self.request_id = request_id self.sampling_params = sampling_params + self.child_requests = set() self.output_aggregator = None + self.max_num_generation_tokens = 0 self.cached_child_sampling_params = None @classmethod @@ -82,8 +91,12 @@ class ParentRequest: Returns: (request ID, sampling_params) tuple """ - return (f"{index}_{self.request_id}", - self._get_child_sampling_params(index)) + child_req_id = f"{index}_{self.request_id}" + self.child_requests.add(child_req_id) + return (child_req_id, self._get_child_sampling_params(index)) + + def finish_child_request(self, req_id: str): + self.child_requests.remove(req_id) @property def n(self) -> int: @@ -117,3 +130,25 @@ class ParentRequest: request_output.outputs = sorted(request_output.outputs, key=lambda x: x.index) return request_output + + def observe_num_generation_tokens(self, num_generation_tokens: int): + self.max_num_generation_tokens = max(num_generation_tokens, + self.max_num_generation_tokens) + return self.max_num_generation_tokens + + @staticmethod + def observe_finished_request(parent_req: Optional['ParentRequest'], + iteration_stats: IterationStats, + num_generation_tokens: int): + + n_param = parent_req.n if parent_req is not None else 1 + + if parent_req is not None: + num_generation_tokens = parent_req.observe_num_generation_tokens( + num_generation_tokens) + + # Child requests finished, we can now record to iteration stats + if parent_req is None or not parent_req.child_requests: + iteration_stats.max_num_generation_tokens_iter.append( + num_generation_tokens) + iteration_stats.n_params_iter.append(n_param) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 7f6de79104841..d02b9a5da2793 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -106,6 +106,9 @@ class PrometheusStatLogger(StatLoggerBase): max_model_len = vllm_config.model_config.max_model_len + # + # Scheduler state + # self.gauge_scheduler_running = prometheus_client.Gauge( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", @@ -116,6 +119,9 @@ class PrometheusStatLogger(StatLoggerBase): documentation="Number of requests waiting to be processed.", labelnames=labelnames).labels(*labelvalues) + # + # GPU cache + # self.gauge_gpu_cache_usage = prometheus_client.Gauge( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", @@ -133,6 +139,9 @@ class PrometheusStatLogger(StatLoggerBase): "GPU prefix cache hits, in terms of number of cached blocks.", labelnames=labelnames).labels(*labelvalues) + # + # Counters + # self.counter_num_preempted_reqs = prometheus_client.Counter( name="vllm:num_preemptions_total", documentation="Cumulative number of preemption from the engine.", @@ -159,6 +168,9 @@ class PrometheusStatLogger(StatLoggerBase): reason] = counter_request_success_base.labels(*(labelvalues + [str(reason)])) + # + # Histograms of counts + # self.histogram_num_prompt_tokens_request = \ prometheus_client.Histogram( name="vllm:request_prompt_tokens", @@ -180,6 +192,31 @@ class PrometheusStatLogger(StatLoggerBase): buckets=build_cudagraph_buckets(vllm_config), labelnames=labelnames).labels(*labelvalues) + self.histogram_max_num_generation_tokens_request = \ + prometheus_client.Histogram( + name="vllm:request_max_num_generation_tokens", + documentation= + "Histogram of maximum number of requested generation tokens.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames).labels(*labelvalues) + + self.histogram_n_request = \ + prometheus_client.Histogram( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + buckets=[1, 2, 5, 10, 20], + labelnames=labelnames).labels(*labelvalues) + + self.histogram_max_tokens_request = \ + prometheus_client.Histogram( + name="vllm:request_params_max_tokens", + documentation="Histogram of the max_tokens request parameter.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames).labels(*labelvalues) + + # + # Histogram of timing intervals + # self.histogram_time_to_first_token = \ prometheus_client.Histogram( name="vllm:time_to_first_token_seconds", @@ -239,6 +276,9 @@ class PrometheusStatLogger(StatLoggerBase): buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) + # + # LoRA metrics + # self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: self.labelname_max_lora = "max_lora" @@ -255,6 +295,9 @@ class PrometheusStatLogger(StatLoggerBase): self.labelname_running_lora_adapters, ]) + # + # Cache config info metric + # self.log_metrics_info("cache_config", vllm_config.cache_config) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): @@ -296,6 +339,11 @@ class PrometheusStatLogger(StatLoggerBase): iteration_stats.num_prompt_tokens + \ iteration_stats.num_generation_tokens) + for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter: + self.histogram_max_num_generation_tokens_request.observe( + max_gen_tokens) + for n_param in iteration_stats.n_params_iter: + self.histogram_n_request.observe(n_param) for ttft in iteration_stats.time_to_first_tokens_iter: self.histogram_time_to_first_token.observe(ttft) for tpot in iteration_stats.time_per_output_tokens_iter: @@ -317,6 +365,8 @@ class PrometheusStatLogger(StatLoggerBase): finished_request.num_prompt_tokens) self.histogram_num_generation_tokens_request.observe( finished_request.num_generation_tokens) + self.histogram_max_tokens_request.observe( + finished_request.max_tokens_param) if self.gauge_lora_info is not None: running_lora_adapters = \ diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index abdca95670e11..14ec7d2d7463f 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -66,6 +66,7 @@ class FinishedRequestStats: e2e_latency: float = 0.0 num_prompt_tokens: int = 0 num_generation_tokens: int = 0 + max_tokens_param: Optional[int] = None queued_time: float = 0.0 prefill_time: float = 0.0 inference_time: float = 0.0 @@ -81,6 +82,8 @@ class IterationStats: self.num_prompt_tokens = 0 self.num_preempted_reqs = 0 self.finished_requests: list[FinishedRequestStats] = [] + self.max_num_generation_tokens_iter: list[int] = [] + self.n_params_iter: list[int] = [] self.time_to_first_tokens_iter: list[float] = [] self.time_per_output_tokens_iter: list[float] = [] self.waiting_lora_adapters: dict[str, int] = {} @@ -150,6 +153,7 @@ class IterationStats: def update_from_finished_request(self, finish_reason: "FinishReason", num_prompt_tokens: int, + max_tokens_param: Optional[int], req_stats: RequestStateStats): e2e_latency = self._time_since(req_stats.arrival_time) @@ -173,6 +177,7 @@ class IterationStats: e2e_latency=e2e_latency, num_prompt_tokens=num_prompt_tokens, num_generation_tokens=req_stats.num_generation_tokens, + max_tokens_param=max_tokens_param, queued_time=queued_time, prefill_time=prefill_time, inference_time=inference_time,