mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 11:05:01 +08:00
Fix wrong truncate_prompt_tokens type hint (#22761)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com> Signed-off-by: Gabriel Marinho <104592062+gmarinho2@users.noreply.github.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
parent
038e9be4eb
commit
5b8077b8ac
@ -51,7 +51,7 @@ from vllm.tasks import PoolingTask
|
|||||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||||
get_cached_tokenizer)
|
get_cached_tokenizer)
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Counter, Device, is_list_of
|
from vllm.utils import Counter, Device, as_iter, is_list_of
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -364,14 +364,6 @@ class LLM:
|
|||||||
# Use default sampling params.
|
# Use default sampling params.
|
||||||
sampling_params = self.get_default_sampling_params()
|
sampling_params = self.get_default_sampling_params()
|
||||||
|
|
||||||
tokenization_kwargs: dict[str, Any] = {}
|
|
||||||
truncate_prompt_tokens = None
|
|
||||||
if isinstance(sampling_params, SamplingParams):
|
|
||||||
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
|
|
||||||
|
|
||||||
_validate_truncation_size(model_config.max_model_len,
|
|
||||||
truncate_prompt_tokens, tokenization_kwargs)
|
|
||||||
|
|
||||||
# Add any modality specific loras to the corresponding prompts
|
# Add any modality specific loras to the corresponding prompts
|
||||||
lora_request = self._get_modality_specific_lora_reqs(
|
lora_request = self._get_modality_specific_lora_reqs(
|
||||||
prompts, lora_request)
|
prompts, lora_request)
|
||||||
@ -381,7 +373,6 @@ class LLM:
|
|||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
priority=priority,
|
priority=priority,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -871,6 +862,8 @@ class LLM:
|
|||||||
If `False`, no progress bar is created.
|
If `False`, no progress bar is created.
|
||||||
lora_request: LoRA request to use for generation, if any.
|
lora_request: LoRA request to use for generation, if any.
|
||||||
pooling_task: Override the pooling task to use.
|
pooling_task: Override the pooling task to use.
|
||||||
|
tokenization_kwargs: overrides tokenization_kwargs set in
|
||||||
|
pooling_params
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of `PoolingRequestOutput` objects containing the
|
A list of `PoolingRequestOutput` objects containing the
|
||||||
@ -916,24 +909,17 @@ class LLM:
|
|||||||
# Use default pooling params.
|
# Use default pooling params.
|
||||||
pooling_params = PoolingParams()
|
pooling_params = PoolingParams()
|
||||||
|
|
||||||
if isinstance(pooling_params, PoolingParams):
|
for param in as_iter(pooling_params):
|
||||||
pooling_params.verify(pooling_task, model_config)
|
param.verify(pooling_task, model_config)
|
||||||
else:
|
# for backwards compatibility
|
||||||
for pooling_param in pooling_params:
|
if truncate_prompt_tokens is not None:
|
||||||
pooling_param.verify(pooling_task, model_config)
|
param.truncate_prompt_tokens = truncate_prompt_tokens
|
||||||
|
|
||||||
if tokenization_kwargs is None:
|
|
||||||
tokenization_kwargs = dict[str, Any]()
|
|
||||||
_validate_truncation_size(model_config.max_model_len,
|
|
||||||
truncate_prompt_tokens,
|
|
||||||
tokenization_kwargs)
|
|
||||||
|
|
||||||
self._validate_and_add_requests(
|
self._validate_and_add_requests(
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
params=pooling_params,
|
params=pooling_params,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
tokenization_kwargs=tokenization_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||||
@ -1385,7 +1371,6 @@ class LLM:
|
|||||||
*,
|
*,
|
||||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
|
||||||
priority: Optional[list[int]] = None,
|
priority: Optional[list[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(prompts, (str, dict)):
|
if isinstance(prompts, (str, dict)):
|
||||||
@ -1412,7 +1397,17 @@ class LLM:
|
|||||||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||||||
it = tqdm_func(it, desc="Adding requests")
|
it = tqdm_func(it, desc="Adding requests")
|
||||||
|
|
||||||
|
model_config = self.llm_engine.model_config
|
||||||
|
|
||||||
for i, prompt in enumerate(it):
|
for i, prompt in enumerate(it):
|
||||||
|
|
||||||
|
param = params[i] if isinstance(params, Sequence) else params
|
||||||
|
|
||||||
|
tokenization_kwargs: dict[str, Any] = {}
|
||||||
|
_validate_truncation_size(model_config.max_model_len,
|
||||||
|
param.truncate_prompt_tokens,
|
||||||
|
tokenization_kwargs)
|
||||||
|
|
||||||
self._add_request(
|
self._add_request(
|
||||||
prompt,
|
prompt,
|
||||||
params[i] if isinstance(params, Sequence) else params,
|
params[i] if isinstance(params, Sequence) else params,
|
||||||
|
|||||||
@ -452,7 +452,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
min_tokens: int = 0
|
min_tokens: int = 0
|
||||||
skip_special_tokens: bool = True
|
skip_special_tokens: bool = True
|
||||||
spaces_between_special_tokens: bool = True
|
spaces_between_special_tokens: bool = True
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||||
prompt_logprobs: Optional[int] = None
|
prompt_logprobs: Optional[int] = None
|
||||||
allowed_token_ids: Optional[list[int]] = None
|
allowed_token_ids: Optional[list[int]] = None
|
||||||
bad_words: list[str] = Field(default_factory=list)
|
bad_words: list[str] = Field(default_factory=list)
|
||||||
@ -995,7 +995,7 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
min_tokens: int = 0
|
min_tokens: int = 0
|
||||||
skip_special_tokens: bool = True
|
skip_special_tokens: bool = True
|
||||||
spaces_between_special_tokens: bool = True
|
spaces_between_special_tokens: bool = True
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||||
allowed_token_ids: Optional[list[int]] = None
|
allowed_token_ids: Optional[list[int]] = None
|
||||||
prompt_logprobs: Optional[int] = None
|
prompt_logprobs: Optional[int] = None
|
||||||
# --8<-- [end:completion-sampling-params]
|
# --8<-- [end:completion-sampling-params]
|
||||||
@ -1325,8 +1325,10 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
|||||||
# --8<-- [end:embedding-extra-params]
|
# --8<-- [end:embedding-extra-params]
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(dimensions=self.dimensions,
|
return PoolingParams(
|
||||||
normalize=self.normalize)
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
normalize=self.normalize)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||||
@ -1393,8 +1395,10 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(dimensions=self.dimensions,
|
return PoolingParams(
|
||||||
normalize=self.normalize)
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
dimensions=self.dimensions,
|
||||||
|
normalize=self.normalize)
|
||||||
|
|
||||||
|
|
||||||
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||||
@ -1430,7 +1434,9 @@ class ScoreRequest(OpenAIBaseModel):
|
|||||||
# --8<-- [end:score-extra-params]
|
# --8<-- [end:score-extra-params]
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(activation=self.activation)
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
activation=self.activation)
|
||||||
|
|
||||||
|
|
||||||
class RerankRequest(OpenAIBaseModel):
|
class RerankRequest(OpenAIBaseModel):
|
||||||
@ -1460,7 +1466,9 @@ class RerankRequest(OpenAIBaseModel):
|
|||||||
# --8<-- [end:rerank-extra-params]
|
# --8<-- [end:rerank-extra-params]
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(activation=self.activation)
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
activation=self.activation)
|
||||||
|
|
||||||
|
|
||||||
class RerankDocument(BaseModel):
|
class RerankDocument(BaseModel):
|
||||||
@ -1618,7 +1626,9 @@ class ClassificationRequest(OpenAIBaseModel):
|
|||||||
# --8<-- [end:classification-extra-params]
|
# --8<-- [end:classification-extra-params]
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(activation=self.activation)
|
return PoolingParams(
|
||||||
|
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||||
|
activation=self.activation)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationData(OpenAIBaseModel):
|
class ClassificationData(OpenAIBaseModel):
|
||||||
|
|||||||
@ -237,7 +237,6 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
documents=request.documents,
|
documents=request.documents,
|
||||||
chat_template_kwargs=request.chat_template_kwargs,
|
chat_template_kwargs=request.chat_template_kwargs,
|
||||||
tool_parser=tool_parser,
|
tool_parser=tool_parser,
|
||||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
|
||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -61,7 +61,6 @@ class ClassificationMixin(OpenAIServing):
|
|||||||
ctx.request,
|
ctx.request,
|
||||||
ctx.tokenizer,
|
ctx.tokenizer,
|
||||||
ctx.request.input,
|
ctx.request.input,
|
||||||
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
@ -157,18 +156,6 @@ class ServingClassification(ClassificationMixin):
|
|||||||
|
|
||||||
return await super().handle(ctx) # type: ignore
|
return await super().handle(ctx) # type: ignore
|
||||||
|
|
||||||
@override
|
|
||||||
def _validate_request(
|
|
||||||
self,
|
|
||||||
ctx: ClassificationServeContext,
|
|
||||||
) -> Optional[ErrorResponse]:
|
|
||||||
if error := super()._validate_request(ctx):
|
|
||||||
return error
|
|
||||||
|
|
||||||
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _create_pooling_params(
|
def _create_pooling_params(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -137,7 +137,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
request.prompt,
|
request.prompt,
|
||||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
|
||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|||||||
@ -97,7 +97,6 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
# so there is no need to append extra tokens to the input
|
# so there is no need to append extra tokens to the input
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
continue_final_message=False,
|
continue_final_message=False,
|
||||||
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
|
|
||||||
add_special_tokens=ctx.request.add_special_tokens,
|
add_special_tokens=ctx.request.add_special_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -106,7 +105,6 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
ctx.request,
|
ctx.request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ctx.request.input,
|
ctx.request.input,
|
||||||
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
|
|
||||||
add_special_tokens=ctx.request.add_special_tokens,
|
add_special_tokens=ctx.request.add_special_tokens,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@ -631,18 +629,6 @@ class OpenAIServingEmbedding(EmbeddingMixin):
|
|||||||
|
|
||||||
return await super().handle(ctx) # type: ignore
|
return await super().handle(ctx) # type: ignore
|
||||||
|
|
||||||
@override
|
|
||||||
def _validate_request(
|
|
||||||
self,
|
|
||||||
ctx: ServeContext[EmbeddingRequest],
|
|
||||||
) -> Optional[ErrorResponse]:
|
|
||||||
if error := super()._validate_request(ctx):
|
|
||||||
return error
|
|
||||||
|
|
||||||
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _create_pooling_params(
|
def _create_pooling_params(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -165,7 +165,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
|
|||||||
|
|
||||||
# Shared across most requests
|
# Shared across most requests
|
||||||
tokenizer: Optional[AnyTokenizer] = None
|
tokenizer: Optional[AnyTokenizer] = None
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
|
||||||
|
|
||||||
# `protected_namespaces` resolves Pydantic v2's warning
|
# `protected_namespaces` resolves Pydantic v2's warning
|
||||||
# on conflict with protected namespace "model_"
|
# on conflict with protected namespace "model_"
|
||||||
@ -297,14 +296,12 @@ class OpenAIServing:
|
|||||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
|
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
|
||||||
None)
|
None)
|
||||||
|
|
||||||
if truncate_prompt_tokens is not None:
|
if truncate_prompt_tokens is not None and \
|
||||||
if truncate_prompt_tokens <= self.max_model_len:
|
truncate_prompt_tokens > self.max_model_len:
|
||||||
ctx.truncate_prompt_tokens = truncate_prompt_tokens
|
return self.create_error_response(
|
||||||
else:
|
"truncate_prompt_tokens value is "
|
||||||
return self.create_error_response(
|
"greater than max_model_len."
|
||||||
"truncate_prompt_tokens value is "
|
" Please, select a smaller truncation size.")
|
||||||
"greater than max_model_len."
|
|
||||||
" Please, select a smaller truncation size.")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _create_pooling_params(
|
def _create_pooling_params(
|
||||||
@ -528,7 +525,6 @@ class OpenAIServing:
|
|||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
|
|
||||||
add_special_tokens: bool,
|
add_special_tokens: bool,
|
||||||
) -> TextTokensPrompt:
|
) -> TextTokensPrompt:
|
||||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||||
@ -538,6 +534,9 @@ class OpenAIServing:
|
|||||||
"do_lower_case", False)):
|
"do_lower_case", False)):
|
||||||
prompt = prompt.lower()
|
prompt = prompt.lower()
|
||||||
|
|
||||||
|
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||||
|
None)
|
||||||
|
|
||||||
if truncate_prompt_tokens is None:
|
if truncate_prompt_tokens is None:
|
||||||
encoded = await async_tokenizer(
|
encoded = await async_tokenizer(
|
||||||
prompt, add_special_tokens=add_special_tokens)
|
prompt, add_special_tokens=add_special_tokens)
|
||||||
@ -565,8 +564,10 @@ class OpenAIServing:
|
|||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
prompt_ids: list[int],
|
prompt_ids: list[int],
|
||||||
tokenizer: Optional[AnyTokenizer],
|
tokenizer: Optional[AnyTokenizer],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
|
||||||
) -> TextTokensPrompt:
|
) -> TextTokensPrompt:
|
||||||
|
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||||
|
None)
|
||||||
|
|
||||||
if truncate_prompt_tokens is None:
|
if truncate_prompt_tokens is None:
|
||||||
input_ids = prompt_ids
|
input_ids = prompt_ids
|
||||||
elif truncate_prompt_tokens < 0:
|
elif truncate_prompt_tokens < 0:
|
||||||
@ -652,7 +653,6 @@ class OpenAIServing:
|
|||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
prompt_input: Union[str, list[int]],
|
prompt_input: Union[str, list[int]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
|
||||||
add_special_tokens: bool = True,
|
add_special_tokens: bool = True,
|
||||||
) -> TextTokensPrompt:
|
) -> TextTokensPrompt:
|
||||||
"""
|
"""
|
||||||
@ -664,7 +664,6 @@ class OpenAIServing:
|
|||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
[prompt_input],
|
[prompt_input],
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
):
|
):
|
||||||
return result
|
return result
|
||||||
@ -675,7 +674,6 @@ class OpenAIServing:
|
|||||||
request: AnyRequest,
|
request: AnyRequest,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
prompt_inputs: Iterable[Union[str, list[int]]],
|
prompt_inputs: Iterable[Union[str, list[int]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
|
||||||
add_special_tokens: bool = True,
|
add_special_tokens: bool = True,
|
||||||
) -> AsyncGenerator[TextTokensPrompt, None]:
|
) -> AsyncGenerator[TextTokensPrompt, None]:
|
||||||
"""
|
"""
|
||||||
@ -689,7 +687,6 @@ class OpenAIServing:
|
|||||||
request,
|
request,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -697,7 +694,6 @@ class OpenAIServing:
|
|||||||
request,
|
request,
|
||||||
prompt_ids=prompt,
|
prompt_ids=prompt,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _tokenize_prompt_input_or_inputs_async(
|
async def _tokenize_prompt_input_or_inputs_async(
|
||||||
@ -706,7 +702,6 @@ class OpenAIServing:
|
|||||||
tokenizer: Optional[AnyTokenizer],
|
tokenizer: Optional[AnyTokenizer],
|
||||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||||
list[list[int]]]],
|
list[list[int]]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
|
||||||
add_special_tokens: bool = True,
|
add_special_tokens: bool = True,
|
||||||
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
|
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
|
||||||
"""
|
"""
|
||||||
@ -719,6 +714,12 @@ class OpenAIServing:
|
|||||||
inputs_embeds = list[EmbedsPrompt]()
|
inputs_embeds = list[EmbedsPrompt]()
|
||||||
inputs_text = list[TextTokensPrompt]()
|
inputs_text = list[TextTokensPrompt]()
|
||||||
|
|
||||||
|
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||||
|
None)
|
||||||
|
|
||||||
|
if (truncate_prompt_tokens or 0) < 0:
|
||||||
|
truncate_prompt_tokens = self.max_model_len
|
||||||
|
|
||||||
if (isinstance(request, CompletionRequest)
|
if (isinstance(request, CompletionRequest)
|
||||||
and request.prompt_embeds is not None):
|
and request.prompt_embeds is not None):
|
||||||
inputs_embeds.extend(
|
inputs_embeds.extend(
|
||||||
@ -748,14 +749,10 @@ class OpenAIServing:
|
|||||||
request,
|
request,
|
||||||
prompt_input["content"],
|
prompt_input["content"],
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
||||||
add_special_tokens=add_special_tokens)
|
add_special_tokens=add_special_tokens)
|
||||||
else:
|
else:
|
||||||
task = self._normalize_prompt_tokens_to_input(
|
task = self._normalize_prompt_tokens_to_input(
|
||||||
request,
|
request, prompt_input["content"], tokenizer=tokenizer)
|
||||||
prompt_input["content"],
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens)
|
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
|
|
||||||
# Wait for all tokenization tasks to complete
|
# Wait for all tokenization tasks to complete
|
||||||
@ -772,7 +769,6 @@ class OpenAIServing:
|
|||||||
TokenizeCompletionRequest],
|
TokenizeCompletionRequest],
|
||||||
tokenizer: Optional[AnyTokenizer],
|
tokenizer: Optional[AnyTokenizer],
|
||||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
|
||||||
add_special_tokens: bool = ...,
|
add_special_tokens: bool = ...,
|
||||||
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
|
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
|
||||||
...
|
...
|
||||||
@ -784,7 +780,6 @@ class OpenAIServing:
|
|||||||
tokenizer: Optional[AnyTokenizer],
|
tokenizer: Optional[AnyTokenizer],
|
||||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||||
list[list[int]]]],
|
list[list[int]]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
|
||||||
add_special_tokens: bool = ...,
|
add_special_tokens: bool = ...,
|
||||||
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
|
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
|
||||||
EngineTokensPrompt, EngineEmbedsPrompt]]]:
|
EngineTokensPrompt, EngineEmbedsPrompt]]]:
|
||||||
@ -796,7 +791,6 @@ class OpenAIServing:
|
|||||||
tokenizer: Optional[AnyTokenizer],
|
tokenizer: Optional[AnyTokenizer],
|
||||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||||
list[list[int]]]],
|
list[list[int]]]],
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
|
||||||
add_special_tokens: bool = True,
|
add_special_tokens: bool = True,
|
||||||
) -> tuple[Union[list[TextTokensPrompt], list[Union[
|
) -> tuple[Union[list[TextTokensPrompt], list[Union[
|
||||||
TextTokensPrompt, EmbedsPrompt]]], Union[
|
TextTokensPrompt, EmbedsPrompt]]], Union[
|
||||||
@ -813,7 +807,6 @@ class OpenAIServing:
|
|||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
input_or_inputs,
|
input_or_inputs,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -866,7 +859,6 @@ class OpenAIServing:
|
|||||||
documents: Optional[list[dict[str, str]]] = None,
|
documents: Optional[list[dict[str, str]]] = None,
|
||||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||||
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
|
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
|
||||||
add_special_tokens: bool = False,
|
add_special_tokens: bool = False,
|
||||||
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
|
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
|
||||||
list[EngineTokensPrompt]]:
|
list[EngineTokensPrompt]]:
|
||||||
@ -941,7 +933,6 @@ class OpenAIServing:
|
|||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
request_prompt,
|
request_prompt,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -120,7 +120,6 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
# so there is no need to append extra tokens to the input
|
# so there is no need to append extra tokens to the input
|
||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
continue_final_message=False,
|
continue_final_message=False,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -129,7 +128,6 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
request,
|
request,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
request.input,
|
request.input,
|
||||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
||||||
add_special_tokens=request.add_special_tokens,
|
add_special_tokens=request.add_special_tokens,
|
||||||
)
|
)
|
||||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||||
|
|||||||
@ -266,12 +266,14 @@ class ServingScores(OpenAIServing):
|
|||||||
request: Union[ScoreRequest, RerankRequest],
|
request: Union[ScoreRequest, RerankRequest],
|
||||||
request_id: str,
|
request_id: str,
|
||||||
raw_request: Optional[Request] = None,
|
raw_request: Optional[Request] = None,
|
||||||
truncate_prompt_tokens: Optional[int] = None,
|
|
||||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||||
lora_request = self._maybe_get_adapters(request)
|
lora_request = self._maybe_get_adapters(request)
|
||||||
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||||
|
|
||||||
|
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||||
|
None)
|
||||||
|
|
||||||
tokenization_kwargs: dict[str, Any] = {}
|
tokenization_kwargs: dict[str, Any] = {}
|
||||||
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
|
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
|
||||||
tokenization_kwargs)
|
tokenization_kwargs)
|
||||||
@ -343,7 +345,6 @@ class ServingScores(OpenAIServing):
|
|||||||
request,
|
request,
|
||||||
request_id,
|
request_id,
|
||||||
raw_request,
|
raw_request,
|
||||||
request.truncate_prompt_tokens,
|
|
||||||
)
|
)
|
||||||
if isinstance(final_res_batch, ErrorResponse):
|
if isinstance(final_res_batch, ErrorResponse):
|
||||||
return final_res_batch
|
return final_res_batch
|
||||||
@ -391,7 +392,6 @@ class ServingScores(OpenAIServing):
|
|||||||
request,
|
request,
|
||||||
request_id,
|
request_id,
|
||||||
raw_request,
|
raw_request,
|
||||||
request.truncate_prompt_tokens,
|
|
||||||
)
|
)
|
||||||
if isinstance(final_res_batch, ErrorResponse):
|
if isinstance(final_res_batch, ErrorResponse):
|
||||||
return final_res_batch
|
return final_res_batch
|
||||||
|
|||||||
@ -346,6 +346,22 @@ class InputPreprocessor:
|
|||||||
) -> EmbedsInputs:
|
) -> EmbedsInputs:
|
||||||
return self._process_embeds(parsed_content)
|
return self._process_embeds(parsed_content)
|
||||||
|
|
||||||
|
def _truncate_inputs(
|
||||||
|
self,
|
||||||
|
inputs: list[int],
|
||||||
|
tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]:
|
||||||
|
|
||||||
|
if not tokenization_kwargs or "truncation" not in \
|
||||||
|
tokenization_kwargs or self.tokenizer is None:
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
max_length = tokenization_kwargs["max_length"]
|
||||||
|
|
||||||
|
if self.tokenizer.truncation_side == "left":
|
||||||
|
return inputs[-max_length:]
|
||||||
|
else:
|
||||||
|
return inputs[:max_length]
|
||||||
|
|
||||||
def _process_tokens(
|
def _process_tokens(
|
||||||
self,
|
self,
|
||||||
parsed_content: TokensPrompt,
|
parsed_content: TokensPrompt,
|
||||||
@ -354,7 +370,8 @@ class InputPreprocessor:
|
|||||||
*,
|
*,
|
||||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||||
) -> Union[TokenInputs, MultiModalInputs]:
|
) -> Union[TokenInputs, MultiModalInputs]:
|
||||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
prompt_token_ids = self._truncate_inputs(
|
||||||
|
parsed_content["prompt_token_ids"], tokenization_kwargs)
|
||||||
|
|
||||||
inputs: Union[TokenInputs, MultiModalInputs]
|
inputs: Union[TokenInputs, MultiModalInputs]
|
||||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||||
@ -382,7 +399,8 @@ class InputPreprocessor:
|
|||||||
*,
|
*,
|
||||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||||
) -> Union[TokenInputs, MultiModalInputs]:
|
) -> Union[TokenInputs, MultiModalInputs]:
|
||||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
prompt_token_ids = self._truncate_inputs(
|
||||||
|
parsed_content["prompt_token_ids"], tokenization_kwargs)
|
||||||
|
|
||||||
inputs: Union[TokenInputs, MultiModalInputs]
|
inputs: Union[TokenInputs, MultiModalInputs]
|
||||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Annotated, Any, Optional
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
|
|
||||||
@ -27,6 +27,11 @@ class PoolingParams(
|
|||||||
the classification outputs.
|
the classification outputs.
|
||||||
softmax: Whether to apply softmax to the reward outputs.
|
softmax: Whether to apply softmax to the reward outputs.
|
||||||
"""
|
"""
|
||||||
|
truncate_prompt_tokens: Optional[Annotated[int,
|
||||||
|
msgspec.Meta(ge=-1)]] = None
|
||||||
|
"""If set to -1, will use the truncation size supported by the model. If
|
||||||
|
set to an integer k, will use only the last k tokens from the prompt
|
||||||
|
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
||||||
|
|
||||||
## for embeddings models
|
## for embeddings models
|
||||||
dimensions: Optional[int] = None
|
dimensions: Optional[int] = None
|
||||||
|
|||||||
@ -182,7 +182,8 @@ class SamplingParams(
|
|||||||
optionally prompt tokens as a first argument."""
|
optionally prompt tokens as a first argument."""
|
||||||
include_stop_str_in_output: bool = False
|
include_stop_str_in_output: bool = False
|
||||||
"""Whether to include the stop strings in output text."""
|
"""Whether to include the stop strings in output text."""
|
||||||
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
|
truncate_prompt_tokens: Optional[Annotated[int,
|
||||||
|
msgspec.Meta(ge=-1)]] = None
|
||||||
"""If set to -1, will use the truncation size supported by the model. If
|
"""If set to -1, will use the truncation size supported by the model. If
|
||||||
set to an integer k, will use only the last k tokens from the prompt
|
set to an integer k, will use only the last k tokens from the prompt
|
||||||
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
||||||
@ -241,7 +242,8 @@ class SamplingParams(
|
|||||||
spaces_between_special_tokens: bool = True,
|
spaces_between_special_tokens: bool = True,
|
||||||
logits_processors: Optional[list[LogitsProcessor]] = None,
|
logits_processors: Optional[list[LogitsProcessor]] = None,
|
||||||
truncate_prompt_tokens: Optional[Annotated[int,
|
truncate_prompt_tokens: Optional[Annotated[int,
|
||||||
msgspec.Meta(ge=1)]] = None,
|
msgspec.Meta(
|
||||||
|
ge=-1)]] = None,
|
||||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||||
guided_decoding: Optional[GuidedDecodingParams] = None,
|
guided_decoding: Optional[GuidedDecodingParams] = None,
|
||||||
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
|
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
|
||||||
@ -411,9 +413,11 @@ class SamplingParams(
|
|||||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||||
f"{self.prompt_logprobs}.")
|
f"{self.prompt_logprobs}.")
|
||||||
if (self.truncate_prompt_tokens is not None
|
if (self.truncate_prompt_tokens is not None
|
||||||
and self.truncate_prompt_tokens < 1):
|
and (self.truncate_prompt_tokens == 0
|
||||||
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
or self.truncate_prompt_tokens < -1)):
|
||||||
f"got {self.truncate_prompt_tokens}")
|
raise ValueError(
|
||||||
|
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
|
||||||
|
f"got {self.truncate_prompt_tokens}")
|
||||||
assert isinstance(self.stop_token_ids, list)
|
assert isinstance(self.stop_token_ids, list)
|
||||||
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
|
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
|
||||||
raise ValueError(f"stop_token_ids must contain only integers, "
|
raise ValueError(f"stop_token_ids must contain only integers, "
|
||||||
|
|||||||
@ -23,6 +23,7 @@ class TokenizerGroup:
|
|||||||
self.tokenizer_config = tokenizer_config
|
self.tokenizer_config = tokenizer_config
|
||||||
self.enable_lora = enable_lora
|
self.enable_lora = enable_lora
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
|
self.truncation_side = tokenizer_config.get("truncation_side", "left")
|
||||||
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||||
max_loras = tokenizer_config.get("max_loras", 0)
|
max_loras = tokenizer_config.get("max_loras", 0)
|
||||||
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
|
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
|
||||||
|
|||||||
@ -1328,6 +1328,12 @@ def as_list(maybe_list: Iterable[T]) -> list[T]:
|
|||||||
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
|
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
|
||||||
|
|
||||||
|
|
||||||
|
def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]:
|
||||||
|
if isinstance(obj, str) or not isinstance(obj, Iterable):
|
||||||
|
obj = [obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
# `collections` helpers
|
# `collections` helpers
|
||||||
def is_list_of(
|
def is_list_of(
|
||||||
value: object,
|
value: object,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user