mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-19 09:57:00 +08:00
[Feature][V1]: suupports cached_tokens in response usage (#18149)
Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
parent
54af915949
commit
b046cf792d
@ -19,7 +19,8 @@ def model() -> LLM:
|
|||||||
enable_prefix_caching=True,
|
enable_prefix_caching=True,
|
||||||
long_prefill_token_threshold=2,
|
long_prefill_token_threshold=2,
|
||||||
max_num_batched_tokens=6,
|
max_num_batched_tokens=6,
|
||||||
max_num_seqs=3)
|
max_num_seqs=3,
|
||||||
|
block_size=16)
|
||||||
|
|
||||||
|
|
||||||
def test_concurrent_partial_prefill(model):
|
def test_concurrent_partial_prefill(model):
|
||||||
@ -27,3 +28,11 @@ def test_concurrent_partial_prefill(model):
|
|||||||
assert len(outputs) == 3
|
assert len(outputs) == 3
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
assert len(output.outputs) == 1
|
assert len(output.outputs) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix_cache_stats_is_recorded(model):
|
||||||
|
# 17 tokens will make sure first 16 tokens are cached in a block
|
||||||
|
input_tokens = {"prompt_token_ids": [101] * 17}
|
||||||
|
_ = model.generate([input_tokens])
|
||||||
|
outputs = model.generate([input_tokens])
|
||||||
|
assert outputs[0].num_cached_tokens == 16
|
||||||
|
|||||||
@ -457,7 +457,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
request.status = RequestStatus.RUNNING
|
request.status = RequestStatus.RUNNING
|
||||||
request.num_computed_tokens = num_computed_tokens
|
request.num_computed_tokens = num_computed_tokens
|
||||||
|
# Count the number of prifix cached tokens.
|
||||||
|
if request.num_cached_tokens < 0:
|
||||||
|
request.num_cached_tokens = num_computed_tokens
|
||||||
# Encoder-related.
|
# Encoder-related.
|
||||||
if encoder_inputs_to_schedule:
|
if encoder_inputs_to_schedule:
|
||||||
scheduled_encoder_inputs[request.request_id] = (
|
scheduled_encoder_inputs[request.request_id] = (
|
||||||
@ -798,6 +800,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
stop_reason=request.stop_reason,
|
stop_reason=request.stop_reason,
|
||||||
events=request.take_events(),
|
events=request.take_events(),
|
||||||
kv_transfer_params=kv_transfer_params,
|
kv_transfer_params=kv_transfer_params,
|
||||||
|
num_cached_tokens=request.num_cached_tokens,
|
||||||
))
|
))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -107,6 +107,9 @@ class EngineCoreOutput(
|
|||||||
events: Optional[list[EngineCoreEvent]] = None
|
events: Optional[list[EngineCoreEvent]] = None
|
||||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
# The number of tokens with prefix cache hits.
|
||||||
|
num_cached_tokens: int = 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def finished(self) -> bool:
|
def finished(self) -> bool:
|
||||||
return self.finish_reason is not None
|
return self.finish_reason is not None
|
||||||
|
|||||||
@ -147,6 +147,7 @@ class RequestState:
|
|||||||
finish_reason: Optional[FinishReason],
|
finish_reason: Optional[FinishReason],
|
||||||
stop_reason: Union[int, str, None],
|
stop_reason: Union[int, str, None],
|
||||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||||
|
num_cached_tokens: int = 0,
|
||||||
) -> Optional[RequestOutput]:
|
) -> Optional[RequestOutput]:
|
||||||
|
|
||||||
finished = finish_reason is not None
|
finished = finish_reason is not None
|
||||||
@ -169,7 +170,7 @@ class RequestState:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return self._new_request_output(request_id, outputs, finished,
|
return self._new_request_output(request_id, outputs, finished,
|
||||||
kv_transfer_params)
|
kv_transfer_params, num_cached_tokens)
|
||||||
|
|
||||||
def _new_request_output(
|
def _new_request_output(
|
||||||
self,
|
self,
|
||||||
@ -177,6 +178,7 @@ class RequestState:
|
|||||||
outputs: list[CompletionOutput],
|
outputs: list[CompletionOutput],
|
||||||
finished: bool,
|
finished: bool,
|
||||||
kv_transfer_params: Optional[dict[str, Any]] = None,
|
kv_transfer_params: Optional[dict[str, Any]] = None,
|
||||||
|
num_cached_tokens: int = 0,
|
||||||
) -> RequestOutput:
|
) -> RequestOutput:
|
||||||
|
|
||||||
if self.output_kind == RequestOutputKind.DELTA:
|
if self.output_kind == RequestOutputKind.DELTA:
|
||||||
@ -193,6 +195,7 @@ class RequestState:
|
|||||||
outputs=outputs,
|
outputs=outputs,
|
||||||
finished=finished,
|
finished=finished,
|
||||||
kv_transfer_params=kv_transfer_params,
|
kv_transfer_params=kv_transfer_params,
|
||||||
|
num_cached_tokens=num_cached_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _new_completion_output(
|
def _new_completion_output(
|
||||||
@ -340,7 +343,7 @@ class OutputProcessor:
|
|||||||
finish_reason = engine_core_output.finish_reason
|
finish_reason = engine_core_output.finish_reason
|
||||||
stop_reason = engine_core_output.stop_reason
|
stop_reason = engine_core_output.stop_reason
|
||||||
kv_transfer_params = engine_core_output.kv_transfer_params
|
kv_transfer_params = engine_core_output.kv_transfer_params
|
||||||
|
num_cached_tokens = engine_core_output.num_cached_tokens
|
||||||
req_state.is_prefilling = False
|
req_state.is_prefilling = False
|
||||||
|
|
||||||
# 2) Detokenize the token ids into text and perform stop checks.
|
# 2) Detokenize the token ids into text and perform stop checks.
|
||||||
@ -356,7 +359,7 @@ class OutputProcessor:
|
|||||||
# 4) Create and handle RequestOutput objects.
|
# 4) Create and handle RequestOutput objects.
|
||||||
if request_output := req_state.make_request_output(
|
if request_output := req_state.make_request_output(
|
||||||
new_token_ids, finish_reason, stop_reason,
|
new_token_ids, finish_reason, stop_reason,
|
||||||
kv_transfer_params):
|
kv_transfer_params, num_cached_tokens):
|
||||||
if req_state.queue is not None:
|
if req_state.queue is not None:
|
||||||
# AsyncLLM: put into queue for handling by generate().
|
# AsyncLLM: put into queue for handling by generate().
|
||||||
req_state.queue.put(request_output)
|
req_state.queue.put(request_output)
|
||||||
|
|||||||
@ -77,6 +77,10 @@ class Request:
|
|||||||
self.output_token_ids = ConstantList(self._output_token_ids)
|
self.output_token_ids = ConstantList(self._output_token_ids)
|
||||||
self.all_token_ids = ConstantList(self._all_token_ids)
|
self.all_token_ids = ConstantList(self._all_token_ids)
|
||||||
|
|
||||||
|
# State
|
||||||
|
# The number of tokens with prefix cache hits.
|
||||||
|
self.num_cached_tokens = -1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||||
if request.mm_inputs is not None:
|
if request.mm_inputs is not None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user