mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 05:47:05 +08:00
Add tokenization_kwargs to encode for embedding model truncation (#21033)
This commit is contained in:
parent
226b452a20
commit
44554a0068
@ -438,6 +438,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Async version of
|
||||
@ -468,6 +469,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
prompt,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
if isinstance(params, SamplingParams) and \
|
||||
@ -862,6 +864,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]:
|
||||
if not self.is_running:
|
||||
if self.start_engine_loop:
|
||||
@ -889,6 +892,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return stream.generator()
|
||||
@ -996,6 +1000,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""Generate outputs for a request from a pooling model.
|
||||
|
||||
@ -1070,6 +1075,7 @@ class AsyncLLMEngine(EngineClient):
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
):
|
||||
yield LLMEngine.validate_output(output, PoolingRequestOutput)
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@ -965,6 +965,7 @@ class LLM:
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -981,6 +982,7 @@ class LLM:
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -997,6 +999,7 @@ class LLM:
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -1014,6 +1017,7 @@ class LLM:
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -1031,6 +1035,7 @@ class LLM:
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -1046,6 +1051,7 @@ class LLM:
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -1066,6 +1072,7 @@ class LLM:
|
||||
lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
pooling_task: PoolingTask = "encode",
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> list[PoolingRequestOutput]:
|
||||
"""Apply pooling to the hidden states corresponding to the input
|
||||
prompts.
|
||||
@ -1131,9 +1138,11 @@ class LLM:
|
||||
for pooling_param in pooling_params:
|
||||
pooling_param.verify(pooling_task, model_config)
|
||||
|
||||
tokenization_kwargs = dict[str, Any]()
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = dict[str, Any]()
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=parsed_prompts,
|
||||
|
||||
@ -437,6 +437,7 @@ class AsyncLLM(EngineClient):
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> AsyncGenerator[PoolingRequestOutput, None]:
|
||||
"""
|
||||
Main function called by the API server to kick off a request
|
||||
@ -465,6 +466,7 @@ class AsyncLLM(EngineClient):
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
# The output_handler task pushes items into the queue.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user