mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 22:14:34 +08:00
[Misc] Add more scoping for improved trace (#28329)
Signed-off-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
parent
40d33264c6
commit
bf6a3d0ff5
@ -38,6 +38,7 @@ from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -259,49 +260,52 @@ class Scheduler(SchedulerInterface):
|
||||
continue
|
||||
|
||||
# Schedule newly needed KV blocks for the request.
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||
)
|
||||
|
||||
if new_blocks is not None:
|
||||
# The request can be scheduled.
|
||||
break
|
||||
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
if self.policy == SchedulingPolicy.PRIORITY:
|
||||
preempted_req = max(
|
||||
self.running,
|
||||
key=lambda r: (r.priority, r.arrival_time),
|
||||
)
|
||||
self.running.remove(preempted_req)
|
||||
if preempted_req in scheduled_running_reqs:
|
||||
scheduled_running_reqs.remove(preempted_req)
|
||||
token_budget += num_scheduled_tokens[preempted_req.request_id]
|
||||
req_to_new_blocks.pop(preempted_req.request_id)
|
||||
num_scheduled_tokens.pop(preempted_req.request_id)
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
self.encoder_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
preempted_req.num_preemptions += 1
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp
|
||||
with record_function_or_nullcontext("schedule: allocate_slots"):
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request,
|
||||
num_new_tokens,
|
||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||
)
|
||||
|
||||
self.waiting.prepend_request(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt. Cannot schedule this request.
|
||||
break
|
||||
if new_blocks is not None:
|
||||
# The request can be scheduled.
|
||||
break
|
||||
|
||||
# The request cannot be scheduled.
|
||||
# Preempt the lowest-priority request.
|
||||
if self.policy == SchedulingPolicy.PRIORITY:
|
||||
preempted_req = max(
|
||||
self.running,
|
||||
key=lambda r: (r.priority, r.arrival_time),
|
||||
)
|
||||
self.running.remove(preempted_req)
|
||||
if preempted_req in scheduled_running_reqs:
|
||||
scheduled_running_reqs.remove(preempted_req)
|
||||
token_budget += num_scheduled_tokens[
|
||||
preempted_req.request_id
|
||||
]
|
||||
req_to_new_blocks.pop(preempted_req.request_id)
|
||||
num_scheduled_tokens.pop(preempted_req.request_id)
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
self.encoder_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
preempted_req.num_preemptions += 1
|
||||
if self.log_stats:
|
||||
preempted_req.record_event(
|
||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp
|
||||
)
|
||||
|
||||
self.waiting.prepend_request(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt. Cannot schedule this request.
|
||||
break
|
||||
|
||||
if new_blocks is None:
|
||||
# Cannot schedule this request.
|
||||
@ -599,13 +603,14 @@ class Scheduler(SchedulerInterface):
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request.request_id
|
||||
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request.request_id
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = [
|
||||
@ -614,13 +619,14 @@ class Scheduler(SchedulerInterface):
|
||||
)
|
||||
for req in scheduled_new_reqs
|
||||
]
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
scheduled_resumed_reqs,
|
||||
num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
with record_function_or_nullcontext("schedule: make_cached_request_data"):
|
||||
cached_reqs_data = self._make_cached_request_data(
|
||||
scheduled_running_reqs,
|
||||
scheduled_resumed_reqs,
|
||||
num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
|
||||
# Record the request ids that were scheduled in this step.
|
||||
self.prev_step_scheduled_req_ids.clear()
|
||||
@ -649,8 +655,8 @@ class Scheduler(SchedulerInterface):
|
||||
if self.connector is not None:
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
self._update_after_schedule(scheduler_output)
|
||||
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
|
||||
def _update_after_schedule(
|
||||
|
||||
@ -61,6 +61,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -315,17 +316,21 @@ class EngineCore:
|
||||
# or finished and not yet removed from the batch.
|
||||
if not self.scheduler.has_requests():
|
||||
return {}, False
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
if model_output is None:
|
||||
model_output = self.model_executor.sample_tokens(grammar_output)
|
||||
with record_function_or_nullcontext("core step: schedule"):
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
with record_function_or_nullcontext("core step: execute_model"):
|
||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
if model_output is None:
|
||||
model_output = self.model_executor.sample_tokens(grammar_output)
|
||||
|
||||
with record_function_or_nullcontext("core step: update_from_output"):
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
|
||||
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
|
||||
|
||||
@ -363,32 +368,49 @@ class EngineCore:
|
||||
model_executed = False
|
||||
deferred_scheduler_output = None
|
||||
if self.scheduler.has_requests():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
exec_future = self.model_executor.execute_model(
|
||||
scheduler_output, non_block=True
|
||||
)
|
||||
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: execute_model"
|
||||
):
|
||||
exec_future = self.model_executor.execute_model(
|
||||
scheduler_output, non_block=True
|
||||
)
|
||||
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
||||
|
||||
if scheduler_output.pending_structured_output_tokens:
|
||||
# We need to defer sampling until we have processed the model output
|
||||
# from the prior step.
|
||||
deferred_scheduler_output = scheduler_output
|
||||
# Block-wait for execute to return (continues running async on the GPU).
|
||||
with self.log_error_detail(scheduler_output):
|
||||
exec_result = exec_future.result()
|
||||
assert exec_result is None
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: pending_structured_output_tokens"
|
||||
):
|
||||
# We need to defer sampling until we have processed the model output
|
||||
# from the prior step.
|
||||
deferred_scheduler_output = scheduler_output
|
||||
# Block-wait for execute to return
|
||||
# (continues running async on the GPU).
|
||||
with self.log_error_detail(scheduler_output):
|
||||
exec_result = exec_future.result()
|
||||
assert exec_result is None
|
||||
else:
|
||||
# We aren't waiting for any tokens, get any grammar output immediately.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: get_grammar_bitmask"
|
||||
):
|
||||
# We aren't waiting for any tokens, get any grammar
|
||||
# output immediately.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
scheduler_output
|
||||
)
|
||||
# Block-wait for execute to return (continues running async on the GPU).
|
||||
with self.log_error_detail(scheduler_output):
|
||||
exec_result = exec_future.result()
|
||||
|
||||
if exec_result is None:
|
||||
# Call sample tokens.
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: sample_tokens"
|
||||
):
|
||||
# Call sample tokens.
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
else:
|
||||
# No sampling required (e.g. all requests finished).
|
||||
future = cast(Future[ModelRunnerOutput], exec_future)
|
||||
@ -408,27 +430,34 @@ class EngineCore:
|
||||
# only be called when the scheduler contains requests or the queue
|
||||
# is non-empty.
|
||||
return None, False
|
||||
|
||||
# Block until the next result is available.
|
||||
future, scheduler_output = batch_queue.pop()
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
|
||||
# Block until the next result is available.
|
||||
future, scheduler_output = batch_queue.pop()
|
||||
with self.log_error_detail(scheduler_output):
|
||||
model_output = future.result()
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: update_from_output"
|
||||
):
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output
|
||||
)
|
||||
|
||||
# NOTE(nick): We can either handle the deferred tasks here or save
|
||||
# in a field and do it immediately once step_with_batch_queue is
|
||||
# re-called. The latter slightly favors TTFT over TPOT/throughput.
|
||||
if deferred_scheduler_output:
|
||||
# We now have the tokens needed to compute the bitmask for the
|
||||
# deferred request. Get the bitmask and call sample tokens.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
deferred_scheduler_output
|
||||
)
|
||||
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
|
||||
batch_queue.appendleft((future, deferred_scheduler_output))
|
||||
with record_function_or_nullcontext(
|
||||
"core step_with_batch_queue: deferred_scheduler_output"
|
||||
):
|
||||
# We now have the tokens needed to compute the bitmask for the
|
||||
# deferred request. Get the bitmask and call sample tokens.
|
||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||
deferred_scheduler_output
|
||||
)
|
||||
future = self.model_executor.sample_tokens(
|
||||
grammar_output, non_block=True
|
||||
)
|
||||
batch_queue.appendleft((future, deferred_scheduler_output))
|
||||
|
||||
return engine_core_outputs, model_executed
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ from vllm.v1.executor import Executor
|
||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
from vllm.v1.utils import record_function_or_nullcontext
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -280,28 +281,32 @@ class LLMEngine:
|
||||
return []
|
||||
|
||||
# 1) Get EngineCoreOutput from the EngineCore.
|
||||
outputs = self.engine_core.get_output()
|
||||
with record_function_or_nullcontext("llm_genine step: get_output"):
|
||||
outputs = self.engine_core.get_output()
|
||||
|
||||
# 2) Process EngineCoreOutputs.
|
||||
iteration_stats = IterationStats() if self.log_stats else None
|
||||
processed_outputs = self.output_processor.process_outputs(
|
||||
outputs.outputs,
|
||||
engine_core_timestamp=outputs.timestamp,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
with record_function_or_nullcontext("llm_genine step: process_outputs"):
|
||||
iteration_stats = IterationStats() if self.log_stats else None
|
||||
processed_outputs = self.output_processor.process_outputs(
|
||||
outputs.outputs,
|
||||
engine_core_timestamp=outputs.timestamp,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||
|
||||
# 3) Abort any reqs that finished due to stop strings.
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
with record_function_or_nullcontext("llm_genine step: abort_requests"):
|
||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||
|
||||
# 4) Record stats
|
||||
if self.logger_manager is not None and outputs.scheduler_stats is not None:
|
||||
self.logger_manager.record(
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
mm_cache_stats=self.processor.stat_mm_cache(),
|
||||
)
|
||||
self.do_log_stats_with_interval()
|
||||
with record_function_or_nullcontext("llm_genine step: record_stats"):
|
||||
if self.logger_manager is not None and outputs.scheduler_stats is not None:
|
||||
self.logger_manager.record(
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
mm_cache_stats=self.processor.stat_mm_cache(),
|
||||
)
|
||||
self.do_log_stats_with_interval()
|
||||
|
||||
return processed_outputs.request_outputs
|
||||
|
||||
|
||||
@ -2525,7 +2525,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
"after execute_model() returns None."
|
||||
)
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
with record_function_or_nullcontext("Preprocess"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: preprocess"):
|
||||
with self.synchronize_input_prep():
|
||||
# Update persistent batch states.
|
||||
self._update_states(scheduler_output)
|
||||
@ -2648,7 +2648,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
batch_descriptor=batch_descriptor,
|
||||
ubatch_slices=ubatch_slices,
|
||||
),
|
||||
record_function_or_nullcontext("Forward"),
|
||||
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
||||
):
|
||||
model_output = self._model_forward(
|
||||
@ -2659,7 +2659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
with record_function_or_nullcontext("Postprocess"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: postprocess"):
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
# True when EAGLE 3 is used.
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
@ -2756,12 +2756,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
scheduler_output, grammar_output, self.input_batch, logits
|
||||
)
|
||||
|
||||
with record_function_or_nullcontext("Sample"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: sample"):
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
def propose_draft_token_ids(sampled_token_ids):
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
with record_function_or_nullcontext("Draft"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: draft"):
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
sampled_token_ids,
|
||||
@ -2799,7 +2799,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
||||
|
||||
with record_function_or_nullcontext("Bookkeep"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
|
||||
(
|
||||
num_nans_in_logits,
|
||||
logprobs_lists,
|
||||
@ -2826,37 +2826,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# tokens on the CPU, so they are run after bookkeeping.
|
||||
propose_draft_token_ids(valid_sampled_token_ids)
|
||||
|
||||
with record_function_or_nullcontext("EPLB"):
|
||||
with record_function_or_nullcontext("gpu_model_runner: eplb"):
|
||||
self.eplb_step()
|
||||
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=req_ids_output_copy,
|
||||
req_id_to_index=req_id_to_index_output_copy,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
kv_connector_output=kv_connector_output,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=req_ids_output_copy,
|
||||
req_id_to_index=req_id_to_index_output_copy,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
kv_connector_output=kv_connector_output,
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
if not self.use_async_scheduling:
|
||||
return output
|
||||
|
||||
async_output = AsyncGPUModelRunnerOutput(
|
||||
model_runner_output=output,
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
logprobs_tensors=sampler_output.logprobs_tensors,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
)
|
||||
|
||||
# Save ref of sampled_token_ids CPU tensor if the batch contains
|
||||
# any requests with sampling params that that require output ids.
|
||||
self.input_batch.set_async_sampled_token_ids(
|
||||
async_output.sampled_token_ids_cpu,
|
||||
async_output.async_copy_ready_event,
|
||||
)
|
||||
with record_function_or_nullcontext(
|
||||
"gpu_model_runner: AsyncGPUModelRunnerOutput"
|
||||
):
|
||||
async_output = AsyncGPUModelRunnerOutput(
|
||||
model_runner_output=output,
|
||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||
logprobs_tensors=sampler_output.logprobs_tensors,
|
||||
invalid_req_indices=invalid_req_indices,
|
||||
async_output_copy_stream=self.async_output_copy_stream,
|
||||
)
|
||||
with record_function_or_nullcontext(
|
||||
"gpu_model_runner: set_async_sampled_token_ids"
|
||||
):
|
||||
# Save ref of sampled_token_ids CPU tensor if the batch contains
|
||||
# any requests with sampling params that that require output ids.
|
||||
self.input_batch.set_async_sampled_token_ids(
|
||||
async_output.sampled_token_ids_cpu,
|
||||
async_output.async_copy_ready_event,
|
||||
)
|
||||
|
||||
return async_output
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user