diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 264b7f58ad1ac..8977b7c8571bb 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -101,10 +101,9 @@ steps: - pytest -v -s distributed/test_pynccl.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py -##### fast check tests ##### -##### 1 GPU test ##### - - label: Metrics, Tracing Test # 10min + num_gpus: 2 + fast_check: true source_file_dependencies: - vllm/ - tests/metrics @@ -118,6 +117,9 @@ steps: opentelemetry-semantic-conventions-ai" - pytest -v -s tracing +##### fast check tests ##### +##### 1 GPU test ##### + - label: Regression Test # 5min mirror_hardwares: [amd] source_file_dependencies: diff --git a/tests/tracing/test_tracing.py b/tests/tracing/test_tracing.py index 90f26400952b9..3cee3b890862a 100644 --- a/tests/tracing/test_tracing.py +++ b/tests/tracing/test_tracing.py @@ -114,5 +114,71 @@ def test_traces(trace_service): SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft e2e_time = metrics.finished_time - metrics.arrival_time assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time + assert metrics.scheduler_time > 0 assert attributes.get( SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER) == metrics.scheduler_time + # Model forward and model execute should be none, since detailed traces is + # not enabled. + assert metrics.model_forward_time is None + assert metrics.model_execute_time is None + + +def test_traces_with_detailed_steps(trace_service): + os.environ[OTEL_EXPORTER_OTLP_TRACES_INSECURE] = "true" + + sampling_params = SamplingParams(temperature=0.01, + top_p=0.1, + max_tokens=256) + model = "facebook/opt-125m" + llm = LLM( + model=model, + otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, + collect_detailed_traces="all", + ) + prompts = ["This is a short prompt"] + outputs = llm.generate(prompts, sampling_params=sampling_params) + + timeout = 5 + if not trace_service.evt.wait(timeout): + raise TimeoutError( + f"The fake trace service didn't receive a trace within " + f"the {timeout} seconds timeout") + + attributes = decode_attributes(trace_service.request.resource_spans[0]. + scope_spans[0].spans[0].attributes) + assert attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) == model + assert attributes.get( + SpanAttributes.LLM_REQUEST_ID) == outputs[0].request_id + assert attributes.get( + SpanAttributes.LLM_REQUEST_TEMPERATURE) == sampling_params.temperature + assert attributes.get( + SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p + assert attributes.get( + SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens + assert attributes.get( + SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of + assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n + assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len( + outputs[0].prompt_token_ids) + completion_tokens = sum(len(o.token_ids) for o in outputs[0].outputs) + assert attributes.get( + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS) == completion_tokens + metrics = outputs[0].metrics + assert attributes.get( + SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE) == metrics.time_in_queue + ttft = metrics.first_token_time - metrics.arrival_time + assert attributes.get( + SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft + e2e_time = metrics.finished_time - metrics.arrival_time + assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time + assert metrics.scheduler_time > 0 + assert attributes.get( + SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER) == metrics.scheduler_time + assert metrics.model_forward_time > 0 + assert attributes.get( + SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD) == pytest.approx( + metrics.model_forward_time / 1000) + assert metrics.model_execute_time > 0 + assert attributes.get(SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE + ) == metrics.model_execute_time + assert metrics.model_forward_time < 1000 * metrics.model_execute_time diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6ed75a6e2ea6b..287de60149670 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1033,9 +1033,9 @@ class Scheduler: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. + scheduler_start_time = time.perf_counter() scheduler_outputs = self._schedule() now = time.time() - scheduler_start_time = time.perf_counter() if not self.cache_config.enable_prefix_caching: common_computed_block_nums = [] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 315b2f50a919d..6c7259129a109 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -948,11 +948,6 @@ class EngineArgs: raise ValueError( f"Invalid module {m} in collect_detailed_traces. " f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}") - if (m == "model" - or m == "all") and self.pipeline_parallel_size > 1: - raise ValueError( - "Collection of detailed traces for the 'model' module is " - "not yet supported with pipeline parallelism.") observability_config = ObservabilityConfig( otlp_traces_endpoint=self.otlp_traces_endpoint, collect_model_forward_time="model" in detailed_trace_modules diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 47068f77ec5df..d01b4e781b608 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1554,6 +1554,21 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: + if (self.is_driver_worker + and hidden_or_intermediate_states is not None + and isinstance(hidden_or_intermediate_states, + IntermediateTensors) + and self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + hidden_or_intermediate_states.tensors["model_forward_time"] = ( + torch.tensor(model_forward_time + orig_model_forward_time)) return hidden_or_intermediate_states logits = self.model.compute_logits(hidden_or_intermediate_states, @@ -1573,11 +1588,16 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): model_forward_end.synchronize() model_forward_time = model_forward_start.elapsed_time( model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() # If there are multiple workers, we are still tracking the latency # from the start time of the driver worker to the end time of the # driver worker. The model forward time will then end up covering # the communication time as well. - output.model_forward_time = model_forward_time + output.model_forward_time = (orig_model_forward_time + + model_forward_time) if self.return_hidden_states: # we only need to pass hidden states of most recent token @@ -1735,7 +1755,7 @@ class CUDAGraphRunner: **kwargs) if intermediate_tensors is not None: for key in intermediate_tensors.tensors: - if key != "model_execute_time": + if key != "model_execute_time" and key != "model_forward_time": self.input_buffers[key].copy_(intermediate_tensors[key], non_blocking=True) # Run the graph.