mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Core] Add span metrics for model_forward, scheduler and sampler time (#7089)
This commit is contained in:
parent
70d268a399
commit
933790c209
@ -114,3 +114,5 @@ def test_traces(trace_service):
|
||||
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN) == ttft
|
||||
e2e_time = metrics.finished_time - metrics.arrival_time
|
||||
assert attributes.get(SpanAttributes.LLM_LATENCY_E2E) == e2e_time
|
||||
assert attributes.get(SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER
|
||||
) == metrics.scheduler_time
|
||||
|
||||
@ -24,6 +24,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
|
||||
load_config=engine_config.load_config,
|
||||
lora_config=engine_config.lora_config,
|
||||
prompt_adapter_config=engine_config.prompt_adapter_config,
|
||||
observability_config=engine_config.observability_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
return model_runner
|
||||
|
||||
@ -1656,11 +1656,26 @@ class ObservabilityConfig:
|
||||
"""Configuration for observability."""
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
|
||||
# Collecting detailed timing information for each request can be expensive.
|
||||
|
||||
# If set, collects the model forward time for the request.
|
||||
collect_model_forward_time: bool = False
|
||||
|
||||
# If set, collects the model execute time for the request.
|
||||
collect_model_execute_time: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if not is_otel_installed() and self.otlp_traces_endpoint is not None:
|
||||
raise ValueError("OpenTelemetry packages must be installed before "
|
||||
"configuring 'otlp_traces_endpoint'")
|
||||
|
||||
if ((self.collect_model_forward_time
|
||||
or self.collect_model_execute_time)
|
||||
and self.otlp_traces_endpoint is None):
|
||||
raise ValueError(
|
||||
"collect_model_forward_time or collect_model_execute_time "
|
||||
"requires --otlp-traces-endpoint to be set.")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EngineConfig:
|
||||
|
||||
@ -1032,6 +1032,7 @@ class Scheduler:
|
||||
# such as self.running, self.swapped, and self.waiting.
|
||||
scheduler_outputs = self._schedule()
|
||||
now = time.time()
|
||||
scheduler_start_time = time.perf_counter()
|
||||
|
||||
if not self.cache_config.enable_prefix_caching:
|
||||
common_computed_block_nums = []
|
||||
@ -1127,6 +1128,17 @@ class Scheduler:
|
||||
|
||||
self._seq_group_metadata_cache.reset()
|
||||
|
||||
scheduler_time = time.perf_counter() - scheduler_start_time
|
||||
# Add this to scheduler time to all the sequences that are currently
|
||||
# running. This will help estimate if the scheduler is a significant
|
||||
# component in the e2e latency.
|
||||
for seq_group in self.running:
|
||||
if seq_group is not None and seq_group.metrics is not None:
|
||||
if seq_group.metrics.scheduler_time is not None:
|
||||
seq_group.metrics.scheduler_time += scheduler_time
|
||||
else:
|
||||
seq_group.metrics.scheduler_time = scheduler_time
|
||||
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
|
||||
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
|
||||
@ -20,6 +20,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
|
||||
|
||||
|
||||
def nullable_str(val: str):
|
||||
if not val or val == "None":
|
||||
@ -117,6 +119,7 @@ class EngineArgs:
|
||||
disable_logprobs_during_spec_decoding: Optional[bool] = None
|
||||
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
collect_detailed_traces: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
@ -660,6 +663,16 @@ class EngineArgs:
|
||||
type=str,
|
||||
default=None,
|
||||
help='Target URL to which OpenTelemetry traces will be sent.')
|
||||
parser.add_argument(
|
||||
'--collect-detailed-traces',
|
||||
type=str,
|
||||
default=None,
|
||||
help="Valid choices are " +
|
||||
",".join(ALLOWED_DETAILED_TRACE_MODULES) +
|
||||
". It makes sense to set this only if --otlp-traces-endpoint is"
|
||||
" set. If set, it will collect detailed traces for the specified "
|
||||
"modules. This involves use of possibly costly and or blocking "
|
||||
"operations and hence might have a performance impact.")
|
||||
|
||||
return parser
|
||||
|
||||
@ -852,8 +865,26 @@ class EngineArgs:
|
||||
decoding_config = DecodingConfig(
|
||||
guided_decoding_backend=self.guided_decoding_backend)
|
||||
|
||||
detailed_trace_modules = []
|
||||
if self.collect_detailed_traces is not None:
|
||||
detailed_trace_modules = self.collect_detailed_traces.split(",")
|
||||
for m in detailed_trace_modules:
|
||||
if m not in ALLOWED_DETAILED_TRACE_MODULES:
|
||||
raise ValueError(
|
||||
f"Invalid module {m} in collect_detailed_traces. "
|
||||
f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}")
|
||||
if (m == "model"
|
||||
or m == "all") and self.pipeline_parallel_size > 1:
|
||||
raise ValueError(
|
||||
"Collection of detailed traces for the 'model' module is "
|
||||
"not yet supported with pipeline parallelism.")
|
||||
observability_config = ObservabilityConfig(
|
||||
otlp_traces_endpoint=self.otlp_traces_endpoint)
|
||||
otlp_traces_endpoint=self.otlp_traces_endpoint,
|
||||
collect_model_forward_time="model" in detailed_trace_modules
|
||||
or "all" in detailed_trace_modules,
|
||||
collect_model_execute_time="worker" in detailed_trace_modules
|
||||
or "all" in detailed_trace_modules,
|
||||
)
|
||||
|
||||
if (model_config.get_sliding_window() is not None
|
||||
and scheduler_config.chunked_prefill_enabled
|
||||
|
||||
@ -267,6 +267,7 @@ class LLMEngine:
|
||||
speculative_config=speculative_config,
|
||||
load_config=load_config,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
observability_config=self.observability_config,
|
||||
)
|
||||
|
||||
if not self.model_config.embedding_mode:
|
||||
@ -1183,6 +1184,22 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
if output is not None and len(output) > 0:
|
||||
for o in output:
|
||||
if (isinstance(o, SamplerOutput)
|
||||
and seq_group.metrics is not None):
|
||||
if seq_group.metrics.model_forward_time is not None:
|
||||
seq_group.metrics.model_forward_time += (
|
||||
o.model_forward_time)
|
||||
else:
|
||||
seq_group.metrics.model_forward_time = (
|
||||
o.model_forward_time)
|
||||
if seq_group.metrics.model_execute_time is not None:
|
||||
seq_group.metrics.model_execute_time += (
|
||||
o.model_execute_time)
|
||||
else:
|
||||
seq_group.metrics.model_execute_time = (
|
||||
o.model_execute_time)
|
||||
if self.model_config.embedding_mode:
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
continue
|
||||
@ -1575,6 +1592,18 @@ class LLMEngine:
|
||||
seq_span.set_attribute(
|
||||
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
|
||||
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
|
||||
if metrics.scheduler_time is not None:
|
||||
seq_span.set_attribute(
|
||||
SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER,
|
||||
metrics.scheduler_time)
|
||||
if metrics.model_forward_time is not None:
|
||||
seq_span.set_attribute(
|
||||
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD,
|
||||
metrics.model_forward_time / 1000.0)
|
||||
if metrics.model_execute_time is not None:
|
||||
seq_span.set_attribute(
|
||||
SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE,
|
||||
metrics.model_execute_time)
|
||||
|
||||
def is_encoder_decoder_model(self):
|
||||
return self.model_config.is_encoder_decoder_model
|
||||
|
||||
@ -2,8 +2,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
ModelConfig, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
@ -32,6 +32,7 @@ class ExecutorBase(ABC):
|
||||
multimodal_config: Optional[MultiModalConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig],
|
||||
observability_config: Optional[ObservabilityConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@ -43,7 +44,7 @@ class ExecutorBase(ABC):
|
||||
self.multimodal_config = multimodal_config
|
||||
self.speculative_config = speculative_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
|
||||
self.observability_config = observability_config
|
||||
self._init_executor()
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -60,6 +60,7 @@ class GPUExecutor(ExecutorBase):
|
||||
prompt_adapter_config=self.prompt_adapter_config,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
observability_config=self.observability_config,
|
||||
)
|
||||
|
||||
def _get_create_worker_kwargs(
|
||||
|
||||
@ -92,6 +92,13 @@ class RequestMetrics:
|
||||
first_token_time: The time when the first token was generated.
|
||||
time_in_queue: The time the request spent in the queue.
|
||||
finished_time: The time when the request was finished.
|
||||
scheduler_time: The time spent in the scheduler when this request was
|
||||
being considered by the scheduler.
|
||||
model_forward_time: The time spent in the model forward pass when this
|
||||
request was in the batch.
|
||||
model_execute_time: The time spent in the model execute function. This
|
||||
will include model forward, block/sync across
|
||||
workers, cpu-gpu sync time and sampling time.
|
||||
"""
|
||||
arrival_time: float
|
||||
last_token_time: float
|
||||
@ -99,6 +106,9 @@ class RequestMetrics:
|
||||
first_token_time: Optional[float]
|
||||
time_in_queue: Optional[float]
|
||||
finished_time: Optional[float] = None
|
||||
scheduler_time: Optional[float] = None
|
||||
model_forward_time: Optional[float] = None
|
||||
model_execute_time: Optional[float] = None
|
||||
|
||||
|
||||
class SequenceData:
|
||||
@ -968,6 +978,13 @@ class SamplerOutput:
|
||||
# Optional last hidden states from the model.
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
|
||||
# Time taken in the forward pass for this across all workers
|
||||
model_forward_time: Optional[float] = None
|
||||
|
||||
# Time taken in the model execute function. This will include model forward,
|
||||
# block/sync across workers, cpu-gpu sync time and sampling time.
|
||||
model_execute_time: Optional[float] = None
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self.outputs[idx]
|
||||
|
||||
|
||||
@ -23,8 +23,8 @@ except ImportError:
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
ModelConfig, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalInputs
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
@ -69,6 +69,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
):
|
||||
if return_hidden_states:
|
||||
raise ValueError(
|
||||
@ -88,6 +89,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
multimodal_config=multimodal_config,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
return_hidden_states=return_hidden_states,
|
||||
observability_config=observability_config,
|
||||
)
|
||||
|
||||
self.flashinfer_decode_workspace_buffer = None
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
ModelConfig, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
|
||||
ModelRunner)
|
||||
@ -32,7 +32,8 @@ class TargetModelRunner(ModelRunner):
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
return_hidden_states: bool = False):
|
||||
return_hidden_states: bool = False,
|
||||
observability_config: Optional[ObservabilityConfig] = None):
|
||||
# An internal boolean member variable to indicate if token log
|
||||
# probabilities are needed or not.
|
||||
self.disable_logprobs = True
|
||||
@ -49,6 +50,7 @@ class TargetModelRunner(ModelRunner):
|
||||
multimodal_config=multimodal_config,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
return_hidden_states=return_hidden_states,
|
||||
observability_config=observability_config,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
|
||||
@ -92,6 +92,12 @@ class SpanAttributes(BaseSpanAttributes):
|
||||
LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
|
||||
LLM_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token"
|
||||
LLM_LATENCY_E2E = "gen_ai.latency.e2e"
|
||||
LLM_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler"
|
||||
# Time taken in the forward pass for this across all workers
|
||||
LLM_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward"
|
||||
# Time taken in the model execute function. This will include model
|
||||
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
|
||||
LLM_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute"
|
||||
|
||||
|
||||
def contains_trace_headers(headers: Mapping[str, str]) -> bool:
|
||||
|
||||
@ -4,8 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
ModelConfig, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MultiModalInputs
|
||||
@ -45,6 +45,7 @@ class EmbeddingModelRunner(
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
):
|
||||
super().__init__(model_config,
|
||||
parallel_config,
|
||||
@ -56,7 +57,8 @@ class EmbeddingModelRunner(
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
multimodal_config=multimodal_config)
|
||||
multimodal_config=multimodal_config,
|
||||
observability_config=observability_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
|
||||
@ -10,8 +10,8 @@ from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
|
||||
get_global_forced_attn_backend,
|
||||
global_force_attn_backend)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
ModelConfig, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
@ -82,6 +82,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
is_driver_worker: bool = False,
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
):
|
||||
'''
|
||||
EncoderDecoderModelRunner constructor.
|
||||
|
||||
@ -27,8 +27,8 @@ except ImportError:
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
ModelConfig, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
@ -806,6 +806,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
multimodal_config: Optional[MultiModalConfig] = None,
|
||||
return_hidden_states: bool = False,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
):
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
@ -818,6 +819,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.return_hidden_states = return_hidden_states
|
||||
self.observability_config = observability_config
|
||||
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
@ -1527,6 +1529,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
"finished_requests_ids": model_input.finished_requests_ids,
|
||||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
|
||||
} if self.has_seqlen_agnostic else {}
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
@ -1537,6 +1545,10 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
device=self.device),
|
||||
**seqlen_agnostic_kwargs)
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.record()
|
||||
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
return hidden_or_intermediate_states
|
||||
@ -1552,6 +1564,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
# If there are multiple workers, we are still tracking the latency
|
||||
# from the start time of the driver worker to the end time of the
|
||||
# driver worker. The model forward time will then end up covering
|
||||
# the communication time as well.
|
||||
output.model_forward_time = model_forward_time
|
||||
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
@ -1709,8 +1732,9 @@ class CUDAGraphRunner:
|
||||
**kwargs)
|
||||
if intermediate_tensors is not None:
|
||||
for key in intermediate_tensors.tensors:
|
||||
self.input_buffers[key].copy_(intermediate_tensors[key],
|
||||
non_blocking=True)
|
||||
if key != "model_execute_time":
|
||||
self.input_buffers[key].copy_(intermediate_tensors[key],
|
||||
non_blocking=True)
|
||||
# Run the graph.
|
||||
self.graph.replay()
|
||||
# Return the output tensor.
|
||||
|
||||
@ -7,8 +7,8 @@ import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
ModelConfig, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
@ -51,6 +51,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None,
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None,
|
||||
observability_config: Optional[ObservabilityConfig] = None,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
@ -73,6 +74,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
self.multimodal_config = multimodal_config
|
||||
self.observability_config = observability_config
|
||||
|
||||
# Return hidden states from target model if the draft model is an
|
||||
# mlp_speculator
|
||||
@ -102,6 +104,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
is_driver_worker=is_driver_worker,
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
multimodal_config=multimodal_config,
|
||||
observability_config=observability_config,
|
||||
**speculative_args,
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
import dataclasses
|
||||
import importlib
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ObservabilityConfig
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -172,6 +174,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
"""
|
||||
is_driver_worker: bool
|
||||
model_runner: ModelRunnerBase
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -219,6 +222,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes at least one model step on the given sequences, unless no
|
||||
sequences are provided."""
|
||||
start_time = time.perf_counter()
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
@ -265,21 +269,36 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
return []
|
||||
|
||||
intermediate_tensors = None
|
||||
orig_model_execute_time = 0.0
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
orig_model_execute_time = intermediate_tensors.tensors.get(
|
||||
"model_execute_time", torch.tensor(0)).item()
|
||||
|
||||
output = self.model_runner.execute_model(
|
||||
model_input, self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None, intermediate_tensors,
|
||||
num_steps)
|
||||
|
||||
model_execute_time = time.perf_counter() - start_time
|
||||
if not get_pp_group().is_last_rank:
|
||||
# output is IntermediateTensors
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
output.tensors["model_execute_time"] = torch.tensor(
|
||||
model_execute_time + orig_model_execute_time)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return [None]
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time
|
||||
and output is not None):
|
||||
for o in output:
|
||||
o.model_execute_time = (orig_model_execute_time +
|
||||
model_execute_time)
|
||||
|
||||
# output is List[SamplerOutput]
|
||||
return output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user