mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 10:54:28 +08:00
[Feature][V1]: suupports cached_tokens in response usage (#18149)
Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
parent
54af915949
commit
b046cf792d
@ -19,7 +19,8 @@ def model() -> LLM:
|
||||
enable_prefix_caching=True,
|
||||
long_prefill_token_threshold=2,
|
||||
max_num_batched_tokens=6,
|
||||
max_num_seqs=3)
|
||||
max_num_seqs=3,
|
||||
block_size=16)
|
||||
|
||||
|
||||
def test_concurrent_partial_prefill(model):
|
||||
@ -27,3 +28,11 @@ def test_concurrent_partial_prefill(model):
|
||||
assert len(outputs) == 3
|
||||
for output in outputs:
|
||||
assert len(output.outputs) == 1
|
||||
|
||||
|
||||
def test_prefix_cache_stats_is_recorded(model):
|
||||
# 17 tokens will make sure first 16 tokens are cached in a block
|
||||
input_tokens = {"prompt_token_ids": [101] * 17}
|
||||
_ = model.generate([input_tokens])
|
||||
outputs = model.generate([input_tokens])
|
||||
assert outputs[0].num_cached_tokens == 16
|
||||
|
||||
@ -457,7 +457,9 @@ class Scheduler(SchedulerInterface):
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
|
||||
# Count the number of prifix cached tokens.
|
||||
if request.num_cached_tokens < 0:
|
||||
request.num_cached_tokens = num_computed_tokens
|
||||
# Encoder-related.
|
||||
if encoder_inputs_to_schedule:
|
||||
scheduled_encoder_inputs[request.request_id] = (
|
||||
@ -798,6 +800,7 @@ class Scheduler(SchedulerInterface):
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events(),
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
num_cached_tokens=request.num_cached_tokens,
|
||||
))
|
||||
|
||||
else:
|
||||
|
||||
@ -107,6 +107,9 @@ class EngineCoreOutput(
|
||||
events: Optional[list[EngineCoreEvent]] = None
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
# The number of tokens with prefix cache hits.
|
||||
num_cached_tokens: int = 0
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self.finish_reason is not None
|
||||
|
||||
@ -147,6 +147,7 @@ class RequestState:
|
||||
finish_reason: Optional[FinishReason],
|
||||
stop_reason: Union[int, str, None],
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||
num_cached_tokens: int = 0,
|
||||
) -> Optional[RequestOutput]:
|
||||
|
||||
finished = finish_reason is not None
|
||||
@ -169,7 +170,7 @@ class RequestState:
|
||||
return None
|
||||
|
||||
return self._new_request_output(request_id, outputs, finished,
|
||||
kv_transfer_params)
|
||||
kv_transfer_params, num_cached_tokens)
|
||||
|
||||
def _new_request_output(
|
||||
self,
|
||||
@ -177,6 +178,7 @@ class RequestState:
|
||||
outputs: list[CompletionOutput],
|
||||
finished: bool,
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||
num_cached_tokens: int = 0,
|
||||
) -> RequestOutput:
|
||||
|
||||
if self.output_kind == RequestOutputKind.DELTA:
|
||||
@ -193,6 +195,7 @@ class RequestState:
|
||||
outputs=outputs,
|
||||
finished=finished,
|
||||
kv_transfer_params=kv_transfer_params,
|
||||
num_cached_tokens=num_cached_tokens,
|
||||
)
|
||||
|
||||
def _new_completion_output(
|
||||
@ -340,7 +343,7 @@ class OutputProcessor:
|
||||
finish_reason = engine_core_output.finish_reason
|
||||
stop_reason = engine_core_output.stop_reason
|
||||
kv_transfer_params = engine_core_output.kv_transfer_params
|
||||
|
||||
num_cached_tokens = engine_core_output.num_cached_tokens
|
||||
req_state.is_prefilling = False
|
||||
|
||||
# 2) Detokenize the token ids into text and perform stop checks.
|
||||
@ -356,7 +359,7 @@ class OutputProcessor:
|
||||
# 4) Create and handle RequestOutput objects.
|
||||
if request_output := req_state.make_request_output(
|
||||
new_token_ids, finish_reason, stop_reason,
|
||||
kv_transfer_params):
|
||||
kv_transfer_params, num_cached_tokens):
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
req_state.queue.put(request_output)
|
||||
|
||||
@ -77,6 +77,10 @@ class Request:
|
||||
self.output_token_ids = ConstantList(self._output_token_ids)
|
||||
self.all_token_ids = ConstantList(self._all_token_ids)
|
||||
|
||||
# State
|
||||
# The number of tokens with prefix cache hits.
|
||||
self.num_cached_tokens = -1
|
||||
|
||||
@classmethod
|
||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||
if request.mm_inputs is not None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user