diff --git a/tests/models/language/pooling/test_extract_hidden_states.py b/tests/models/language/pooling/test_extract_hidden_states.py index f8e3fa7d1560..0d41b93233d5 100644 --- a/tests/models/language/pooling/test_extract_hidden_states.py +++ b/tests/models/language/pooling/test_extract_hidden_states.py @@ -11,7 +11,7 @@ from vllm import TokensPrompt ["Qwen/Qwen3-0.6B"], ) @torch.inference_mode -def test_embed_models(hf_runner, vllm_runner, model: str): +def test_extract_hidden_states(hf_runner, vllm_runner, model: str): n_prompt_tokens = [55, 56, 57] token_prompts = [[1024 + i for i in range(n)] for n in n_prompt_tokens] @@ -21,7 +21,7 @@ def test_embed_models(hf_runner, vllm_runner, model: str): enforce_eager=True, runner="pooling", enable_chunked_prefill=False, - enable_prefix_caching=False, + enable_prefix_caching=True, ) as vllm_model: pooling_outputs = vllm_model.llm.encode( [TokensPrompt(prompt_token_ids=t) for t in token_prompts], @@ -30,4 +30,29 @@ def test_embed_models(hf_runner, vllm_runner, model: str): for n, output in zip(n_prompt_tokens, pooling_outputs): assert len(output.prompt_token_ids) == n + assert len(output.outputs.data) == n assert output.num_cached_tokens == 0 + + # test enable_prefix_caching plus all pooling + # we need to skip reading cache at this request by + # request.skip_reading_prefix_cache + pooling_outputs = vllm_model.llm.encode( + [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + pooling_task="token_embed", + ) + + for n, output in zip(n_prompt_tokens, pooling_outputs): + assert len(output.prompt_token_ids) == n + assert len(output.outputs.data) == n + assert output.num_cached_tokens == 0 + + # skip_reading_prefix_cache can still write to cache + # to accelerate following requests + pooling_outputs = vllm_model.llm.encode( + [TokensPrompt(prompt_token_ids=t) for t in token_prompts], + pooling_task="embed", + ) + + for n, output in zip(n_prompt_tokens, pooling_outputs): + assert len(output.prompt_token_ids) == n + assert output.num_cached_tokens > 0 diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 72a8320cc1bf..5c3dfa8ac9cb 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -57,6 +57,7 @@ class PoolingParams( ## Internal use only task: PoolingTask | None = None requires_token_ids: bool = False + skip_reading_prefix_cache: bool = None extra_kwargs: dict[str, Any] | None = None output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY @@ -93,6 +94,8 @@ class PoolingParams( # plugin task uses io_processor.parse_request to verify inputs, # skipping PoolingParams verify if self.task == "plugin": + if self.skip_reading_prefix_cache is None: + self.skip_reading_prefix_cache = True return # NOTE: Task validation needs to done against the model instance, @@ -122,6 +125,15 @@ class PoolingParams( if getattr(self, k, None) is None: setattr(self, k, getattr(pooler_config, k)) + if self.skip_reading_prefix_cache is None: + # If prefix caching is enabled, + # the output of all pooling may less than n_prompt_tokens, + # we need to skip reading cache at this request. + if self.task in ["token_embed", "token_classify"]: + self.skip_reading_prefix_cache = True + else: + self.skip_reading_prefix_cache = False + self._verify_step_pooling(pooler_config, valid_parameters) def _verify_step_pooling( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index dd820840410e..901d66163452 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -254,6 +254,8 @@ class SamplingParams( generated token can complete the sequence.""" _bad_words_token_ids: list[list[int]] | None = None + skip_reading_prefix_cache: bool = None + @staticmethod def from_optional( n: int | None = 1, @@ -414,6 +416,12 @@ class SamplingParams( self.structured_outputs = self.guided_decoding self.guided_decoding = None + if self.skip_reading_prefix_cache is None: + # If prefix caching is enabled, + # the output of prompt logprobs may less than n_prompt_tokens, + # we need to skip reading cache at this request. + self.skip_reading_prefix_cache = self.prompt_logprobs is not None + def _verify_args(self) -> None: if not isinstance(self.n, int): raise ValueError(f"n must be an int, but is of type {type(self.n)}") diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 63a1ff06e404..7f405fc248ac 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -185,12 +185,11 @@ class KVCacheManager: - A list of blocks that are computed for the request. - The number of computed tokens. """ - # Prefix caching is disabled or - # When the request requires prompt logprobs, we skip prefix caching. - if not self.enable_caching or ( - request.sampling_params is not None - and request.sampling_params.prompt_logprobs is not None - ): + # We skip finding the prefix cache hit when prefix caching is + # disabled or the request is marked as skipping kv cache read + # (which happens when the request requires prompt logprobs + # or calls a pooling model with all pooling). + if not self.enable_caching or request.skip_reading_prefix_cache: return self.empty_kv_cache_blocks, 0 # NOTE: When all tokens hit the cache, we must recompute the last token diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 7a5f1183ed48..3d92906fbf4b 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -127,6 +127,8 @@ class Request: self.get_hash_new_full_blocks = partial(block_hasher, self) self.block_hashes = self.get_hash_new_full_blocks() + self.skip_reading_prefix_cache = self.get_skip_reading_prefix_cache() + @classmethod def from_engine_core_request( cls, @@ -180,6 +182,19 @@ class Request: def num_output_tokens(self) -> int: return len(self._output_token_ids) + def get_skip_reading_prefix_cache(self) -> bool: + if ( + self.sampling_params is not None + and self.sampling_params.skip_reading_prefix_cache is not None + ): + return self.sampling_params.skip_reading_prefix_cache + elif ( + self.pooling_params is not None + and self.pooling_params.skip_reading_prefix_cache is not None + ): + return self.pooling_params.skip_reading_prefix_cache + return False + def is_finished(self) -> bool: return RequestStatus.is_finished(self.status)