[Core] Add span metrics for model_forward, scheduler and sampler time (#7089)

This commit is contained in:
Mahesh Keralapura 2024-08-09 13:55:13 -07:00 committed by GitHub
parent 70d268a399
commit 933790c209
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 189 additions and 21 deletions

View File

@ -114,3 +114,5 @@ def test_traces(trace_service):
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
e2e_time = metrics.finished_time - metrics.arrival_time
assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time
assert attributes.get(SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER
) == metrics.scheduler_time

View File

@ -24,6 +24,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
observability_config=engine_config.observability_config,
is_driver_worker=True,
)
return model_runner

View File

@ -1656,11 +1656,26 @@ class ObservabilityConfig:
"""Configuration for observability."""
otlp_traces_endpoint: Optional[str] = None
# Collecting detailed timing information for each request can be expensive.
# If set, collects the model forward time for the request.
collect_model_forward_time: bool = False
# If set, collects the model execute time for the request.
collect_model_execute_time: bool = False
def __post_init__(self):
if not is_otel_installed() and self.otlp_traces_endpoint is not None:
raise ValueError("OpenTelemetry packages must be installed before "
"configuring 'otlp_traces_endpoint'")
if ((self.collect_model_forward_time
or self.collect_model_execute_time)
and self.otlp_traces_endpoint is None):
raise ValueError(
"collect_model_forward_time or collect_model_execute_time "
"requires --otlp-traces-endpoint to be set.")
@dataclass(frozen=True)
class EngineConfig:

View File

@ -1032,6 +1032,7 @@ class Scheduler:
# such as self.running, self.swapped, and self.waiting.
scheduler_outputs = self._schedule()
now = time.time()
scheduler_start_time = time.perf_counter()
if not self.cache_config.enable_prefix_caching:
common_computed_block_nums = []
@ -1127,6 +1128,17 @@ class Scheduler:
self._seq_group_metadata_cache.reset()
scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
# running. This will help estimate if the scheduler is a significant
# component in the e2e latency.
for seq_group in self.running:
if seq_group is not None and seq_group.metrics is not None:
if seq_group.metrics.scheduler_time is not None:
seq_group.metrics.scheduler_time += scheduler_time
else:
seq_group.metrics.scheduler_time = scheduler_time
return seq_group_metadata_list, scheduler_outputs
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:

View File

@ -20,6 +20,8 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
def nullable_str(val: str):
if not val or val == "None":
@ -117,6 +119,7 @@ class EngineArgs:
disable_logprobs_during_spec_decoding: Optional[bool] = None
otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None
def __post_init__(self):
if self.tokenizer is None:
@ -660,6 +663,16 @@ class EngineArgs:
type=str,
default=None,
help='Target URL to which OpenTelemetry traces will be sent.')
parser.add_argument(
'--collect-detailed-traces',
type=str,
default=None,
help="Valid choices are " +
",".join(ALLOWED_DETAILED_TRACE_MODULES) +
". It makes sense to set this only if --otlp-traces-endpoint is"
" set. If set, it will collect detailed traces for the specified "
"modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact.")
return parser
@ -852,8 +865,26 @@ class EngineArgs:
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
detailed_trace_modules = []
if self.collect_detailed_traces is not None:
detailed_trace_modules = self.collect_detailed_traces.split(",")
for m in detailed_trace_modules:
if m not in ALLOWED_DETAILED_TRACE_MODULES:
raise ValueError(
f"Invalid module {m} in collect_detailed_traces. "
f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
if (m == "model"
or m == "all") and self.pipeline_parallel_size > 1:
raise ValueError(
"Collection of detailed traces for the 'model' module is "
"not yet supported with pipeline parallelism.")
observability_config = ObservabilityConfig(
otlp_traces_endpoint=self.otlp_traces_endpoint)
otlp_traces_endpoint=self.otlp_traces_endpoint,
collect_model_forward_time="model" in detailed_trace_modules
or "all" in detailed_trace_modules,
collect_model_execute_time="worker" in detailed_trace_modules
or "all" in detailed_trace_modules,
)
if (model_config.get_sliding_window() is not None
and scheduler_config.chunked_prefill_enabled

View File

@ -267,6 +267,7 @@ class LLMEngine:
speculative_config=speculative_config,
load_config=load_config,
prompt_adapter_config=prompt_adapter_config,
observability_config=self.observability_config,
)
if not self.model_config.embedding_mode:
@ -1183,6 +1184,22 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size)
if output is not None and len(output) > 0:
for o in output:
if (isinstance(o, SamplerOutput)
and seq_group.metrics is not None):
if seq_group.metrics.model_forward_time is not None:
seq_group.metrics.model_forward_time += (
o.model_forward_time)
else:
seq_group.metrics.model_forward_time = (
o.model_forward_time)
if seq_group.metrics.model_execute_time is not None:
seq_group.metrics.model_execute_time += (
o.model_execute_time)
else:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.embedding_mode:
self._process_sequence_group_outputs(seq_group, outputs)
continue
@ -1575,6 +1592,18 @@ class LLMEngine:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
if metrics.scheduler_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
metrics.scheduler_time)
if metrics.model_forward_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
metrics.model_forward_time / 1000.0)
if metrics.model_execute_time is not None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
metrics.model_execute_time)
def is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model

View File

@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
@ -32,6 +32,7 @@ class ExecutorBase(ABC):
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
@ -43,7 +44,7 @@ class ExecutorBase(ABC):
self.multimodal_config = multimodal_config
self.speculative_config = speculative_config
self.prompt_adapter_config = prompt_adapter_config
self.observability_config = observability_config
self._init_executor()
@abstractmethod

View File

@ -60,6 +60,7 @@ class GPUExecutor(ExecutorBase):
prompt_adapter_config=self.prompt_adapter_config,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
observability_config=self.observability_config,
)
def _get_create_worker_kwargs(

View File

@ -92,6 +92,13 @@ class RequestMetrics:
first_token_time: The time when the first token was generated.
time_in_queue: The time the request spent in the queue.
finished_time: The time when the request was finished.
scheduler_time: The time spent in the scheduler when this request was
being considered by the scheduler.
model_forward_time: The time spent in the model forward pass when this
request was in the batch.
model_execute_time: The time spent in the model execute function. This
will include model forward, block/sync across
workers, cpu-gpu sync time and sampling time.
"""
arrival_time: float
last_token_time: float
@ -99,6 +106,9 @@ class RequestMetrics:
first_token_time: Optional[float]
time_in_queue: Optional[float]
finished_time: Optional[float] = None
scheduler_time: Optional[float] = None
model_forward_time: Optional[float] = None
model_execute_time: Optional[float] = None
class SequenceData:
@ -968,6 +978,13 @@ class SamplerOutput:
# Optional last hidden states from the model.
hidden_states: Optional[torch.Tensor] = None
# Time taken in the forward pass for this across all workers
model_forward_time: Optional[float] = None
# Time taken in the model execute function. This will include model forward,
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time: Optional[float] = None
def __getitem__(self, idx: int):
return self.outputs[idx]

View File

@ -23,8 +23,8 @@ except ImportError:
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
@ -69,6 +69,7 @@ class TP1DraftModelRunner(ModelRunner):
multimodal_config: Optional[MultiModalConfig] = None,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
):
if return_hidden_states:
raise ValueError(
@ -88,6 +89,7 @@ class TP1DraftModelRunner(ModelRunner):
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
observability_config=observability_config,
)
self.flashinfer_decode_workspace_buffer = None

View File

@ -1,8 +1,8 @@
from typing import List, Optional
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.sequence import SequenceGroupMetadata
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
ModelRunner)
@ -32,7 +32,8 @@ class TargetModelRunner(ModelRunner):
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False):
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None):
# An internal boolean member variable to indicate if token log
# probabilities are needed or not.
self.disable_logprobs = True
@ -49,6 +50,7 @@ class TargetModelRunner(ModelRunner):
multimodal_config=multimodal_config,
prompt_adapter_config=prompt_adapter_config,
return_hidden_states=return_hidden_states,
observability_config=observability_config,
)
def prepare_model_input(

View File

@ -92,6 +92,12 @@ class SpanAttributes(BaseSpanAttributes):
LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
LLM_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token"
LLM_LATENCY_E2E = "gen_ai.latency.e2e"
LLM_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler"
# Time taken in the forward pass for this across all workers
LLM_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward"
# Time taken in the model execute function. This will include model
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
LLM_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute"
def contains_trace_headers(headers: Mapping[str, str]) -> bool:

View File

@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalInputs
@ -45,6 +45,7 @@ class EmbeddingModelRunner(
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
):
super().__init__(model_config,
parallel_config,
@ -56,7 +57,8 @@ class EmbeddingModelRunner(
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config)
multimodal_config=multimodal_config,
observability_config=observability_config)
@torch.inference_mode()
def execute_model(

View File

@ -10,8 +10,8 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
get_global_forced_attn_backend,
global_force_attn_backend)
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
@ -82,6 +82,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
is_driver_worker: bool = False,
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
observability_config: Optional[ObservabilityConfig] = None,
):
'''
EncoderDecoderModelRunner constructor.

View File

@ -27,8 +27,8 @@ except ImportError:
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
from vllm.distributed import get_pp_group
from vllm.distributed.parallel_state import graph_capture
from vllm.inputs import INPUT_REGISTRY
@ -806,6 +806,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
return_hidden_states: bool = False,
observability_config: Optional[ObservabilityConfig] = None,
):
self.model_config = model_config
self.parallel_config = parallel_config
@ -818,6 +819,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.prompt_adapter_config = prompt_adapter_config
self.multimodal_config = multimodal_config
self.return_hidden_states = return_hidden_states
self.observability_config = observability_config
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
@ -1527,6 +1529,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_start = torch.cuda.Event(enable_timing=True)
model_forward_end = torch.cuda.Event(enable_timing=True)
model_forward_start.record()
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
@ -1537,6 +1545,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
device=self.device),
**seqlen_agnostic_kwargs)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
model_forward_end.record()
# Compute the logits in the last pipeline stage.
if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states
@ -1552,6 +1564,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time
and output is not None):
model_forward_end.synchronize()
model_forward_time = model_forward_start.elapsed_time(
model_forward_end)
# If there are multiple workers, we are still tracking the latency
# from the start time of the driver worker to the end time of the
# driver worker. The model forward time will then end up covering
# the communication time as well.
output.model_forward_time = model_forward_time
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
@ -1709,8 +1732,9 @@ class CUDAGraphRunner:
**kwargs)
if intermediate_tensors is not None:
for key in intermediate_tensors.tensors:
self.input_buffers[key].copy_(intermediate_tensors[key],
non_blocking=True)
if key != "model_execute_time":
self.input_buffers[key].copy_(intermediate_tensors[key],
non_blocking=True)
# Run the graph.
self.graph.replay()
# Return the output tensor.

View File

@ -7,8 +7,8 @@ import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
ModelConfig, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
@ -51,6 +51,7 @@ class Worker(LocalOrDistributedWorkerBase):
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
is_driver_worker: bool = False,
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
observability_config: Optional[ObservabilityConfig] = None,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
@ -73,6 +74,7 @@ class Worker(LocalOrDistributedWorkerBase):
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.multimodal_config = multimodal_config
self.observability_config = observability_config
# Return hidden states from target model if the draft model is an
# mlp_speculator
@ -102,6 +104,7 @@ class Worker(LocalOrDistributedWorkerBase):
is_driver_worker=is_driver_worker,
prompt_adapter_config=prompt_adapter_config,
multimodal_config=multimodal_config,
observability_config=observability_config,
**speculative_args,
)
# Uninitialized cache engine. Will be initialized by

View File

@ -1,11 +1,13 @@
import dataclasses
import importlib
import os
import time
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from vllm.config import ObservabilityConfig
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -172,6 +174,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
"""
is_driver_worker: bool
model_runner: ModelRunnerBase
observability_config: Optional[ObservabilityConfig] = None
@property
@abstractmethod
@ -219,6 +222,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
start_time = time.perf_counter()
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
@ -265,21 +269,36 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return []
intermediate_tensors = None
orig_model_execute_time = 0.0
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
orig_model_execute_time = intermediate_tensors.tensors.get(
"model_execute_time", torch.tensor(0)).item()
output = self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None, intermediate_tensors,
num_steps)
model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
output.tensors["model_execute_time"] = torch.tensor(
model_execute_time + orig_model_execute_time)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return [None]
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time
and output is not None):
for o in output:
o.model_execute_time = (orig_model_execute_time +
model_execute_time)
# output is List[SamplerOutput]
return output