mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 06:44:29 +08:00
[Misc] Add more scoping for improved trace (#28329)
Signed-off-by: Wei Wei <wwei6@meta.com>
This commit is contained in:
parent
40d33264c6
commit
bf6a3d0ff5
@ -38,6 +38,7 @@ from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
|||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
from vllm.v1.utils import record_function_or_nullcontext
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -259,49 +260,52 @@ class Scheduler(SchedulerInterface):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Schedule newly needed KV blocks for the request.
|
# Schedule newly needed KV blocks for the request.
|
||||||
while True:
|
with record_function_or_nullcontext("schedule: allocate_slots"):
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
while True:
|
||||||
request,
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
num_new_tokens,
|
request,
|
||||||
num_lookahead_tokens=self.num_lookahead_tokens,
|
num_new_tokens,
|
||||||
)
|
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||||
|
|
||||||
if new_blocks is not None:
|
|
||||||
# The request can be scheduled.
|
|
||||||
break
|
|
||||||
|
|
||||||
# The request cannot be scheduled.
|
|
||||||
# Preempt the lowest-priority request.
|
|
||||||
if self.policy == SchedulingPolicy.PRIORITY:
|
|
||||||
preempted_req = max(
|
|
||||||
self.running,
|
|
||||||
key=lambda r: (r.priority, r.arrival_time),
|
|
||||||
)
|
|
||||||
self.running.remove(preempted_req)
|
|
||||||
if preempted_req in scheduled_running_reqs:
|
|
||||||
scheduled_running_reqs.remove(preempted_req)
|
|
||||||
token_budget += num_scheduled_tokens[preempted_req.request_id]
|
|
||||||
req_to_new_blocks.pop(preempted_req.request_id)
|
|
||||||
num_scheduled_tokens.pop(preempted_req.request_id)
|
|
||||||
req_index -= 1
|
|
||||||
else:
|
|
||||||
preempted_req = self.running.pop()
|
|
||||||
|
|
||||||
self.kv_cache_manager.free(preempted_req)
|
|
||||||
self.encoder_cache_manager.free(preempted_req)
|
|
||||||
preempted_req.status = RequestStatus.PREEMPTED
|
|
||||||
preempted_req.num_computed_tokens = 0
|
|
||||||
preempted_req.num_preemptions += 1
|
|
||||||
if self.log_stats:
|
|
||||||
preempted_req.record_event(
|
|
||||||
EngineCoreEventType.PREEMPTED, scheduled_timestamp
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waiting.prepend_request(preempted_req)
|
if new_blocks is not None:
|
||||||
preempted_reqs.append(preempted_req)
|
# The request can be scheduled.
|
||||||
if preempted_req == request:
|
break
|
||||||
# No more request to preempt. Cannot schedule this request.
|
|
||||||
break
|
# The request cannot be scheduled.
|
||||||
|
# Preempt the lowest-priority request.
|
||||||
|
if self.policy == SchedulingPolicy.PRIORITY:
|
||||||
|
preempted_req = max(
|
||||||
|
self.running,
|
||||||
|
key=lambda r: (r.priority, r.arrival_time),
|
||||||
|
)
|
||||||
|
self.running.remove(preempted_req)
|
||||||
|
if preempted_req in scheduled_running_reqs:
|
||||||
|
scheduled_running_reqs.remove(preempted_req)
|
||||||
|
token_budget += num_scheduled_tokens[
|
||||||
|
preempted_req.request_id
|
||||||
|
]
|
||||||
|
req_to_new_blocks.pop(preempted_req.request_id)
|
||||||
|
num_scheduled_tokens.pop(preempted_req.request_id)
|
||||||
|
req_index -= 1
|
||||||
|
else:
|
||||||
|
preempted_req = self.running.pop()
|
||||||
|
|
||||||
|
self.kv_cache_manager.free(preempted_req)
|
||||||
|
self.encoder_cache_manager.free(preempted_req)
|
||||||
|
preempted_req.status = RequestStatus.PREEMPTED
|
||||||
|
preempted_req.num_computed_tokens = 0
|
||||||
|
preempted_req.num_preemptions += 1
|
||||||
|
if self.log_stats:
|
||||||
|
preempted_req.record_event(
|
||||||
|
EngineCoreEventType.PREEMPTED, scheduled_timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
self.waiting.prepend_request(preempted_req)
|
||||||
|
preempted_reqs.append(preempted_req)
|
||||||
|
if preempted_req == request:
|
||||||
|
# No more request to preempt. Cannot schedule this request.
|
||||||
|
break
|
||||||
|
|
||||||
if new_blocks is None:
|
if new_blocks is None:
|
||||||
# Cannot schedule this request.
|
# Cannot schedule this request.
|
||||||
@ -599,13 +603,14 @@ class Scheduler(SchedulerInterface):
|
|||||||
# Get the longest common prefix among all requests in the running queue.
|
# Get the longest common prefix among all requests in the running queue.
|
||||||
# This can be potentially used for cascade attention.
|
# This can be potentially used for cascade attention.
|
||||||
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
num_common_prefix_blocks = [0] * len(self.kv_cache_config.kv_cache_groups)
|
||||||
if self.running:
|
with record_function_or_nullcontext("schedule: get_num_common_prefix_blocks"):
|
||||||
any_request = self.running[0]
|
if self.running:
|
||||||
num_common_prefix_blocks = (
|
any_request = self.running[0]
|
||||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
num_common_prefix_blocks = (
|
||||||
any_request.request_id
|
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||||
|
any_request.request_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# Construct the scheduler output.
|
# Construct the scheduler output.
|
||||||
new_reqs_data = [
|
new_reqs_data = [
|
||||||
@ -614,13 +619,14 @@ class Scheduler(SchedulerInterface):
|
|||||||
)
|
)
|
||||||
for req in scheduled_new_reqs
|
for req in scheduled_new_reqs
|
||||||
]
|
]
|
||||||
cached_reqs_data = self._make_cached_request_data(
|
with record_function_or_nullcontext("schedule: make_cached_request_data"):
|
||||||
scheduled_running_reqs,
|
cached_reqs_data = self._make_cached_request_data(
|
||||||
scheduled_resumed_reqs,
|
scheduled_running_reqs,
|
||||||
num_scheduled_tokens,
|
scheduled_resumed_reqs,
|
||||||
scheduled_spec_decode_tokens,
|
num_scheduled_tokens,
|
||||||
req_to_new_blocks,
|
scheduled_spec_decode_tokens,
|
||||||
)
|
req_to_new_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
# Record the request ids that were scheduled in this step.
|
# Record the request ids that were scheduled in this step.
|
||||||
self.prev_step_scheduled_req_ids.clear()
|
self.prev_step_scheduled_req_ids.clear()
|
||||||
@ -649,8 +655,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
if self.connector is not None:
|
if self.connector is not None:
|
||||||
meta = self.connector.build_connector_meta(scheduler_output)
|
meta = self.connector.build_connector_meta(scheduler_output)
|
||||||
scheduler_output.kv_connector_metadata = meta
|
scheduler_output.kv_connector_metadata = meta
|
||||||
|
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
||||||
self._update_after_schedule(scheduler_output)
|
self._update_after_schedule(scheduler_output)
|
||||||
return scheduler_output
|
return scheduler_output
|
||||||
|
|
||||||
def _update_after_schedule(
|
def _update_after_schedule(
|
||||||
|
|||||||
@ -61,6 +61,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
|||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
from vllm.v1.utils import record_function_or_nullcontext
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -315,17 +316,21 @@ class EngineCore:
|
|||||||
# or finished and not yet removed from the batch.
|
# or finished and not yet removed from the batch.
|
||||||
if not self.scheduler.has_requests():
|
if not self.scheduler.has_requests():
|
||||||
return {}, False
|
return {}, False
|
||||||
scheduler_output = self.scheduler.schedule()
|
with record_function_or_nullcontext("core step: schedule"):
|
||||||
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
scheduler_output = self.scheduler.schedule()
|
||||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
|
||||||
with self.log_error_detail(scheduler_output):
|
|
||||||
model_output = future.result()
|
|
||||||
if model_output is None:
|
|
||||||
model_output = self.model_executor.sample_tokens(grammar_output)
|
|
||||||
|
|
||||||
engine_core_outputs = self.scheduler.update_from_output(
|
with record_function_or_nullcontext("core step: execute_model"):
|
||||||
scheduler_output, model_output
|
future = self.model_executor.execute_model(scheduler_output, non_block=True)
|
||||||
)
|
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
||||||
|
with self.log_error_detail(scheduler_output):
|
||||||
|
model_output = future.result()
|
||||||
|
if model_output is None:
|
||||||
|
model_output = self.model_executor.sample_tokens(grammar_output)
|
||||||
|
|
||||||
|
with record_function_or_nullcontext("core step: update_from_output"):
|
||||||
|
engine_core_outputs = self.scheduler.update_from_output(
|
||||||
|
scheduler_output, model_output
|
||||||
|
)
|
||||||
|
|
||||||
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
|
return engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0
|
||||||
|
|
||||||
@ -363,32 +368,49 @@ class EngineCore:
|
|||||||
model_executed = False
|
model_executed = False
|
||||||
deferred_scheduler_output = None
|
deferred_scheduler_output = None
|
||||||
if self.scheduler.has_requests():
|
if self.scheduler.has_requests():
|
||||||
scheduler_output = self.scheduler.schedule()
|
with record_function_or_nullcontext("core step_with_batch_queue: schedule"):
|
||||||
exec_future = self.model_executor.execute_model(
|
scheduler_output = self.scheduler.schedule()
|
||||||
scheduler_output, non_block=True
|
with record_function_or_nullcontext(
|
||||||
)
|
"core step_with_batch_queue: execute_model"
|
||||||
|
):
|
||||||
|
exec_future = self.model_executor.execute_model(
|
||||||
|
scheduler_output, non_block=True
|
||||||
|
)
|
||||||
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
||||||
|
|
||||||
if scheduler_output.pending_structured_output_tokens:
|
if scheduler_output.pending_structured_output_tokens:
|
||||||
# We need to defer sampling until we have processed the model output
|
with record_function_or_nullcontext(
|
||||||
# from the prior step.
|
"core step_with_batch_queue: pending_structured_output_tokens"
|
||||||
deferred_scheduler_output = scheduler_output
|
):
|
||||||
# Block-wait for execute to return (continues running async on the GPU).
|
# We need to defer sampling until we have processed the model output
|
||||||
with self.log_error_detail(scheduler_output):
|
# from the prior step.
|
||||||
exec_result = exec_future.result()
|
deferred_scheduler_output = scheduler_output
|
||||||
assert exec_result is None
|
# Block-wait for execute to return
|
||||||
|
# (continues running async on the GPU).
|
||||||
|
with self.log_error_detail(scheduler_output):
|
||||||
|
exec_result = exec_future.result()
|
||||||
|
assert exec_result is None
|
||||||
else:
|
else:
|
||||||
# We aren't waiting for any tokens, get any grammar output immediately.
|
with record_function_or_nullcontext(
|
||||||
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
|
"core step_with_batch_queue: get_grammar_bitmask"
|
||||||
|
):
|
||||||
|
# We aren't waiting for any tokens, get any grammar
|
||||||
|
# output immediately.
|
||||||
|
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||||
|
scheduler_output
|
||||||
|
)
|
||||||
# Block-wait for execute to return (continues running async on the GPU).
|
# Block-wait for execute to return (continues running async on the GPU).
|
||||||
with self.log_error_detail(scheduler_output):
|
with self.log_error_detail(scheduler_output):
|
||||||
exec_result = exec_future.result()
|
exec_result = exec_future.result()
|
||||||
|
|
||||||
if exec_result is None:
|
if exec_result is None:
|
||||||
# Call sample tokens.
|
with record_function_or_nullcontext(
|
||||||
future = self.model_executor.sample_tokens(
|
"core step_with_batch_queue: sample_tokens"
|
||||||
grammar_output, non_block=True
|
):
|
||||||
)
|
# Call sample tokens.
|
||||||
|
future = self.model_executor.sample_tokens(
|
||||||
|
grammar_output, non_block=True
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# No sampling required (e.g. all requests finished).
|
# No sampling required (e.g. all requests finished).
|
||||||
future = cast(Future[ModelRunnerOutput], exec_future)
|
future = cast(Future[ModelRunnerOutput], exec_future)
|
||||||
@ -408,27 +430,34 @@ class EngineCore:
|
|||||||
# only be called when the scheduler contains requests or the queue
|
# only be called when the scheduler contains requests or the queue
|
||||||
# is non-empty.
|
# is non-empty.
|
||||||
return None, False
|
return None, False
|
||||||
|
with record_function_or_nullcontext("core step_with_batch_queue: model_output"):
|
||||||
# Block until the next result is available.
|
# Block until the next result is available.
|
||||||
future, scheduler_output = batch_queue.pop()
|
future, scheduler_output = batch_queue.pop()
|
||||||
with self.log_error_detail(scheduler_output):
|
with self.log_error_detail(scheduler_output):
|
||||||
model_output = future.result()
|
model_output = future.result()
|
||||||
|
with record_function_or_nullcontext(
|
||||||
engine_core_outputs = self.scheduler.update_from_output(
|
"core step_with_batch_queue: update_from_output"
|
||||||
scheduler_output, model_output
|
):
|
||||||
)
|
engine_core_outputs = self.scheduler.update_from_output(
|
||||||
|
scheduler_output, model_output
|
||||||
|
)
|
||||||
|
|
||||||
# NOTE(nick): We can either handle the deferred tasks here or save
|
# NOTE(nick): We can either handle the deferred tasks here or save
|
||||||
# in a field and do it immediately once step_with_batch_queue is
|
# in a field and do it immediately once step_with_batch_queue is
|
||||||
# re-called. The latter slightly favors TTFT over TPOT/throughput.
|
# re-called. The latter slightly favors TTFT over TPOT/throughput.
|
||||||
if deferred_scheduler_output:
|
if deferred_scheduler_output:
|
||||||
# We now have the tokens needed to compute the bitmask for the
|
with record_function_or_nullcontext(
|
||||||
# deferred request. Get the bitmask and call sample tokens.
|
"core step_with_batch_queue: deferred_scheduler_output"
|
||||||
grammar_output = self.scheduler.get_grammar_bitmask(
|
):
|
||||||
deferred_scheduler_output
|
# We now have the tokens needed to compute the bitmask for the
|
||||||
)
|
# deferred request. Get the bitmask and call sample tokens.
|
||||||
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
|
grammar_output = self.scheduler.get_grammar_bitmask(
|
||||||
batch_queue.appendleft((future, deferred_scheduler_output))
|
deferred_scheduler_output
|
||||||
|
)
|
||||||
|
future = self.model_executor.sample_tokens(
|
||||||
|
grammar_output, non_block=True
|
||||||
|
)
|
||||||
|
batch_queue.appendleft((future, deferred_scheduler_output))
|
||||||
|
|
||||||
return engine_core_outputs, model_executed
|
return engine_core_outputs, model_executed
|
||||||
|
|
||||||
|
|||||||
@ -36,6 +36,7 @@ from vllm.v1.executor import Executor
|
|||||||
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
|
||||||
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
from vllm.v1.metrics.reader import Metric, get_metrics_snapshot
|
||||||
from vllm.v1.metrics.stats import IterationStats
|
from vllm.v1.metrics.stats import IterationStats
|
||||||
|
from vllm.v1.utils import record_function_or_nullcontext
|
||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -280,28 +281,32 @@ class LLMEngine:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# 1) Get EngineCoreOutput from the EngineCore.
|
# 1) Get EngineCoreOutput from the EngineCore.
|
||||||
outputs = self.engine_core.get_output()
|
with record_function_or_nullcontext("llm_genine step: get_output"):
|
||||||
|
outputs = self.engine_core.get_output()
|
||||||
|
|
||||||
# 2) Process EngineCoreOutputs.
|
# 2) Process EngineCoreOutputs.
|
||||||
iteration_stats = IterationStats() if self.log_stats else None
|
with record_function_or_nullcontext("llm_genine step: process_outputs"):
|
||||||
processed_outputs = self.output_processor.process_outputs(
|
iteration_stats = IterationStats() if self.log_stats else None
|
||||||
outputs.outputs,
|
processed_outputs = self.output_processor.process_outputs(
|
||||||
engine_core_timestamp=outputs.timestamp,
|
outputs.outputs,
|
||||||
iteration_stats=iteration_stats,
|
engine_core_timestamp=outputs.timestamp,
|
||||||
)
|
iteration_stats=iteration_stats,
|
||||||
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
)
|
||||||
|
self.output_processor.update_scheduler_stats(outputs.scheduler_stats)
|
||||||
|
|
||||||
# 3) Abort any reqs that finished due to stop strings.
|
# 3) Abort any reqs that finished due to stop strings.
|
||||||
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
with record_function_or_nullcontext("llm_genine step: abort_requests"):
|
||||||
|
self.engine_core.abort_requests(processed_outputs.reqs_to_abort)
|
||||||
|
|
||||||
# 4) Record stats
|
# 4) Record stats
|
||||||
if self.logger_manager is not None and outputs.scheduler_stats is not None:
|
with record_function_or_nullcontext("llm_genine step: record_stats"):
|
||||||
self.logger_manager.record(
|
if self.logger_manager is not None and outputs.scheduler_stats is not None:
|
||||||
scheduler_stats=outputs.scheduler_stats,
|
self.logger_manager.record(
|
||||||
iteration_stats=iteration_stats,
|
scheduler_stats=outputs.scheduler_stats,
|
||||||
mm_cache_stats=self.processor.stat_mm_cache(),
|
iteration_stats=iteration_stats,
|
||||||
)
|
mm_cache_stats=self.processor.stat_mm_cache(),
|
||||||
self.do_log_stats_with_interval()
|
)
|
||||||
|
self.do_log_stats_with_interval()
|
||||||
|
|
||||||
return processed_outputs.request_outputs
|
return processed_outputs.request_outputs
|
||||||
|
|
||||||
|
|||||||
@ -2525,7 +2525,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
"after execute_model() returns None."
|
"after execute_model() returns None."
|
||||||
)
|
)
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
with record_function_or_nullcontext("Preprocess"):
|
with record_function_or_nullcontext("gpu_model_runner: preprocess"):
|
||||||
with self.synchronize_input_prep():
|
with self.synchronize_input_prep():
|
||||||
# Update persistent batch states.
|
# Update persistent batch states.
|
||||||
self._update_states(scheduler_output)
|
self._update_states(scheduler_output)
|
||||||
@ -2648,7 +2648,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
ubatch_slices=ubatch_slices,
|
ubatch_slices=ubatch_slices,
|
||||||
),
|
),
|
||||||
record_function_or_nullcontext("Forward"),
|
record_function_or_nullcontext("gpu_model_runner: forward"),
|
||||||
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
|
||||||
):
|
):
|
||||||
model_output = self._model_forward(
|
model_output = self._model_forward(
|
||||||
@ -2659,7 +2659,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
with record_function_or_nullcontext("Postprocess"):
|
with record_function_or_nullcontext("gpu_model_runner: postprocess"):
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
# True when EAGLE 3 is used.
|
# True when EAGLE 3 is used.
|
||||||
hidden_states, aux_hidden_states = model_output
|
hidden_states, aux_hidden_states = model_output
|
||||||
@ -2756,12 +2756,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
scheduler_output, grammar_output, self.input_batch, logits
|
scheduler_output, grammar_output, self.input_batch, logits
|
||||||
)
|
)
|
||||||
|
|
||||||
with record_function_or_nullcontext("Sample"):
|
with record_function_or_nullcontext("gpu_model_runner: sample"):
|
||||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||||
|
|
||||||
def propose_draft_token_ids(sampled_token_ids):
|
def propose_draft_token_ids(sampled_token_ids):
|
||||||
assert spec_decode_common_attn_metadata is not None
|
assert spec_decode_common_attn_metadata is not None
|
||||||
with record_function_or_nullcontext("Draft"):
|
with record_function_or_nullcontext("gpu_model_runner: draft"):
|
||||||
self._draft_token_ids = self.propose_draft_token_ids(
|
self._draft_token_ids = self.propose_draft_token_ids(
|
||||||
scheduler_output,
|
scheduler_output,
|
||||||
sampled_token_ids,
|
sampled_token_ids,
|
||||||
@ -2799,7 +2799,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||||
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
||||||
|
|
||||||
with record_function_or_nullcontext("Bookkeep"):
|
with record_function_or_nullcontext("gpu_model_runner: bookkeep"):
|
||||||
(
|
(
|
||||||
num_nans_in_logits,
|
num_nans_in_logits,
|
||||||
logprobs_lists,
|
logprobs_lists,
|
||||||
@ -2826,37 +2826,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# tokens on the CPU, so they are run after bookkeeping.
|
# tokens on the CPU, so they are run after bookkeeping.
|
||||||
propose_draft_token_ids(valid_sampled_token_ids)
|
propose_draft_token_ids(valid_sampled_token_ids)
|
||||||
|
|
||||||
with record_function_or_nullcontext("EPLB"):
|
with record_function_or_nullcontext("gpu_model_runner: eplb"):
|
||||||
self.eplb_step()
|
self.eplb_step()
|
||||||
|
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
|
||||||
output = ModelRunnerOutput(
|
output = ModelRunnerOutput(
|
||||||
req_ids=req_ids_output_copy,
|
req_ids=req_ids_output_copy,
|
||||||
req_id_to_index=req_id_to_index_output_copy,
|
req_id_to_index=req_id_to_index_output_copy,
|
||||||
sampled_token_ids=valid_sampled_token_ids,
|
sampled_token_ids=valid_sampled_token_ids,
|
||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
kv_connector_output=kv_connector_output,
|
kv_connector_output=kv_connector_output,
|
||||||
num_nans_in_logits=num_nans_in_logits,
|
num_nans_in_logits=num_nans_in_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.use_async_scheduling:
|
if not self.use_async_scheduling:
|
||||||
return output
|
return output
|
||||||
|
with record_function_or_nullcontext(
|
||||||
async_output = AsyncGPUModelRunnerOutput(
|
"gpu_model_runner: AsyncGPUModelRunnerOutput"
|
||||||
model_runner_output=output,
|
):
|
||||||
sampled_token_ids=sampler_output.sampled_token_ids,
|
async_output = AsyncGPUModelRunnerOutput(
|
||||||
logprobs_tensors=sampler_output.logprobs_tensors,
|
model_runner_output=output,
|
||||||
invalid_req_indices=invalid_req_indices,
|
sampled_token_ids=sampler_output.sampled_token_ids,
|
||||||
async_output_copy_stream=self.async_output_copy_stream,
|
logprobs_tensors=sampler_output.logprobs_tensors,
|
||||||
)
|
invalid_req_indices=invalid_req_indices,
|
||||||
|
async_output_copy_stream=self.async_output_copy_stream,
|
||||||
# Save ref of sampled_token_ids CPU tensor if the batch contains
|
)
|
||||||
# any requests with sampling params that that require output ids.
|
with record_function_or_nullcontext(
|
||||||
self.input_batch.set_async_sampled_token_ids(
|
"gpu_model_runner: set_async_sampled_token_ids"
|
||||||
async_output.sampled_token_ids_cpu,
|
):
|
||||||
async_output.async_copy_ready_event,
|
# Save ref of sampled_token_ids CPU tensor if the batch contains
|
||||||
)
|
# any requests with sampling params that that require output ids.
|
||||||
|
self.input_batch.set_async_sampled_token_ids(
|
||||||
|
async_output.sampled_token_ids_cpu,
|
||||||
|
async_output.async_copy_ready_event,
|
||||||
|
)
|
||||||
|
|
||||||
return async_output
|
return async_output
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user