mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-18 08:24:28 +08:00
[V1] Aggregate chunked prompt logprobs in model runner (#14875)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
9cc645141d
commit
3aee6573dc
@ -627,8 +627,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
# Transmit partial if chunked prefill & prompt logprobs is enabled
|
||||
if new_token_ids or prompt_logprobs_tensors is not None:
|
||||
if new_token_ids:
|
||||
# Add EngineCoreOutput for this Request.
|
||||
outputs.append(
|
||||
EngineCoreOutput(
|
||||
@ -639,6 +638,9 @@ class Scheduler(SchedulerInterface):
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events()))
|
||||
else:
|
||||
# Invariant: EngineCore returns no partial prefill outputs.
|
||||
assert not prompt_logprobs_tensors
|
||||
|
||||
self.scheduled_req_ids.remove(request.request_id)
|
||||
if not stopped:
|
||||
|
||||
@ -115,7 +115,6 @@ class LogprobsProcessor:
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
|
||||
# Pythonize the torch tensors.
|
||||
# TODO(rob): experiment with doing this in EngineCore?
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids = token_ids.tolist()
|
||||
|
||||
@ -105,9 +105,7 @@ class RequestState:
|
||||
finished = finish_reason is not None
|
||||
final_only = self.output_kind == RequestOutputKind.FINAL_ONLY
|
||||
|
||||
# In follow up, we will switch to invariant where EngineCore
|
||||
# does not stream partial prefills.
|
||||
if not finished and (self.is_prefilling or final_only):
|
||||
if not finished and final_only:
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
@ -285,19 +283,7 @@ class OutputProcessor:
|
||||
finish_reason = engine_core_output.finish_reason
|
||||
stop_reason = engine_core_output.stop_reason
|
||||
|
||||
# TODO(andy): prompt logprobs + chunked prefill can
|
||||
# result in engine core returning an output for a
|
||||
# partial prefill (in order to send back partial
|
||||
# prompt logprobs.) This breaks the invariant that
|
||||
# process_outputs is only operating on engine core
|
||||
# outputs associated with non-partial completions.
|
||||
# Currently this is handled by having `is_prefilling`
|
||||
# check for new decoded tokens, indicating that
|
||||
# the completion is not partial.
|
||||
#
|
||||
# Follow up will aggregate partial prompt logprobs
|
||||
# in the EngineCore.
|
||||
req_state.is_prefilling = not new_token_ids
|
||||
req_state.is_prefilling = False
|
||||
|
||||
# 2) Detokenize the token ids into text and perform stop checks.
|
||||
stop_string = req_state.detokenizer.update(
|
||||
@ -306,8 +292,7 @@ class OutputProcessor:
|
||||
finish_reason = FinishReason.STOP
|
||||
stop_reason = stop_string
|
||||
|
||||
# 3) Compute sample and prompt logprobs for request,
|
||||
# if required.
|
||||
# 3) Compute sample and prompt logprobs for request, if required.
|
||||
req_state.logprobs_processor.update_from_output(engine_core_output)
|
||||
|
||||
# 4) Create and handle RequestOutput objects.
|
||||
|
||||
@ -100,15 +100,8 @@ class IterationStats:
|
||||
num_new_generation_tokens = len(output.new_token_ids)
|
||||
|
||||
self.num_generation_tokens += num_new_generation_tokens
|
||||
if is_prefilling and num_new_generation_tokens > 0:
|
||||
# TODO(andy): we used to assert that num_new_generation_tokens
|
||||
# > 0 with an invariant that EngineCore does not stream outputs
|
||||
# for partially completed prefills (scheduler.update_from_output
|
||||
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
|
||||
# When prompt logprobs are enabled, we currently stream out the
|
||||
# partially completed prompt.
|
||||
# This will be reverted in a follow up PR and we should re-enable
|
||||
# this assertion / invariant.
|
||||
if is_prefilling:
|
||||
assert num_new_generation_tokens > 0
|
||||
self.num_prompt_tokens += prompt_len
|
||||
|
||||
first_token_latency = self._time_since(req_stats.arrival_time)
|
||||
@ -123,16 +116,12 @@ class IterationStats:
|
||||
|
||||
# Process the batch-level "new tokens" engine core event
|
||||
if is_prefilling:
|
||||
# TODO: re-enable no-output-for-partial-prefills invariant as above
|
||||
if num_new_generation_tokens > 0:
|
||||
req_stats.first_token_ts = engine_core_timestamp
|
||||
req_stats.first_token_ts = engine_core_timestamp
|
||||
else:
|
||||
tpot = engine_core_timestamp - req_stats.last_token_ts
|
||||
self.time_per_output_tokens_iter.append(tpot)
|
||||
|
||||
# TODO: re-enable no-output-for-partial-prefills invariant as above
|
||||
if num_new_generation_tokens > 0:
|
||||
req_stats.last_token_ts = engine_core_timestamp
|
||||
req_stats.last_token_ts = engine_core_timestamp
|
||||
|
||||
def update_from_events(self, req_id: str, events: list["EngineCoreEvent"],
|
||||
is_prefilling: bool, req_stats: RequestStateStats,
|
||||
|
||||
@ -39,6 +39,25 @@ class LogprobsTensors(NamedTuple):
|
||||
self.selected_token_ranks.tolist(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def empty_cpu(num_positions: int,
|
||||
num_tokens_per_position: int) -> "LogprobsTensors":
|
||||
"""Create empty LogprobsTensors on CPU."""
|
||||
|
||||
logprob_token_ids = torch.empty(
|
||||
(num_positions, num_tokens_per_position),
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
logprobs = torch.empty_like(logprob_token_ids, dtype=torch.float32)
|
||||
selected_token_ranks = torch.empty(num_positions,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
return LogprobsTensors(
|
||||
logprob_token_ids=logprob_token_ids,
|
||||
logprobs=logprobs,
|
||||
selected_token_ranks=selected_token_ranks,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
@ -197,6 +198,9 @@ class InputBatch:
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: dict[str, int] = {}
|
||||
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
|
||||
self.logit_bias: list[Optional[dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
@ -362,6 +366,7 @@ class InputBatch:
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
|
||||
@ -1191,6 +1191,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if not num_prompt_logprobs_dict:
|
||||
return {}
|
||||
|
||||
in_progress_dict = self.input_batch.in_progress_prompt_logprobs_cpu
|
||||
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {}
|
||||
|
||||
# Since prompt logprobs are a rare feature, prioritize simple,
|
||||
@ -1206,16 +1207,36 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
# Set up target LogprobsTensors object.
|
||||
logprobs_tensors = in_progress_dict.get(req_id)
|
||||
if not logprobs_tensors:
|
||||
# Create empty logprobs CPU tensors for the entire prompt.
|
||||
# If chunked, we'll copy in slice by slice.
|
||||
logprobs_tensors = LogprobsTensors.empty_cpu(
|
||||
num_prompt_tokens - 1, num_prompt_logprobs + 1)
|
||||
in_progress_dict[req_id] = logprobs_tensors
|
||||
|
||||
# Determine number of logits to retrieve.
|
||||
start_tok = request.num_computed_tokens + 1
|
||||
start_idx = request.num_computed_tokens
|
||||
start_tok = start_idx + 1
|
||||
num_remaining_tokens = num_prompt_tokens - start_tok
|
||||
if num_tokens < num_remaining_tokens:
|
||||
if num_tokens <= num_remaining_tokens:
|
||||
# This is a chunk, more tokens remain.
|
||||
# In the == case, there are no more prompt logprobs to produce
|
||||
# but we want to defer returning them to the next step where we
|
||||
# have new generated tokens to return.
|
||||
num_logits = num_tokens
|
||||
else:
|
||||
# This is the last chunk of prompt tokens to return.
|
||||
num_logits = num_remaining_tokens
|
||||
completed_prefill_reqs.append(req_id)
|
||||
prompt_logprobs_dict[req_id] = logprobs_tensors
|
||||
|
||||
if num_logits <= 0:
|
||||
# This can happen for the final chunk if we prefilled exactly
|
||||
# (num_prompt_tokens - 1) tokens for this request in the prior
|
||||
# step. There are no more prompt logprobs to produce.
|
||||
continue
|
||||
|
||||
# Get the logits corresponding to this req's prompt tokens.
|
||||
# If this is a partial request (i.e. chunked prefill),
|
||||
@ -1236,19 +1257,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs, num_prompt_logprobs, tgt_token_ids)
|
||||
|
||||
# Transfer GPU->CPU async.
|
||||
prompt_logprobs_dict[req_id] = LogprobsTensors(
|
||||
token_ids.to("cpu", non_blocking=True),
|
||||
logprobs.to("cpu", non_blocking=True),
|
||||
ranks.to("cpu", non_blocking=True),
|
||||
)
|
||||
chunk_slice = slice(start_idx, start_idx + num_logits)
|
||||
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(
|
||||
token_ids, non_blocking=True)
|
||||
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs,
|
||||
non_blocking=True)
|
||||
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(
|
||||
ranks, non_blocking=True)
|
||||
|
||||
# Remove requests that have completed prefill from the batch
|
||||
# num_prompt_logprobs_dict.
|
||||
for req_id in completed_prefill_reqs:
|
||||
del num_prompt_logprobs_dict[req_id]
|
||||
del in_progress_dict[req_id]
|
||||
|
||||
# Must synchronize the non-blocking GPU->CPU transfers.
|
||||
torch.cuda.synchronize()
|
||||
if prompt_logprobs_dict:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return prompt_logprobs_dict
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user