[Feature][V1]: suupports cached_tokens in response usage (#18149)

Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
Chauncey 2025-05-23 16:41:03 +08:00 committed by GitHub
parent 54af915949
commit b046cf792d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 27 additions and 5 deletions

View File

@ -19,7 +19,8 @@ def model() -> LLM:
enable_prefix_caching=True,
long_prefill_token_threshold=2,
max_num_batched_tokens=6,
max_num_seqs=3)
max_num_seqs=3,
block_size=16)
def test_concurrent_partial_prefill(model):
@ -27,3 +28,11 @@ def test_concurrent_partial_prefill(model):
assert len(outputs) == 3
for output in outputs:
assert len(output.outputs) == 1
def test_prefix_cache_stats_is_recorded(model):
# 17 tokens will make sure first 16 tokens are cached in a block
input_tokens = {"prompt_token_ids": [101] * 17}
_ = model.generate([input_tokens])
outputs = model.generate([input_tokens])
assert outputs[0].num_cached_tokens == 16

View File

@ -457,7 +457,9 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prifix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
@ -798,6 +800,7 @@ class Scheduler(SchedulerInterface):
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
num_cached_tokens=request.num_cached_tokens,
))
else:

View File

@ -107,6 +107,9 @@ class EngineCoreOutput(
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0
@property
def finished(self) -> bool:
return self.finish_reason is not None

View File

@ -147,6 +147,7 @@ class RequestState:
finish_reason: Optional[FinishReason],
stop_reason: Union[int, str, None],
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> Optional[RequestOutput]:
finished = finish_reason is not None
@ -169,7 +170,7 @@ class RequestState:
return None
return self._new_request_output(request_id, outputs, finished,
kv_transfer_params)
kv_transfer_params, num_cached_tokens)
def _new_request_output(
self,
@ -177,6 +178,7 @@ class RequestState:
outputs: list[CompletionOutput],
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
num_cached_tokens: int = 0,
) -> RequestOutput:
if self.output_kind == RequestOutputKind.DELTA:
@ -193,6 +195,7 @@ class RequestState:
outputs=outputs,
finished=finished,
kv_transfer_params=kv_transfer_params,
num_cached_tokens=num_cached_tokens,
)
def _new_completion_output(
@ -340,7 +343,7 @@ class OutputProcessor:
finish_reason = engine_core_output.finish_reason
stop_reason = engine_core_output.stop_reason
kv_transfer_params = engine_core_output.kv_transfer_params
num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
# 2) Detokenize the token ids into text and perform stop checks.
@ -356,7 +359,7 @@ class OutputProcessor:
# 4) Create and handle RequestOutput objects.
if request_output := req_state.make_request_output(
new_token_ids, finish_reason, stop_reason,
kv_transfer_params):
kv_transfer_params, num_cached_tokens):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)

View File

@ -77,6 +77,10 @@ class Request:
self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids)
# State
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1
@classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None: