diff --git a/tests/tracing/test_tracing.py b/tests/tracing/test_tracing.py index 2f8f62cf2d1e4..a492daf3b49ca 100644 --- a/tests/tracing/test_tracing.py +++ b/tests/tracing/test_tracing.py @@ -114,3 +114,5 @@ def test_traces(trace_service): SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft e2e_time = metrics.finished_time - metrics.arrival_time assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time + assert attributes.get(SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER + ) == metrics.scheduler_time diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 4a0e2b4184936..4d2edc02139c4 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -24,6 +24,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: load_config=engine_config.load_config, lora_config=engine_config.lora_config, prompt_adapter_config=engine_config.prompt_adapter_config, + observability_config=engine_config.observability_config, is_driver_worker=True, ) return model_runner diff --git a/vllm/config.py b/vllm/config.py index 59cabbfc965da..4207466cfc5c0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1656,11 +1656,26 @@ class ObservabilityConfig: """Configuration for observability.""" otlp_traces_endpoint: Optional[str] = None + # Collecting detailed timing information for each request can be expensive. + + # If set, collects the model forward time for the request. + collect_model_forward_time: bool = False + + # If set, collects the model execute time for the request. + collect_model_execute_time: bool = False + def __post_init__(self): if not is_otel_installed() and self.otlp_traces_endpoint is not None: raise ValueError("OpenTelemetry packages must be installed before " "configuring 'otlp_traces_endpoint'") + if ((self.collect_model_forward_time + or self.collect_model_execute_time) + and self.otlp_traces_endpoint is None): + raise ValueError( + "collect_model_forward_time or collect_model_execute_time " + "requires --otlp-traces-endpoint to be set.") + @dataclass(frozen=True) class EngineConfig: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index a40f6e2e248b9..b16850c7eb9f8 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1032,6 +1032,7 @@ class Scheduler: # such as self.running, self.swapped, and self.waiting. scheduler_outputs = self._schedule() now = time.time() + scheduler_start_time = time.perf_counter() if not self.cache_config.enable_prefix_caching: common_computed_block_nums = [] @@ -1127,6 +1128,17 @@ class Scheduler: self._seq_group_metadata_cache.reset() + scheduler_time = time.perf_counter() - scheduler_start_time + # Add this to scheduler time to all the sequences that are currently + # running. This will help estimate if the scheduler is a significant + # component in the e2e latency. + for seq_group in self.running: + if seq_group is not None and seq_group.metrics is not None: + if seq_group.metrics.scheduler_time is not None: + seq_group.metrics.scheduler_time += scheduler_time + else: + seq_group.metrics.scheduler_time = scheduler_time + return seq_group_metadata_list, scheduler_outputs def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b6d2ea463940f..73698511fdbb7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -20,6 +20,8 @@ if TYPE_CHECKING: logger = init_logger(__name__) +ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] + def nullable_str(val: str): if not val or val == "None": @@ -117,6 +119,7 @@ class EngineArgs: disable_logprobs_during_spec_decoding: Optional[bool] = None otlp_traces_endpoint: Optional[str] = None + collect_detailed_traces: Optional[str] = None def __post_init__(self): if self.tokenizer is None: @@ -660,6 +663,16 @@ class EngineArgs: type=str, default=None, help='Target URL to which OpenTelemetry traces will be sent.') + parser.add_argument( + '--collect-detailed-traces', + type=str, + default=None, + help="Valid choices are " + + ",".join(ALLOWED_DETAILED_TRACE_MODULES) + + ". It makes sense to set this only if --otlp-traces-endpoint is" + " set. If set, it will collect detailed traces for the specified " + "modules. This involves use of possibly costly and or blocking " + "operations and hence might have a performance impact.") return parser @@ -852,8 +865,26 @@ class EngineArgs: decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) + detailed_trace_modules = [] + if self.collect_detailed_traces is not None: + detailed_trace_modules = self.collect_detailed_traces.split(",") + for m in detailed_trace_modules: + if m not in ALLOWED_DETAILED_TRACE_MODULES: + raise ValueError( + f"Invalid module {m} in collect_detailed_traces. " + f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}") + if (m == "model" + or m == "all") and self.pipeline_parallel_size > 1: + raise ValueError( + "Collection of detailed traces for the 'model' module is " + "not yet supported with pipeline parallelism.") observability_config = ObservabilityConfig( - otlp_traces_endpoint=self.otlp_traces_endpoint) + otlp_traces_endpoint=self.otlp_traces_endpoint, + collect_model_forward_time="model" in detailed_trace_modules + or "all" in detailed_trace_modules, + collect_model_execute_time="worker" in detailed_trace_modules + or "all" in detailed_trace_modules, + ) if (model_config.get_sliding_window() is not None and scheduler_config.chunked_prefill_enabled diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dcaf375f9b15d..39bb1f9c274fa 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -267,6 +267,7 @@ class LLMEngine: speculative_config=speculative_config, load_config=load_config, prompt_adapter_config=prompt_adapter_config, + observability_config=self.observability_config, ) if not self.model_config.embedding_mode: @@ -1183,6 +1184,22 @@ class LLMEngine: seq_group = scheduled_seq_group.seq_group seq_group.update_num_computed_tokens( scheduled_seq_group.token_chunk_size) + if output is not None and len(output) > 0: + for o in output: + if (isinstance(o, SamplerOutput) + and seq_group.metrics is not None): + if seq_group.metrics.model_forward_time is not None: + seq_group.metrics.model_forward_time += ( + o.model_forward_time) + else: + seq_group.metrics.model_forward_time = ( + o.model_forward_time) + if seq_group.metrics.model_execute_time is not None: + seq_group.metrics.model_execute_time += ( + o.model_execute_time) + else: + seq_group.metrics.model_execute_time = ( + o.model_execute_time) if self.model_config.embedding_mode: self._process_sequence_group_outputs(seq_group, outputs) continue @@ -1575,6 +1592,18 @@ class LLMEngine: seq_span.set_attribute( SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft) seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) + if metrics.scheduler_time is not None: + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER, + metrics.scheduler_time) + if metrics.model_forward_time is not None: + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD, + metrics.model_forward_time / 1000.0) + if metrics.model_execute_time is not None: + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE, + metrics.model_execute_time) def is_encoder_decoder_model(self): return self.model_config.is_encoder_decoder_model diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index a848bc70941c1..bc4f544554ae4 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -2,8 +2,8 @@ from abc import ABC, abstractmethod from typing import List, Optional, Set, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, + ModelConfig, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest @@ -32,6 +32,7 @@ class ExecutorBase(ABC): multimodal_config: Optional[MultiModalConfig], speculative_config: Optional[SpeculativeConfig], prompt_adapter_config: Optional[PromptAdapterConfig], + observability_config: Optional[ObservabilityConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -43,7 +44,7 @@ class ExecutorBase(ABC): self.multimodal_config = multimodal_config self.speculative_config = speculative_config self.prompt_adapter_config = prompt_adapter_config - + self.observability_config = observability_config self._init_executor() @abstractmethod diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 3e77af0e20323..57b9e2b33b982 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -60,6 +60,7 @@ class GPUExecutor(ExecutorBase): prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=(not self.parallel_config) or (rank % self.parallel_config.tensor_parallel_size == 0), + observability_config=self.observability_config, ) def _get_create_worker_kwargs( diff --git a/vllm/sequence.py b/vllm/sequence.py index fd2dc96566786..7349bc6f13bd6 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -92,6 +92,13 @@ class RequestMetrics: first_token_time: The time when the first token was generated. time_in_queue: The time the request spent in the queue. finished_time: The time when the request was finished. + scheduler_time: The time spent in the scheduler when this request was + being considered by the scheduler. + model_forward_time: The time spent in the model forward pass when this + request was in the batch. + model_execute_time: The time spent in the model execute function. This + will include model forward, block/sync across + workers, cpu-gpu sync time and sampling time. """ arrival_time: float last_token_time: float @@ -99,6 +106,9 @@ class RequestMetrics: first_token_time: Optional[float] time_in_queue: Optional[float] finished_time: Optional[float] = None + scheduler_time: Optional[float] = None + model_forward_time: Optional[float] = None + model_execute_time: Optional[float] = None class SequenceData: @@ -968,6 +978,13 @@ class SamplerOutput: # Optional last hidden states from the model. hidden_states: Optional[torch.Tensor] = None + # Time taken in the forward pass for this across all workers + model_forward_time: Optional[float] = None + + # Time taken in the model execute function. This will include model forward, + # block/sync across workers, cpu-gpu sync time and sampling time. + model_execute_time: Optional[float] = None + def __getitem__(self, idx: int): return self.outputs[idx] diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index b76a1ab4cf243..7707d38a0f666 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -23,8 +23,8 @@ except ImportError: FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) + ModelConfig, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.multimodal import MultiModalInputs from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, @@ -69,6 +69,7 @@ class TP1DraftModelRunner(ModelRunner): multimodal_config: Optional[MultiModalConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None, ): if return_hidden_states: raise ValueError( @@ -88,6 +89,7 @@ class TP1DraftModelRunner(ModelRunner): multimodal_config=multimodal_config, prompt_adapter_config=prompt_adapter_config, return_hidden_states=return_hidden_states, + observability_config=observability_config, ) self.flashinfer_decode_workspace_buffer = None diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py index 957f2f8c8843e..e5b6933a5ce1c 100644 --- a/vllm/spec_decode/target_model_runner.py +++ b/vllm/spec_decode/target_model_runner.py @@ -1,8 +1,8 @@ from typing import List, Optional from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) + ModelConfig, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.sequence import SequenceGroupMetadata from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) @@ -32,7 +32,8 @@ class TargetModelRunner(ModelRunner): is_driver_worker: bool = False, prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, - return_hidden_states: bool = False): + return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None): # An internal boolean member variable to indicate if token log # probabilities are needed or not. self.disable_logprobs = True @@ -49,6 +50,7 @@ class TargetModelRunner(ModelRunner): multimodal_config=multimodal_config, prompt_adapter_config=prompt_adapter_config, return_hidden_states=return_hidden_states, + observability_config=observability_config, ) def prepare_model_input( diff --git a/vllm/tracing.py b/vllm/tracing.py index 7ac38e6a0f663..8bd71b8fd9ea5 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -92,6 +92,12 @@ class SpanAttributes(BaseSpanAttributes): LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue" LLM_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token" LLM_LATENCY_E2E = "gen_ai.latency.e2e" + LLM_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler" + # Time taken in the forward pass for this across all workers + LLM_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward" + # Time taken in the model execute function. This will include model + # forward, block/sync across workers, cpu-gpu sync time and sampling time. + LLM_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute" def contains_trace_headers(headers: Mapping[str, str]) -> bool: diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 72ab96cf3c2e1..197c4c730e5a7 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type import torch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) + ModelConfig, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalInputs @@ -45,6 +45,7 @@ class EmbeddingModelRunner( is_driver_worker: bool = False, prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, + observability_config: Optional[ObservabilityConfig] = None, ): super().__init__(model_config, parallel_config, @@ -56,7 +57,8 @@ class EmbeddingModelRunner( kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, prompt_adapter_config=prompt_adapter_config, - multimodal_config=multimodal_config) + multimodal_config=multimodal_config, + observability_config=observability_config) @torch.inference_mode() def execute_model( diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index d9b323f2af09e..4e66a04674c2a 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -10,8 +10,8 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, get_global_forced_attn_backend, global_force_attn_backend) from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) + ModelConfig, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -82,6 +82,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): is_driver_worker: bool = False, prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, + observability_config: Optional[ObservabilityConfig] = None, ): ''' EncoderDecoderModelRunner constructor. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c71e29381e645..cfbbb6698cd8a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -27,8 +27,8 @@ except ImportError: import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig) + ModelConfig, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig) from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY @@ -806,6 +806,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): prompt_adapter_config: Optional[PromptAdapterConfig] = None, multimodal_config: Optional[MultiModalConfig] = None, return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None, ): self.model_config = model_config self.parallel_config = parallel_config @@ -818,6 +819,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): self.prompt_adapter_config = prompt_adapter_config self.multimodal_config = multimodal_config self.return_hidden_states = return_hidden_states + self.observability_config = observability_config self.device = self.device_config.device self.pin_memory = is_pin_memory_available() @@ -1527,6 +1529,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_start = torch.cuda.Event(enable_timing=True) + model_forward_end = torch.cuda.Event(enable_timing=True) + model_forward_start.record() + hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -1537,6 +1545,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): device=self.device), **seqlen_agnostic_kwargs) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.record() + # Compute the logits in the last pipeline stage. if not get_pp_group().is_last_rank: return hidden_or_intermediate_states @@ -1552,6 +1564,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): logits=logits, sampling_metadata=model_input.sampling_metadata, ) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + # If there are multiple workers, we are still tracking the latency + # from the start time of the driver worker to the end time of the + # driver worker. The model forward time will then end up covering + # the communication time as well. + output.model_forward_time = model_forward_time if self.return_hidden_states: # we only need to pass hidden states of most recent token @@ -1709,8 +1732,9 @@ class CUDAGraphRunner: **kwargs) if intermediate_tensors is not None: for key in intermediate_tensors.tensors: - self.input_buffers[key].copy_(intermediate_tensors[key], - non_blocking=True) + if key != "model_execute_time": + self.input_buffers[key].copy_(intermediate_tensors[key], + non_blocking=True) # Run the graph. self.graph.replay() # Return the output tensor. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 45751eceacbca..90b844bf42139 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,8 +7,8 @@ import torch import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, MultiModalConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, + ModelConfig, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, @@ -51,6 +51,7 @@ class Worker(LocalOrDistributedWorkerBase): prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, + observability_config: Optional[ObservabilityConfig] = None, ) -> None: self.model_config = model_config self.parallel_config = parallel_config @@ -73,6 +74,7 @@ class Worker(LocalOrDistributedWorkerBase): from vllm.utils import init_cached_hf_modules init_cached_hf_modules() self.multimodal_config = multimodal_config + self.observability_config = observability_config # Return hidden states from target model if the draft model is an # mlp_speculator @@ -102,6 +104,7 @@ class Worker(LocalOrDistributedWorkerBase): is_driver_worker=is_driver_worker, prompt_adapter_config=prompt_adapter_config, multimodal_config=multimodal_config, + observability_config=observability_config, **speculative_args, ) # Uninitialized cache engine. Will be initialized by diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index e56440693b895..20db3dad1caab 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,11 +1,13 @@ import dataclasses import importlib import os +import time from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import torch +from vllm.config import ObservabilityConfig from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -172,6 +174,7 @@ class LocalOrDistributedWorkerBase(WorkerBase): """ is_driver_worker: bool model_runner: ModelRunnerBase + observability_config: Optional[ObservabilityConfig] = None @property @abstractmethod @@ -219,6 +222,7 @@ class LocalOrDistributedWorkerBase(WorkerBase): ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" + start_time = time.perf_counter() if self.is_driver_worker: if execute_model_req is None: if self.do_metadata_broadcast: @@ -265,21 +269,36 @@ class LocalOrDistributedWorkerBase(WorkerBase): return [] intermediate_tensors = None + orig_model_execute_time = 0.0 if not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group())) + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + orig_model_execute_time = intermediate_tensors.tensors.get( + "model_execute_time", torch.tensor(0)).item() output = self.model_runner.execute_model( model_input, self.kv_cache[worker_input.virtual_engine] if self.kv_cache is not None else None, intermediate_tensors, num_steps) - + model_execute_time = time.perf_counter() - start_time if not get_pp_group().is_last_rank: # output is IntermediateTensors + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + output.tensors["model_execute_time"] = torch.tensor( + model_execute_time + orig_model_execute_time) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) return [None] + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time + and output is not None): + for o in output: + o.model_execute_time = (orig_model_execute_time + + model_execute_time) # output is List[SamplerOutput] return output