[V1] Aggregate chunked prompt logprobs in model runner (#14875)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-03-24 09:27:57 -07:00 committed by GitHub
parent 9cc645141d
commit 3aee6573dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 68 additions and 44 deletions

View File

@ -627,8 +627,7 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
# Transmit partial if chunked prefill & prompt logprobs is enabled
if new_token_ids or prompt_logprobs_tensors is not None:
if new_token_ids:
# Add EngineCoreOutput for this Request.
outputs.append(
EngineCoreOutput(
@ -639,6 +638,9 @@ class Scheduler(SchedulerInterface):
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason,
events=request.take_events()))
else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
self.scheduled_req_ids.remove(request.request_id)
if not stopped:

View File

@ -115,7 +115,6 @@ class LogprobsProcessor:
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the torch tensors.
# TODO(rob): experiment with doing this in EngineCore?
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()

View File

@ -105,9 +105,7 @@ class RequestState:
finished = finish_reason is not None
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
if not finished and (self.is_prefilling or final_only):
if not finished and final_only:
# Only the final output is required in FINAL_ONLY mode.
return None
@ -285,19 +283,7 @@ class OutputProcessor:
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
# TODO(andy): prompt logprobs + chunked prefill can
# result in engine core returning an output for a
# partial prefill (in order to send back partial
# prompt logprobs.) This breaks the invariant that
# process_outputs is only operating on engine core
# outputs associated with non-partial completions.
# Currently this is handled by having `is_prefilling`
# check for new decoded tokens, indicating that
# the completion is not partial.
#
# Follow up will aggregate partial prompt logprobs
# in the EngineCore.
req_state.is_prefilling = not new_token_ids
req_state.is_prefilling = False
# 2) Detokenize the token ids into text and perform stop checks.
stop_string = req_state.detokenizer.update(
@ -306,8 +292,7 @@ class OutputProcessor:
finish_reason = FinishReason.STOP
stop_reason = stop_string
# 3) Compute sample and prompt logprobs for request,
# if required.
# 3) Compute sample and prompt logprobs for request, if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.

View File

@ -100,15 +100,8 @@ class IterationStats:
num_new_generation_tokens = len(output.new_token_ids)
self.num_generation_tokens += num_new_generation_tokens
if is_prefilling and num_new_generation_tokens > 0:
# TODO(andy): we used to assert that num_new_generation_tokens
# > 0 with an invariant that EngineCore does not stream outputs
# for partially completed prefills (scheduler.update_from_output
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
# When prompt logprobs are enabled, we currently stream out the
# partially completed prompt.
# This will be reverted in a follow up PR and we should re-enable
# this assertion / invariant.
if is_prefilling:
assert num_new_generation_tokens > 0
self.num_prompt_tokens += prompt_len
first_token_latency = self._time_since(req_stats.arrival_time)
@ -123,16 +116,12 @@ class IterationStats:
# Process the batch-level "new tokens" engine core event
if is_prefilling:
# TODO: re-enable no-output-for-partial-prefills invariant as above
if num_new_generation_tokens > 0:
req_stats.first_token_ts = engine_core_timestamp
req_stats.first_token_ts = engine_core_timestamp
else:
tpot = engine_core_timestamp - req_stats.last_token_ts
self.time_per_output_tokens_iter.append(tpot)
# TODO: re-enable no-output-for-partial-prefills invariant as above
if num_new_generation_tokens > 0:
req_stats.last_token_ts = engine_core_timestamp
req_stats.last_token_ts = engine_core_timestamp
def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
is_prefilling: bool, req_stats: RequestStateStats,

View File

@ -39,6 +39,25 @@ class LogprobsTensors(NamedTuple):
self.selected_token_ranks.tolist(),
)
@staticmethod
def empty_cpu(num_positions: int,
num_tokens_per_position: int) -> "LogprobsTensors":
"""Create empty LogprobsTensors on CPU."""
logprob_token_ids = torch.empty(
(num_positions, num_tokens_per_position),
dtype=torch.int32,
device="cpu")
logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
selected_token_ranks = torch.empty(num_positions,
dtype=torch.int32,
device="cpu")
return LogprobsTensors(
logprob_token_ids=logprob_token_ids,
logprobs=logprobs,
selected_token_ranks=selected_token_ranks,
)
@dataclass
class SamplerOutput:

View File

@ -11,6 +11,7 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import BlockTable
@ -197,6 +198,9 @@ class InputBatch:
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
self.logit_bias: list[Optional[dict[int,
float]]] = [None] * max_num_reqs
self.has_allowed_token_ids: set[str] = set()
@ -362,6 +366,7 @@ class InputBatch:
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]

View File

@ -1191,6 +1191,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not num_prompt_logprobs_dict:
return {}
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
# Since prompt logprobs are a rare feature, prioritize simple,
@ -1206,16 +1207,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)
# Set up target LogprobsTensors object.
logprobs_tensors = in_progress_dict.get(req_id)
if not logprobs_tensors:
# Create empty logprobs CPU tensors for the entire prompt.
# If chunked, we'll copy in slice by slice.
logprobs_tensors = LogprobsTensors.empty_cpu(
num_prompt_tokens - 1, num_prompt_logprobs + 1)
in_progress_dict[req_id] = logprobs_tensors
# Determine number of logits to retrieve.
start_tok = request.num_computed_tokens + 1
start_idx = request.num_computed_tokens
start_tok = start_idx + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens < num_remaining_tokens:
if num_tokens <= num_remaining_tokens:
# This is a chunk, more tokens remain.
# In the == case, there are no more prompt logprobs to produce
# but we want to defer returning them to the next step where we
# have new generated tokens to return.
num_logits = num_tokens
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
completed_prefill_reqs.append(req_id)
prompt_logprobs_dict[req_id] = logprobs_tensors
if num_logits <= 0:
# This can happen for the final chunk if we prefilled exactly
# (num_prompt_tokens - 1) tokens for this request in the prior
# step. There are no more prompt logprobs to produce.
continue
# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
@ -1236,19 +1257,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs, num_prompt_logprobs, tgt_token_ids)
# Transfer GPU->CPU async.
prompt_logprobs_dict[req_id] = LogprobsTensors(
token_ids.to("cpu", non_blocking=True),
logprobs.to("cpu", non_blocking=True),
ranks.to("cpu", non_blocking=True),
)
chunk_slice = slice(start_idx, start_idx + num_logits)
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
token_ids, non_blocking=True)
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
non_blocking=True)
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
ranks, non_blocking=True)
# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
for req_id in completed_prefill_reqs:
del num_prompt_logprobs_dict[req_id]
del in_progress_dict[req_id]
# Must synchronize the non-blocking GPU->CPU transfers.
torch.cuda.synchronize()
if prompt_logprobs_dict:
torch.cuda.synchronize()
return prompt_logprobs_dict