diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9d900e691b0a0..479524a117995 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -51,7 +51,7 @@ from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, Device, is_list_of +from vllm.utils import Counter, Device, as_iter, is_list_of from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: @@ -364,14 +364,6 @@ class LLM: # Use default sampling params. sampling_params = self.get_default_sampling_params() - tokenization_kwargs: dict[str, Any] = {} - truncate_prompt_tokens = None - if isinstance(sampling_params, SamplingParams): - truncate_prompt_tokens = sampling_params.truncate_prompt_tokens - - _validate_truncation_size(model_config.max_model_len, - truncate_prompt_tokens, tokenization_kwargs) - # Add any modality specific loras to the corresponding prompts lora_request = self._get_modality_specific_lora_reqs( prompts, lora_request) @@ -381,7 +373,6 @@ class LLM: params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, priority=priority, ) @@ -871,6 +862,8 @@ class LLM: If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. pooling_task: Override the pooling task to use. + tokenization_kwargs: overrides tokenization_kwargs set in + pooling_params Returns: A list of `PoolingRequestOutput` objects containing the @@ -916,24 +909,17 @@ class LLM: # Use default pooling params. pooling_params = PoolingParams() - if isinstance(pooling_params, PoolingParams): - pooling_params.verify(pooling_task, model_config) - else: - for pooling_param in pooling_params: - pooling_param.verify(pooling_task, model_config) - - if tokenization_kwargs is None: - tokenization_kwargs = dict[str, Any]() - _validate_truncation_size(model_config.max_model_len, - truncate_prompt_tokens, - tokenization_kwargs) + for param in as_iter(pooling_params): + param.verify(pooling_task, model_config) + # for backwards compatibility + if truncate_prompt_tokens is not None: + param.truncate_prompt_tokens = truncate_prompt_tokens self._validate_and_add_requests( prompts=prompts, params=pooling_params, use_tqdm=use_tqdm, lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, ) outputs = self._run_engine(use_tqdm=use_tqdm) @@ -1385,7 +1371,6 @@ class LLM: *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], - tokenization_kwargs: Optional[dict[str, Any]] = None, priority: Optional[list[int]] = None, ) -> None: if isinstance(prompts, (str, dict)): @@ -1412,7 +1397,17 @@ class LLM: tqdm_func = use_tqdm if callable(use_tqdm) else tqdm it = tqdm_func(it, desc="Adding requests") + model_config = self.llm_engine.model_config + for i, prompt in enumerate(it): + + param = params[i] if isinstance(params, Sequence) else params + + tokenization_kwargs: dict[str, Any] = {} + _validate_truncation_size(model_config.max_model_len, + param.truncate_prompt_tokens, + tokenization_kwargs) + self._add_request( prompt, params[i] if isinstance(params, Sequence) else params, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 5cb41bd93d4bc..0fa1136b47b85 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -452,7 +452,7 @@ class ChatCompletionRequest(OpenAIBaseModel): min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None bad_words: list[str] = Field(default_factory=list) @@ -995,7 +995,7 @@ class CompletionRequest(OpenAIBaseModel): min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None allowed_token_ids: Optional[list[int]] = None prompt_logprobs: Optional[int] = None # --8<-- [end:completion-sampling-params] @@ -1325,8 +1325,10 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): # --8<-- [end:embedding-extra-params] def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - normalize=self.normalize) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + dimensions=self.dimensions, + normalize=self.normalize) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1393,8 +1395,10 @@ class EmbeddingChatRequest(OpenAIBaseModel): return data def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - normalize=self.normalize) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + dimensions=self.dimensions, + normalize=self.normalize) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] @@ -1430,7 +1434,9 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation) class RerankRequest(OpenAIBaseModel): @@ -1460,7 +1466,9 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation) class RerankDocument(BaseModel): @@ -1618,7 +1626,9 @@ class ClassificationRequest(OpenAIBaseModel): # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams(activation=self.activation) + return PoolingParams( + truncate_prompt_tokens=self.truncate_prompt_tokens, + activation=self.activation) class ClassificationData(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1c0ffdfb91897..6300d0758c3d4 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -237,7 +237,6 @@ class OpenAIServingChat(OpenAIServing): documents=request.documents, chat_template_kwargs=request.chat_template_kwargs, tool_parser=tool_parser, - truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) else: diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 1d510d0b60a2d..b4fdc36390319 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -61,7 +61,6 @@ class ClassificationMixin(OpenAIServing): ctx.request, ctx.tokenizer, ctx.request.input, - truncate_prompt_tokens=ctx.request.truncate_prompt_tokens, ) return None @@ -157,18 +156,6 @@ class ServingClassification(ClassificationMixin): return await super().handle(ctx) # type: ignore - @override - def _validate_request( - self, - ctx: ClassificationServeContext, - ) -> Optional[ErrorResponse]: - if error := super()._validate_request(ctx): - return error - - ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens - - return None - @override def _create_pooling_params( self, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index f461d7609b945..11effba8f9eb3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -137,7 +137,6 @@ class OpenAIServingCompletion(OpenAIServing): request, tokenizer, request.prompt, - truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) except ValueError as e: diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 45c1932f1873c..0a0d98db2d0d8 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -97,7 +97,6 @@ class EmbeddingMixin(OpenAIServing): # so there is no need to append extra tokens to the input add_generation_prompt=False, continue_final_message=False, - truncate_prompt_tokens=ctx.truncate_prompt_tokens, add_special_tokens=ctx.request.add_special_tokens, ) else: @@ -106,7 +105,6 @@ class EmbeddingMixin(OpenAIServing): ctx.request, tokenizer, ctx.request.input, - truncate_prompt_tokens=ctx.truncate_prompt_tokens, add_special_tokens=ctx.request.add_special_tokens, ) return None @@ -631,18 +629,6 @@ class OpenAIServingEmbedding(EmbeddingMixin): return await super().handle(ctx) # type: ignore - @override - def _validate_request( - self, - ctx: ServeContext[EmbeddingRequest], - ) -> Optional[ErrorResponse]: - if error := super()._validate_request(ctx): - return error - - ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens - - return None - @override def _create_pooling_params( self, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index ca6f3987936da..320c1e61f1d13 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -165,7 +165,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, # Shared across most requests tokenizer: Optional[AnyTokenizer] = None - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # `protected_namespaces` resolves Pydantic v2's warning # on conflict with protected namespace "model_" @@ -297,14 +296,12 @@ class OpenAIServing: truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) - if truncate_prompt_tokens is not None: - if truncate_prompt_tokens <= self.max_model_len: - ctx.truncate_prompt_tokens = truncate_prompt_tokens - else: - return self.create_error_response( - "truncate_prompt_tokens value is " - "greater than max_model_len." - " Please, select a smaller truncation size.") + if truncate_prompt_tokens is not None and \ + truncate_prompt_tokens > self.max_model_len: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") return None def _create_pooling_params( @@ -528,7 +525,6 @@ class OpenAIServing: request: AnyRequest, prompt: str, tokenizer: AnyTokenizer, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]], add_special_tokens: bool, ) -> TextTokensPrompt: async_tokenizer = self._get_async_tokenizer(tokenizer) @@ -538,6 +534,9 @@ class OpenAIServing: "do_lower_case", False)): prompt = prompt.lower() + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + if truncate_prompt_tokens is None: encoded = await async_tokenizer( prompt, add_special_tokens=add_special_tokens) @@ -565,8 +564,10 @@ class OpenAIServing: request: AnyRequest, prompt_ids: list[int], tokenizer: Optional[AnyTokenizer], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], ) -> TextTokensPrompt: + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + if truncate_prompt_tokens is None: input_ids = prompt_ids elif truncate_prompt_tokens < 0: @@ -652,7 +653,6 @@ class OpenAIServing: request: AnyRequest, tokenizer: AnyTokenizer, prompt_input: Union[str, list[int]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> TextTokensPrompt: """ @@ -664,7 +664,6 @@ class OpenAIServing: request, tokenizer, [prompt_input], - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ): return result @@ -675,7 +674,6 @@ class OpenAIServing: request: AnyRequest, tokenizer: AnyTokenizer, prompt_inputs: Iterable[Union[str, list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> AsyncGenerator[TextTokensPrompt, None]: """ @@ -689,7 +687,6 @@ class OpenAIServing: request, prompt=prompt, tokenizer=tokenizer, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) else: @@ -697,7 +694,6 @@ class OpenAIServing: request, prompt_ids=prompt, tokenizer=tokenizer, - truncate_prompt_tokens=truncate_prompt_tokens, ) async def _tokenize_prompt_input_or_inputs_async( @@ -706,7 +702,6 @@ class OpenAIServing: tokenizer: Optional[AnyTokenizer], input_or_inputs: Optional[Union[str, list[str], list[int], list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]: """ @@ -719,6 +714,12 @@ class OpenAIServing: inputs_embeds = list[EmbedsPrompt]() inputs_text = list[TextTokensPrompt]() + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + + if (truncate_prompt_tokens or 0) < 0: + truncate_prompt_tokens = self.max_model_len + if (isinstance(request, CompletionRequest) and request.prompt_embeds is not None): inputs_embeds.extend( @@ -748,14 +749,10 @@ class OpenAIServing: request, prompt_input["content"], tokenizer=tokenizer, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens) else: task = self._normalize_prompt_tokens_to_input( - request, - prompt_input["content"], - tokenizer=tokenizer, - truncate_prompt_tokens=truncate_prompt_tokens) + request, prompt_input["content"], tokenizer=tokenizer) tasks.append(task) # Wait for all tokenization tasks to complete @@ -772,7 +769,6 @@ class OpenAIServing: TokenizeCompletionRequest], tokenizer: Optional[AnyTokenizer], input_or_inputs: Union[str, list[str], list[int], list[list[int]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., add_special_tokens: bool = ..., ) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]: ... @@ -784,7 +780,6 @@ class OpenAIServing: tokenizer: Optional[AnyTokenizer], input_or_inputs: Optional[Union[str, list[str], list[int], list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ..., add_special_tokens: bool = ..., ) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[ EngineTokensPrompt, EngineEmbedsPrompt]]]: @@ -796,7 +791,6 @@ class OpenAIServing: tokenizer: Optional[AnyTokenizer], input_or_inputs: Optional[Union[str, list[str], list[int], list[list[int]]]], - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, ) -> tuple[Union[list[TextTokensPrompt], list[Union[ TextTokensPrompt, EmbedsPrompt]]], Union[ @@ -813,7 +807,6 @@ class OpenAIServing: request, tokenizer, input_or_inputs, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) @@ -866,7 +859,6 @@ class OpenAIServing: documents: Optional[list[dict[str, str]]] = None, chat_template_kwargs: Optional[dict[str, Any]] = None, tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None, - truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = False, ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt], list[EngineTokensPrompt]]: @@ -941,7 +933,6 @@ class OpenAIServing: request, tokenizer, request_prompt, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, ) else: diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index e8cb1aed84596..b2c2af2ec58e0 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -120,7 +120,6 @@ class OpenAIServingPooling(OpenAIServing): # so there is no need to append extra tokens to the input add_generation_prompt=False, continue_final_message=False, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) else: @@ -129,7 +128,6 @@ class OpenAIServingPooling(OpenAIServing): request, tokenizer, request.input, - truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) except (ValueError, TypeError, jinja2.TemplateError) as e: diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index c54deb371d545..847c014a11dc3 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -266,12 +266,14 @@ class ServingScores(OpenAIServing): request: Union[ScoreRequest, RerankRequest], request_id: str, raw_request: Optional[Request] = None, - truncate_prompt_tokens: Optional[int] = None, ) -> Union[list[PoolingRequestOutput], ErrorResponse]: lora_request = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) + truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens", + None) + tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.max_model_len, truncate_prompt_tokens, tokenization_kwargs) @@ -343,7 +345,6 @@ class ServingScores(OpenAIServing): request, request_id, raw_request, - request.truncate_prompt_tokens, ) if isinstance(final_res_batch, ErrorResponse): return final_res_batch @@ -391,7 +392,6 @@ class ServingScores(OpenAIServing): request, request_id, raw_request, - request.truncate_prompt_tokens, ) if isinstance(final_res_batch, ErrorResponse): return final_res_batch diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 3dbd9057fe0f7..2f2fbe274bf07 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -346,6 +346,22 @@ class InputPreprocessor: ) -> EmbedsInputs: return self._process_embeds(parsed_content) + def _truncate_inputs( + self, + inputs: list[int], + tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]: + + if not tokenization_kwargs or "truncation" not in \ + tokenization_kwargs or self.tokenizer is None: + return inputs + + max_length = tokenization_kwargs["max_length"] + + if self.tokenizer.truncation_side == "left": + return inputs[-max_length:] + else: + return inputs[:max_length] + def _process_tokens( self, parsed_content: TokensPrompt, @@ -354,7 +370,8 @@ class InputPreprocessor: *, mm_hash_overrides: Optional[dict[str, list[str]]] = None, ) -> Union[TokenInputs, MultiModalInputs]: - prompt_token_ids = parsed_content["prompt_token_ids"] + prompt_token_ids = self._truncate_inputs( + parsed_content["prompt_token_ids"], tokenization_kwargs) inputs: Union[TokenInputs, MultiModalInputs] if multi_modal_data := parsed_content.get("multi_modal_data"): @@ -382,7 +399,8 @@ class InputPreprocessor: *, mm_hash_overrides: Optional[dict[str, list[str]]] = None, ) -> Union[TokenInputs, MultiModalInputs]: - prompt_token_ids = parsed_content["prompt_token_ids"] + prompt_token_ids = self._truncate_inputs( + parsed_content["prompt_token_ids"], tokenization_kwargs) inputs: Union[TokenInputs, MultiModalInputs] if multi_modal_data := parsed_content.get("multi_modal_data"): diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 29f037b4372cd..6672392b8d080 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Annotated, Any, Optional import msgspec @@ -27,6 +27,11 @@ class PoolingParams( the classification outputs. softmax: Whether to apply softmax to the reward outputs. """ + truncate_prompt_tokens: Optional[Annotated[int, + msgspec.Meta(ge=-1)]] = None + """If set to -1, will use the truncation size supported by the model. If + set to an integer k, will use only the last k tokens from the prompt + (i.e., left truncation). If set to `None`, truncation is disabled.""" ## for embeddings models dimensions: Optional[int] = None diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index df4cca9ba1147..c7b4ba34c602e 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -182,7 +182,8 @@ class SamplingParams( optionally prompt tokens as a first argument.""" include_stop_str_in_output: bool = False """Whether to include the stop strings in output text.""" - truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, + msgspec.Meta(ge=-1)]] = None """If set to -1, will use the truncation size supported by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is disabled.""" @@ -241,7 +242,8 @@ class SamplingParams( spaces_between_special_tokens: bool = True, logits_processors: Optional[list[LogitsProcessor]] = None, truncate_prompt_tokens: Optional[Annotated[int, - msgspec.Meta(ge=1)]] = None, + msgspec.Meta( + ge=-1)]] = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, guided_decoding: Optional[GuidedDecodingParams] = None, logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None, @@ -411,9 +413,11 @@ class SamplingParams( raise ValueError(f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}.") if (self.truncate_prompt_tokens is not None - and self.truncate_prompt_tokens < 1): - raise ValueError(f"truncate_prompt_tokens must be >= 1, " - f"got {self.truncate_prompt_tokens}") + and (self.truncate_prompt_tokens == 0 + or self.truncate_prompt_tokens < -1)): + raise ValueError( + f"truncate_prompt_tokens must be an integer >= 1 or -1, " + f"got {self.truncate_prompt_tokens}") assert isinstance(self.stop_token_ids, list) if not all(isinstance(st_id, int) for st_id in self.stop_token_ids): raise ValueError(f"stop_token_ids must contain only integers, " diff --git a/vllm/transformers_utils/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group.py index a8bb0398dfdb1..ae8220f9b9dc5 100644 --- a/vllm/transformers_utils/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group.py @@ -23,6 +23,7 @@ class TokenizerGroup: self.tokenizer_config = tokenizer_config self.enable_lora = enable_lora self.max_input_length = max_input_length + self.truncation_side = tokenizer_config.get("truncation_side", "left") self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) max_loras = tokenizer_config.get("max_loras", 0) self.lora_tokenizers = LRUCache[int, AnyTokenizer]( diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index c5ed10326fd50..698aaab3aaa02 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -1328,6 +1328,12 @@ def as_list(maybe_list: Iterable[T]) -> list[T]: return maybe_list if isinstance(maybe_list, list) else list(maybe_list) +def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]: + if isinstance(obj, str) or not isinstance(obj, Iterable): + obj = [obj] + return obj + + # `collections` helpers def is_list_of( value: object,