[Misc] Add more scoping for improved trace (#28329)

Signed-off-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
Wei Wei 2025-11-10 13:03:21 -08:00 committed by GitHub
parent 40d33264c6
commit bf6a3d0ff5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 192 additions and 148 deletions

View File

@ -38,6 +38,7 @@ from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
logger = init_logger(__name__)
@ -259,49 +260,52 @@ class Scheduler(SchedulerInterface):
continue
# Schedule newly needed KV blocks for the request.
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
)
if new_blocks is not None:
# The request can be scheduled.
break
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[preempted_req.request_id]
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
req_index -= 1
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_preemptions += 1
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp
with record_function_or_nullcontext("schedule: allocate_slots"):
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens,
)
self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt. Cannot schedule this request.
break
if new_blocks is not None:
# The request can be scheduled.
break
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
token_budget += num_scheduled_tokens[
preempted_req.request_id
]
req_to_new_blocks.pop(preempted_req.request_id)
num_scheduled_tokens.pop(preempted_req.request_id)
req_index -= 1
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_preemptions += 1
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp
)
self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt. Cannot schedule this request.
break
if new_blocks is None:
# Cannot schedule this request.
@ -599,13 +603,14 @@ class Scheduler(SchedulerInterface):
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request.request_id
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request.request_id
)
)
)
# Construct the scheduler output.
new_reqs_data = [
@ -614,13 +619,14 @@ class Scheduler(SchedulerInterface):
)
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
with record_function_or_nullcontext("schedule: make_cached_request_data"):
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
# Record the request ids that were scheduled in this step.
self.prev_step_scheduled_req_ids.clear()
@ -649,8 +655,8 @@ class Scheduler(SchedulerInterface):
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
self._update_after_schedule(scheduler_output)
with record_function_or_nullcontext("schedule: update_after_schedule"):
self._update_after_schedule(scheduler_output)
return scheduler_output
def _update_after_schedule(

View File

@ -61,6 +61,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.utils import record_function_or_nullcontext
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -315,17 +316,21 @@ class EngineCore:
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
with record_function_or_nullcontext("core step: schedule"):
scheduler_output = self.scheduler.schedule()
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
with record_function_or_nullcontext("core step: execute_model"):
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
with record_function_or_nullcontext("core step: update_from_output"):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
@ -363,32 +368,49 @@ class EngineCore:
model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
scheduler_output = self.scheduler.schedule()
with record_function_or_nullcontext(
"core step_with_batch_queue: execute_model"
):
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
model_executed = scheduler_output.total_num_scheduled_tokens > 0
if scheduler_output.pending_structured_output_tokens:
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
assert exec_result is None
with record_function_or_nullcontext(
"core step_with_batch_queue: pending_structured_output_tokens"
):
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
# Block-wait for execute to return
# (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
assert exec_result is None
else:
# We aren't waiting for any tokens, get any grammar output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with record_function_or_nullcontext(
"core step_with_batch_queue: get_grammar_bitmask"
):
# We aren't waiting for any tokens, get any grammar
# output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(
scheduler_output
)
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
if exec_result is None:
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
with record_function_or_nullcontext(
"core step_with_batch_queue: sample_tokens"
):
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
else:
# No sampling required (e.g. all requests finished).
future = cast(Future[ModelRunnerOutput], exec_future)
@ -408,27 +430,34 @@ class EngineCore:
# only be called when the scheduler contains requests or the queue
# is non-empty.
return None, False
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
with self.log_error_detail(scheduler_output):
model_output = future.result()
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
with self.log_error_detail(scheduler_output):
model_output = future.result()
with record_function_or_nullcontext(
"core step_with_batch_queue: update_from_output"
):
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
batch_queue.appendleft((future, deferred_scheduler_output))
with record_function_or_nullcontext(
"core step_with_batch_queue: deferred_scheduler_output"
):
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
batch_queue.appendleft((future, deferred_scheduler_output))
return engine_core_outputs, model_executed

View File

@ -36,6 +36,7 @@ from vllm.v1.executor import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.utils import record_function_or_nullcontext
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__)
@ -280,28 +281,32 @@ class LLMEngine:
return []
# 1) Get EngineCoreOutput from the EngineCore.
outputs = self.engine_core.get_output()
with record_function_or_nullcontext("llm_genine step: get_output"):
outputs = self.engine_core.get_output()
# 2) Process EngineCoreOutputs.
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats,
)
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
with record_function_or_nullcontext("llm_genine step: process_outputs"):
iteration_stats = IterationStats() if self.log_stats else None
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
iteration_stats=iteration_stats,
)
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
# 3) Abort any reqs that finished due to stop strings.
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
with record_function_or_nullcontext("llm_genine step: abort_requests"):
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
# 4) Record stats
if self.logger_manager is not None and outputs.scheduler_stats is not None:
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=self.processor.stat_mm_cache(),
)
self.do_log_stats_with_interval()
with record_function_or_nullcontext("llm_genine step: record_stats"):
if self.logger_manager is not None and outputs.scheduler_stats is not None:
self.logger_manager.record(
scheduler_stats=outputs.scheduler_stats,
iteration_stats=iteration_stats,
mm_cache_stats=self.processor.stat_mm_cache(),
)
self.do_log_stats_with_interval()
return processed_outputs.request_outputs

View File

@ -2525,7 +2525,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
"after execute_model() returns None."
)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with record_function_or_nullcontext("Preprocess"):
with record_function_or_nullcontext("gpu_model_runner: preprocess"):
with self.synchronize_input_prep():
# Update persistent batch states.
self._update_states(scheduler_output)
@ -2648,7 +2648,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
),
record_function_or_nullcontext("Forward"),
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
):
model_output = self._model_forward(
@ -2659,7 +2659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
**model_kwargs,
)
with record_function_or_nullcontext("Postprocess"):
with record_function_or_nullcontext("gpu_model_runner: postprocess"):
if self.use_aux_hidden_state_outputs:
# True when EAGLE 3 is used.
hidden_states, aux_hidden_states = model_output
@ -2756,12 +2756,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output, grammar_output, self.input_batch, logits
)
with record_function_or_nullcontext("Sample"):
with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
def propose_draft_token_ids(sampled_token_ids):
assert spec_decode_common_attn_metadata is not None
with record_function_or_nullcontext("Draft"):
with record_function_or_nullcontext("gpu_model_runner: draft"):
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
sampled_token_ids,
@ -2799,7 +2799,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# as inputs, and does not need to wait for bookkeeping to finish.
propose_draft_token_ids(sampler_output.sampled_token_ids)
with record_function_or_nullcontext("Bookkeep"):
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
(
num_nans_in_logits,
logprobs_lists,
@ -2826,37 +2826,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)
with record_function_or_nullcontext("EPLB"):
with record_function_or_nullcontext("gpu_model_runner: eplb"):
self.eplb_step()
output = ModelRunnerOutput(
req_ids=req_ids_output_copy,
req_id_to_index=req_id_to_index_output_copy,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
num_nans_in_logits=num_nans_in_logits,
)
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
output = ModelRunnerOutput(
req_ids=req_ids_output_copy,
req_id_to_index=req_id_to_index_output_copy,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=kv_connector_output,
num_nans_in_logits=num_nans_in_logits,
)
if not self.use_async_scheduling:
return output
async_output = AsyncGPUModelRunnerOutput(
model_runner_output=output,
sampled_token_ids=sampler_output.sampled_token_ids,
logprobs_tensors=sampler_output.logprobs_tensors,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)
# Save ref of sampled_token_ids CPU tensor if the batch contains
# any requests with sampling params that that require output ids.
self.input_batch.set_async_sampled_token_ids(
async_output.sampled_token_ids_cpu,
async_output.async_copy_ready_event,
)
with record_function_or_nullcontext(
"gpu_model_runner: AsyncGPUModelRunnerOutput"
):
async_output = AsyncGPUModelRunnerOutput(
model_runner_output=output,
sampled_token_ids=sampler_output.sampled_token_ids,
logprobs_tensors=sampler_output.logprobs_tensors,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)
with record_function_or_nullcontext(
"gpu_model_runner: set_async_sampled_token_ids"
):
# Save ref of sampled_token_ids CPU tensor if the batch contains
# any requests with sampling params that that require output ids.
self.input_batch.set_async_sampled_token_ids(
async_output.sampled_token_ids_cpu,
async_output.async_copy_ready_event,
)
return async_output