mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 07:17:04 +08:00
[Core] Support serving encoder/decoder models (#7258)
This commit is contained in:
parent
0fa14907da
commit
7eb4a51c5f
2
.github/workflows/mypy.yaml
vendored
2
.github/workflows/mypy.yaml
vendored
@ -25,7 +25,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install mypy==1.9.0
|
||||
pip install mypy==1.11.1
|
||||
pip install types-setuptools
|
||||
pip install types-PyYAML
|
||||
pip install types-requests
|
||||
|
||||
@ -4,8 +4,8 @@ encoder/decoder models, specifically BART
|
||||
'''
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.utils import zip_enc_dec_prompt_lists
|
||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
TokensPrompt, zip_enc_dec_prompts)
|
||||
|
||||
dtype = "float"
|
||||
|
||||
@ -61,9 +61,9 @@ enc_dec_prompt3 = ExplicitEncoderDecoderPrompt(
|
||||
)
|
||||
|
||||
# - Finally, here's a useful helper function for zipping encoder and
|
||||
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
|
||||
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
|
||||
# instances
|
||||
zipped_prompt_list = zip_enc_dec_prompt_lists(
|
||||
zipped_prompt_list = zip_enc_dec_prompts(
|
||||
['An encoder prompt', 'Another encoder prompt'],
|
||||
['A decoder prompt', 'Another decoder prompt'])
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer == 0.10.3
|
||||
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
|
||||
typing_extensions
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||
pyzmq
|
||||
gguf == 0.9.1
|
||||
|
||||
@ -8,7 +8,7 @@ isort==5.13.2
|
||||
clang-format==18.1.5
|
||||
|
||||
# type checking
|
||||
mypy==1.9.0
|
||||
mypy==1.11.1
|
||||
types-PyYAML
|
||||
types-requests
|
||||
types-setuptools
|
||||
|
||||
@ -3,6 +3,7 @@ import gc
|
||||
import os
|
||||
import sys
|
||||
from collections import UserList
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
|
||||
|
||||
import pytest
|
||||
@ -14,20 +15,19 @@ from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
|
||||
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
|
||||
BatchFeature)
|
||||
|
||||
from tests.models.utils import DecoderPromptType
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import SampleLogprobs
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||
is_cpu, to_enc_dec_tuple_list,
|
||||
zip_enc_dec_prompt_lists)
|
||||
is_cpu)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -124,10 +124,16 @@ def example_prompts() -> List[str]:
|
||||
return prompts
|
||||
|
||||
|
||||
class DecoderPromptType(Enum):
|
||||
"""For encoder/decoder models only."""
|
||||
CUSTOM = 1
|
||||
NONE = 2
|
||||
EMPTY_STR = 3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_encoder_decoder_prompts() \
|
||||
-> Dict[DecoderPromptType,
|
||||
Tuple[List[str], List[Optional[str]]]]:
|
||||
def example_encoder_decoder_prompts(
|
||||
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
|
||||
'''
|
||||
Returns an encoder prompt list and a decoder prompt list, wherein each pair
|
||||
of same-index entries in both lists corresponds to an (encoder prompt,
|
||||
@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \
|
||||
# NONE decoder prompt type
|
||||
return {
|
||||
DecoderPromptType.NONE:
|
||||
zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
|
||||
zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
|
||||
DecoderPromptType.EMPTY_STR:
|
||||
zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
|
||||
zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
|
||||
DecoderPromptType.CUSTOM:
|
||||
zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
|
||||
zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
|
||||
}
|
||||
|
||||
|
||||
@ -444,7 +450,7 @@ class HfRunner:
|
||||
|
||||
def generate_encoder_decoder_greedy_logprobs_limit(
|
||||
self,
|
||||
encoder_decoder_prompts: Tuple[List[str], List[str]],
|
||||
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
**kwargs: Any,
|
||||
@ -608,7 +614,7 @@ class VllmRunner:
|
||||
|
||||
def generate_encoder_decoder_w_logprobs(
|
||||
self,
|
||||
encoder_decoder_prompts: Tuple[List[str], List[str]],
|
||||
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
||||
sampling_params: SamplingParams,
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
'''
|
||||
@ -653,7 +659,7 @@ class VllmRunner:
|
||||
|
||||
def generate_encoder_decoder_greedy_logprobs(
|
||||
self,
|
||||
encoder_decoder_prompts: Tuple[List[str], List[str]],
|
||||
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
|
||||
@ -11,9 +11,9 @@ pytest distributed/test_basic_distributed_correctness_enc_dec.py
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.utils import DecoderPromptType
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
from ..conftest import DecoderPromptType
|
||||
from ..models.utils import check_logprobs_close
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
|
||||
|
||||
50
tests/entrypoints/openai/test_encoder_decoder.py
Normal file
50
tests/entrypoints/openai/test_encoder_decoder.py
Normal file
@ -0,0 +1,50 @@
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "facebook/bart-base"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--enforce-eager",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server):
|
||||
return server.get_async_client()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
assert len(choice.text) >= 5
|
||||
assert choice.finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5, prompt_tokens=2, total_tokens=7)
|
||||
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
Run `pytest tests/models/test_bart.py`.
|
||||
"""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
if not is_cpu():
|
||||
@ -11,22 +13,31 @@ if not is_cpu():
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.models.utils import DecoderPromptType
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ..conftest import DecoderPromptType
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
|
||||
|
||||
DECODER_PROMPT_TYPES = ([
|
||||
DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR,
|
||||
DecoderPromptType.NONE
|
||||
])
|
||||
def vllm_to_hf_output(
|
||||
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
|
||||
decoder_prompt_type: DecoderPromptType,
|
||||
):
|
||||
"""Sanitize vllm output to be comparable with hf output."""
|
||||
output_ids, output_str, out_logprobs = vllm_output
|
||||
|
||||
hf_output_str = output_str + "</s>"
|
||||
if decoder_prompt_type == DecoderPromptType.NONE:
|
||||
hf_output_str = "<s>" + hf_output_str
|
||||
|
||||
return output_ids, hf_output_str, out_logprobs
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES)
|
||||
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@ -146,8 +157,13 @@ if not is_cpu():
|
||||
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
|
||||
else 0)
|
||||
|
||||
check_logprobs_close(outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
num_outputs_0_skip_tokens=hf_skip_tokens)
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=[
|
||||
vllm_to_hf_output(vllm_output, decoder_prompt_type)
|
||||
for vllm_output in vllm_outputs
|
||||
],
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
num_outputs_0_skip_tokens=hf_skip_tokens,
|
||||
)
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from vllm.sequence import SampleLogprobs
|
||||
@ -136,13 +135,3 @@ def check_logprobs_close(
|
||||
warnings.simplefilter("always")
|
||||
|
||||
warnings.warn(fail_msg, stacklevel=2)
|
||||
|
||||
|
||||
class DecoderPromptType(Enum):
|
||||
'''
|
||||
For encoder/decoder models only -
|
||||
|
||||
'''
|
||||
CUSTOM = 1
|
||||
NONE = 2
|
||||
EMPTY_STR = 3
|
||||
|
||||
@ -2,7 +2,7 @@ from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import parse_and_batch_prompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
|
||||
STRING_INPUTS = [
|
||||
'',
|
||||
|
||||
@ -464,6 +464,16 @@ class ModelConfig:
|
||||
if t != "attention"
|
||||
])
|
||||
|
||||
@property
|
||||
def is_encoder_decoder_model(self) -> bool:
|
||||
"""Extract the HF encoder/decoder model flag."""
|
||||
return getattr(self.hf_config, "is_encoder_decoder", False)
|
||||
|
||||
@property
|
||||
def is_embedding_model(self) -> bool:
|
||||
"""Extract the embedding model flag."""
|
||||
return self.embedding_mode
|
||||
|
||||
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache.
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping,
|
||||
Optional, Set, Tuple, Type, Union)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
@ -12,11 +13,14 @@ from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.llm_engine import (DecoderPromptComponents, LLMEngine,
|
||||
PromptComponents)
|
||||
from vllm.engine.metrics import StatLoggerBase
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||
from vllm.inputs import LLMInputs, PromptInputs
|
||||
from vllm.inputs import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
|
||||
SingletonPromptInputs)
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
@ -293,38 +297,138 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
"""Stop the remote worker execution loop."""
|
||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||
|
||||
async def process_model_inputs_async(
|
||||
async def _tokenize_prompt_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> List[int]:
|
||||
"""Async version of :meth:`_tokenize_prompt`."""
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
return await tokenizer.encode_async(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
|
||||
async def _extract_prompt_components_async(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PromptComponents:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
multi_modal_data = None
|
||||
elif isinstance(inputs, dict):
|
||||
if "prompt_token_ids" in inputs:
|
||||
prompt = None
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
else:
|
||||
# NOTE: This extra assignment is required to pass mypy
|
||||
prompt = parsed_prompt = inputs["prompt"]
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
parsed_prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
else:
|
||||
assert_never(inputs)
|
||||
|
||||
return prompt, prompt_token_ids, multi_modal_data
|
||||
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||
encoder_comps: PromptComponents
|
||||
decoder_comps: DecoderPromptComponents
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
encoder_task = self._extract_prompt_components_async(
|
||||
inputs["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := inputs["decoder_prompt"]) is None:
|
||||
encoder_comps = await encoder_task
|
||||
decoder_comps = None, None, None
|
||||
else:
|
||||
decoder_task = self._extract_prompt_components_async(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
encoder_comps, decoder_comps = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
else:
|
||||
encoder_comps = await self._extract_prompt_components_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
decoder_comps = None, None, None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
||||
|
||||
async def _process_decoder_only_prompt_async(
|
||||
self,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||
prompt_comps = await self._extract_prompt_components_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
if "prompt_token_ids" not in inputs:
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
prompt_comps,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
prompt_token_ids = await tokenizer.encode_async(
|
||||
async def process_model_inputs_async(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
|
||||
"""Async version of :meth:`process_model_inputs`."""
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder
|
||||
model_inputs = await self._process_encoder_decoder_prompt_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
prompt=inputs["prompt"],
|
||||
lora_request=lora_request)
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||
"to decoder-only models")
|
||||
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = [
|
||||
0
|
||||
] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \
|
||||
prompt_token_ids
|
||||
# Decoder-only operation
|
||||
model_inputs = await self._process_decoder_only_prompt_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
|
||||
return self.input_processor(llm_inputs)
|
||||
return self.input_processor(model_inputs)
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
@ -336,6 +440,7 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
"""Async version of :meth:`add_request`."""
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
@ -343,10 +448,11 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
arrival_time = time.time()
|
||||
|
||||
processed_inputs = await self.process_model_inputs_async(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
inputs=inputs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
|
||||
@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Type, TypeVar, Union
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
@ -22,10 +24,12 @@ from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (INPUT_REGISTRY, LLMInputs, PromptInputs,
|
||||
get_prompt_type)
|
||||
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, LLMInputs,
|
||||
PromptInputs, SingletonPromptInputs)
|
||||
from vllm.inputs.parse import is_explicit_encoder_decoder_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -43,8 +47,7 @@ from vllm.transformers_utils.tokenizer_group import (
|
||||
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import (Counter, is_embedding_model_config,
|
||||
is_encoder_decoder_model_config)
|
||||
from vllm.utils import Counter
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -66,6 +69,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
|
||||
|
||||
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
||||
|
||||
PromptComponents = Tuple[Optional[str], List[int],
|
||||
Optional[MultiModalDataDict]]
|
||||
DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]],
|
||||
Optional[MultiModalDataDict]]
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
@ -524,7 +532,7 @@ class LLMEngine:
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
||||
|
||||
def _get_decoder_start_token_id(self, ) -> Optional[int]:
|
||||
def _get_decoder_start_token_id(self) -> Optional[int]:
|
||||
'''
|
||||
Obtain the decoder start token id employed by an encoder/decoder
|
||||
model. Returns None for non-encoder/decoder models or if the
|
||||
@ -553,7 +561,7 @@ class LLMEngine:
|
||||
def _add_processed_request(
|
||||
self,
|
||||
request_id: str,
|
||||
processed_inputs: LLMInputs,
|
||||
processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
@ -613,11 +621,11 @@ class LLMEngine:
|
||||
def stop_remote_worker_execution_loop(self) -> None:
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
|
||||
_LLMInputComponentsType = Tuple[str, List[int], ]
|
||||
_LLMInputComponentsType = Tuple[str, List[int]]
|
||||
|
||||
def _prepare_decoder_input_ids_for_generation(
|
||||
self,
|
||||
decoder_input_ids: Optional[List[int]] = None,
|
||||
decoder_input_ids: Optional[List[int]],
|
||||
) -> List[int]:
|
||||
"""
|
||||
Prepares `decoder_input_ids` for generation with encoder-decoder models.
|
||||
@ -639,14 +647,13 @@ class LLMEngine:
|
||||
* Processed token list
|
||||
"""
|
||||
|
||||
decoder_start_token_id: Optional[int] = (
|
||||
self._get_decoder_start_token_id())
|
||||
decoder_start_token_id = self._get_decoder_start_token_id()
|
||||
assert decoder_start_token_id is not None
|
||||
|
||||
if decoder_input_ids is None:
|
||||
# no decoder prompt input ->
|
||||
# use decoder_start_token_id as decoder_input_ids
|
||||
(decoder_input_ids) = self._get_default_enc_dec_decoder_prompt()
|
||||
decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
|
||||
|
||||
if (len(decoder_input_ids) == 0
|
||||
or decoder_input_ids[0] != decoder_start_token_id):
|
||||
@ -657,12 +664,11 @@ class LLMEngine:
|
||||
def _tokenize_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[str] = None,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> List[int]:
|
||||
'''
|
||||
Wrapper around application of the model's
|
||||
tokenizer.
|
||||
Wrapper around application of the model's tokenizer.
|
||||
|
||||
Arguments:
|
||||
|
||||
@ -678,87 +684,72 @@ class LLMEngine:
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
prompt_token_ids = tokenizer.encode(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
return tokenizer.encode(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
def _extract_single_prompt_for_enc_dec_input(
|
||||
def _extract_prompt_components(
|
||||
self,
|
||||
inputs: Optional[PromptInputs],
|
||||
request_id: Optional[str] = None,
|
||||
ptype: Optional[str] = None,
|
||||
is_encoder_prompt: bool = False,
|
||||
) -> Tuple[Optional[str], List[int]]:
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> PromptComponents:
|
||||
'''
|
||||
Only for encoder/decoder models:
|
||||
Extract prompt & prompt_token_ids from any single
|
||||
encoder or decoder input prompt. For encoder input prompts
|
||||
in particular, also extract multi-modal data.
|
||||
|
||||
This function handles the following scenarios:
|
||||
1. The user supplied a singleton encoder prompt
|
||||
& the prompt/prompt-token-ids must be extracted.
|
||||
2. The user supplied an explicit encoder/decoder
|
||||
prompt & the prompt/prompt-token-ids must be
|
||||
extracted from either the encoder and decoder prompts.
|
||||
|
||||
For decoder prompts in particular (scenario 2), special
|
||||
processing is applied to the returned decoder token ids.
|
||||
Extract the components of any single encoder or decoder input prompt.
|
||||
|
||||
Arguments:
|
||||
|
||||
* request_id
|
||||
* ptype: str representation of the input prompt type.
|
||||
If `ptype` is `None`, assume that the prompt
|
||||
type is unknown and must be inferred. This is the
|
||||
case for ExplicitEncoderDecoder sub-prompts.
|
||||
* inputs: single encoder or decoder input prompt
|
||||
* is_encoder_prompt: True if encoder input prompt.
|
||||
If False, decoder prompt tokens
|
||||
are preprocessed.
|
||||
* lora_request: this is only valid for decoder prompts
|
||||
|
||||
Returns:
|
||||
|
||||
* prompt
|
||||
* prompt_token_ids
|
||||
* multi_modal_data
|
||||
'''
|
||||
prompt_token_ids = None
|
||||
ptype = (get_prompt_type(inputs) if ptype is None else ptype)
|
||||
|
||||
if inputs is None:
|
||||
prompt = None
|
||||
elif ptype == 'str':
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
elif ptype == 'TokensPrompt':
|
||||
prompt = None
|
||||
prompt_token_ids = inputs['prompt_token_ids']
|
||||
multi_modal_data = None
|
||||
elif isinstance(inputs, dict):
|
||||
if "prompt_token_ids" in inputs:
|
||||
prompt = None
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
else:
|
||||
# NOTE: This extra assignment is required to pass mypy
|
||||
prompt = parsed_prompt = inputs["prompt"]
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
parsed_prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
else:
|
||||
prompt = inputs['prompt']
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
assert_never(inputs)
|
||||
|
||||
if not is_encoder_prompt:
|
||||
# Apply special pre-processing to
|
||||
# decoder prompts
|
||||
prompt_token_ids = (self._prepare_decoder_input_ids_for_generation(
|
||||
prompt_token_ids, ))
|
||||
return prompt, prompt_token_ids, multi_modal_data
|
||||
|
||||
assert prompt_token_ids is not None
|
||||
def _apply_prompt_adapter(
|
||||
self,
|
||||
prompt_token_ids: List[int],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> List[int]:
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = (
|
||||
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
|
||||
+ prompt_token_ids)
|
||||
|
||||
return (
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
)
|
||||
return prompt_token_ids
|
||||
|
||||
def _get_default_enc_dec_decoder_prompt(self, ) -> List[int]:
|
||||
def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
|
||||
'''
|
||||
Specifically for encoder/decoder models:
|
||||
generate a default decoder prompt for when
|
||||
@ -792,18 +783,39 @@ class LLMEngine:
|
||||
|
||||
bos_token_id = self._get_bos_token_id()
|
||||
assert bos_token_id is not None
|
||||
prompt_token_ids: List[int] = [bos_token_id]
|
||||
return prompt_token_ids
|
||||
return [bos_token_id]
|
||||
|
||||
def _build_enc_dec_llm_inputs(
|
||||
self,
|
||||
encoder_comps: PromptComponents,
|
||||
decoder_comps: DecoderPromptComponents,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps
|
||||
decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps
|
||||
|
||||
if encoder_mm_data is not None or decoder_mm_data is not None:
|
||||
raise ValueError("Multi-modal encoder-decoder models are "
|
||||
"not supported yet")
|
||||
|
||||
decoder_prompt_ids = (
|
||||
self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids))
|
||||
|
||||
return EncoderDecoderLLMInputs(
|
||||
prompt_token_ids=decoder_prompt_ids,
|
||||
prompt=decoder_prompt,
|
||||
encoder_prompt_token_ids=encoder_prompt_ids,
|
||||
encoder_prompt=encoder_prompt,
|
||||
)
|
||||
|
||||
def _process_encoder_decoder_prompt(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
request_id: Optional[str] = None,
|
||||
) -> LLMInputs:
|
||||
request_id: str,
|
||||
) -> EncoderDecoderLLMInputs:
|
||||
'''
|
||||
For encoder/decoder models only:
|
||||
Process an input prompt
|
||||
into an `LLMInputs` instance.
|
||||
Process an input prompt into an
|
||||
:class:`EncoderDecoderLLMInputs` instance.
|
||||
|
||||
There are two types of input prompts:
|
||||
singleton prompts which carry only the
|
||||
@ -830,136 +842,103 @@ class LLMEngine:
|
||||
|
||||
Returns:
|
||||
|
||||
* `LLMInputs` instance
|
||||
* :class:`EncoderDecoderLLMInputs` instance
|
||||
'''
|
||||
|
||||
ptype = get_prompt_type(inputs)
|
||||
encoder_comps: PromptComponents
|
||||
decoder_comps: DecoderPromptComponents
|
||||
|
||||
# Obtain encoder and decoder prompt tokens. Note
|
||||
# that, no matter what, the decoder
|
||||
# prompt type is unknown.
|
||||
if ptype == "ExplicitEncoderDecoder":
|
||||
# If input is explicit encoder/decoder prompt,
|
||||
# then it remains to be determined what type
|
||||
# of encoder prompt we have
|
||||
extracted_encoder_prompt = inputs.get('encoder_prompt')
|
||||
encoder_ptype = None
|
||||
# Extract decoder prompt from explicit
|
||||
# encoder/decoder prompt
|
||||
extracted_decoder_prompt = inputs.get('decoder_prompt')
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
encoder_comps = self._extract_prompt_components(
|
||||
inputs["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := inputs["decoder_prompt"]) is None:
|
||||
decoder_comps = None, None, None
|
||||
else:
|
||||
decoder_comps = self._extract_prompt_components(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
# If input is singleton encoder prompt, then
|
||||
# we know the encoder prompt type
|
||||
extracted_encoder_prompt = inputs
|
||||
encoder_ptype = ptype
|
||||
# Decoder prompt is always unknown if
|
||||
# encoder/decoder prompt is not explicit
|
||||
extracted_decoder_prompt = None
|
||||
encoder_comps = self._extract_prompt_components(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
# Invoke helper function to obtain encoder
|
||||
# prompt and prompt token ids, either from
|
||||
# singleton encoder prompt or from the
|
||||
# encoder sub-prompt of an explicit
|
||||
# encoder/decode scenario 2), special
|
||||
# processing is applied to the returned decoder token ids
|
||||
(
|
||||
encoder_prompt,
|
||||
encoder_prompt_token_ids,
|
||||
) = self._extract_single_prompt_for_enc_dec_input(
|
||||
extracted_encoder_prompt,
|
||||
request_id=request_id,
|
||||
ptype=encoder_ptype,
|
||||
is_encoder_prompt=True,
|
||||
)
|
||||
decoder_comps = None, None, None
|
||||
|
||||
# Invoke helper method to obtain
|
||||
# decoder prompt and prompt token ids.
|
||||
#
|
||||
# The helper method will detect the decoder
|
||||
# prompt type.
|
||||
#
|
||||
# Helper method will also apply special
|
||||
# preprocessing unique to decoder prompts.
|
||||
(
|
||||
decoder_prompt,
|
||||
decoder_prompt_token_ids,
|
||||
) = self._extract_single_prompt_for_enc_dec_input(
|
||||
extracted_decoder_prompt,
|
||||
request_id=request_id,
|
||||
ptype=None,
|
||||
is_encoder_prompt=False,
|
||||
)
|
||||
return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps)
|
||||
|
||||
return LLMInputs(
|
||||
prompt_token_ids=decoder_prompt_token_ids,
|
||||
prompt=decoder_prompt,
|
||||
encoder_prompt_token_ids=encoder_prompt_token_ids,
|
||||
encoder_prompt=encoder_prompt,
|
||||
)
|
||||
def _build_decoder_only_llm_inputs(
|
||||
self,
|
||||
prompt_comps: PromptComponents,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> LLMInputs:
|
||||
prompt, prompt_token_ids, multi_modal_data = prompt_comps
|
||||
|
||||
prompt_token_ids = self._apply_prompt_adapter(
|
||||
prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
def _process_decoder_only_prompt(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
inputs: SingletonPromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
request_id: Optional[str] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
'''
|
||||
For decoder-only models:
|
||||
Process an input prompt
|
||||
into an `LLMInputs` instance.
|
||||
Process an input prompt into an :class:`LLMInputs` instance.
|
||||
|
||||
Arguments:
|
||||
|
||||
* inputs: input prompt
|
||||
* lora_request
|
||||
* request_id
|
||||
* lora_request
|
||||
* prompt_adapter_request
|
||||
|
||||
Returns:
|
||||
|
||||
* `LLMInputs` instance
|
||||
* :class:`LLMInputs` instance
|
||||
'''
|
||||
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
prompt = inputs.get("prompt")
|
||||
prompt_comps = self._extract_prompt_components(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
if "prompt_token_ids" not in inputs:
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
if prompt_adapter_request:
|
||||
prompt_token_ids = (
|
||||
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens
|
||||
+ prompt_token_ids)
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=prompt,
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
return self._build_decoder_only_llm_inputs(
|
||||
prompt_comps,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> LLMInputs:
|
||||
) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
|
||||
|
||||
if self.is_encoder_decoder_model():
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder
|
||||
|
||||
model_inputs = self._process_encoder_decoder_prompt(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
if is_explicit_encoder_decoder_prompt(inputs):
|
||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||
"to decoder-only models")
|
||||
|
||||
# Decoder-only operation
|
||||
model_inputs = self._process_decoder_only_prompt(
|
||||
inputs,
|
||||
@ -1029,10 +1008,11 @@ class LLMEngine:
|
||||
arrival_time = time.time()
|
||||
|
||||
processed_inputs = self.process_model_inputs(
|
||||
inputs,
|
||||
request_id=request_id,
|
||||
inputs=inputs,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
@ -1597,7 +1577,7 @@ class LLMEngine:
|
||||
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)
|
||||
|
||||
def is_encoder_decoder_model(self):
|
||||
return is_encoder_decoder_model_config(self.model_config)
|
||||
return self.model_config.is_encoder_decoder_model
|
||||
|
||||
def is_embedding_model(self):
|
||||
return is_embedding_model_config(self.model_config)
|
||||
return self.model_config.is_embedding_model
|
||||
|
||||
@ -2,8 +2,7 @@ import codecs
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
|
||||
cast, final)
|
||||
from typing import Any, Awaitable, Iterable, List, Optional, Tuple, Union, cast
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
@ -59,7 +58,7 @@ ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam,
|
||||
CustomChatCompletionMessageParam]
|
||||
|
||||
|
||||
@final # So that it should be compatible with Dict[str, str]
|
||||
# TODO: Make fields ReadOnly once mypy supports it
|
||||
class ConversationMessage(TypedDict):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
@ -6,8 +6,8 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
|
||||
parse_and_batch_prompt)
|
||||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
|
||||
@ -40,9 +40,11 @@ def _get_allowed_token_ids_logits_processor(
|
||||
return AllowedTokenIdsLogitsProcessor(allowed_token_ids)
|
||||
|
||||
|
||||
def logit_bias_logits_processor(logit_bias: Dict[str,
|
||||
float], token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
def logit_bias_logits_processor(
|
||||
logit_bias: Dict[int, float],
|
||||
token_ids: List[int],
|
||||
logits: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
for token_id, bias in logit_bias.items():
|
||||
logits[token_id] += bias
|
||||
return logits
|
||||
|
||||
@ -22,7 +22,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeRequest)
|
||||
# yapf: enable
|
||||
from vllm.inputs import parse_and_batch_prompt
|
||||
from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from .data import (ExplicitEncoderDecoderPrompt, LLMInputs, ParsedText,
|
||||
ParsedTokens, PromptInputs, SingletonPromptInputs,
|
||||
TextPrompt, TokensPrompt, get_prompt_type,
|
||||
is_valid_encoder_decoder_llm_inputs, parse_and_batch_prompt)
|
||||
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
||||
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
|
||||
TokensPrompt, build_explicit_enc_dec_prompt,
|
||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||
from .registry import InputContext, InputRegistry
|
||||
|
||||
INPUT_REGISTRY = InputRegistry()
|
||||
@ -14,18 +14,17 @@ See also:
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"ParsedText",
|
||||
"ParsedTokens",
|
||||
"parse_and_batch_prompt",
|
||||
"TextPrompt",
|
||||
"TokensPrompt",
|
||||
"PromptInputs",
|
||||
"SingletonPromptInputs",
|
||||
"ExplicitEncoderDecoderPrompt",
|
||||
"LLMInputs",
|
||||
"EncoderDecoderLLMInputs",
|
||||
"build_explicit_enc_dec_prompt",
|
||||
"to_enc_dec_tuple_list",
|
||||
"zip_enc_dec_prompts",
|
||||
"INPUT_REGISTRY",
|
||||
"InputContext",
|
||||
"InputRegistry",
|
||||
"get_prompt_type",
|
||||
"is_valid_encoder_decoder_llm_inputs",
|
||||
"ExplicitEncoderDecoderPrompt",
|
||||
"SingletonPromptInputs",
|
||||
]
|
||||
|
||||
@ -1,71 +1,12 @@
|
||||
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
|
||||
TypedDict, Union, cast, overload)
|
||||
from typing import (TYPE_CHECKING, Generic, Iterable, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
content: str
|
||||
is_tokens: Literal[False]
|
||||
|
||||
|
||||
class ParsedTokens(TypedDict):
|
||||
content: List[int]
|
||||
is_tokens: Literal[True]
|
||||
|
||||
|
||||
# https://github.com/vllm-project/vllm/pull/4028
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
|
||||
...
|
||||
|
||||
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
|
||||
if isinstance(prompt, str):
|
||||
# case 1: a string
|
||||
return [ParsedText(content=prompt, is_tokens=False)]
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
if isinstance(prompt[0], str):
|
||||
# case 2: array of strings
|
||||
return [
|
||||
ParsedText(content=elem, is_tokens=False)
|
||||
for elem in cast(List[str], prompt)
|
||||
]
|
||||
if isinstance(prompt[0], int):
|
||||
# case 3: array of tokens
|
||||
elem = cast(List[int], prompt)
|
||||
return [ParsedTokens(content=elem, is_tokens=True)]
|
||||
if isinstance(prompt[0], list):
|
||||
if len(prompt[0]) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
if isinstance(prompt[0][0], int):
|
||||
# case 4: array of token arrays
|
||||
return [
|
||||
ParsedTokens(content=elem, is_tokens=True)
|
||||
for elem in cast(List[List[int]], prompt)
|
||||
]
|
||||
|
||||
raise ValueError("prompt must be a string, array of strings, "
|
||||
"array of tokens, or array of token arrays")
|
||||
|
||||
|
||||
class TextPrompt(TypedDict):
|
||||
"""Schema for a text prompt."""
|
||||
|
||||
@ -103,39 +44,49 @@ Note that "singleton" is as opposed to a data structure
|
||||
which encapsulates multiple prompts, i.e. of the sort
|
||||
which may be utilized for encoder/decoder models when
|
||||
the user desires to express both the encoder & decoder
|
||||
prompts explicitly, i.e. ExplicitEncoderDecoderPrompt
|
||||
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
|
||||
|
||||
A prompt of type SingletonPromptInputs may be employed
|
||||
A prompt of type :class:`SingletonPromptInputs` may be employed
|
||||
as (1) input to a decoder-only model, (2) input to
|
||||
the encoder of an encoder/decoder model, in the scenario
|
||||
where the decoder-prompt is not specified explicitly, or
|
||||
(3) as a member of a larger data structure encapsulating
|
||||
more than one prompt, i.e. ExplicitEncoderDecoderPrompt
|
||||
more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
|
||||
"""
|
||||
|
||||
_T1_co = TypeVar("_T1_co",
|
||||
bound=SingletonPromptInputs,
|
||||
default=SingletonPromptInputs,
|
||||
covariant=True)
|
||||
_T2_co = TypeVar("_T2_co",
|
||||
bound=SingletonPromptInputs,
|
||||
default=SingletonPromptInputs,
|
||||
covariant=True)
|
||||
|
||||
class ExplicitEncoderDecoderPrompt(TypedDict):
|
||||
|
||||
# TODO: Make fields ReadOnly once mypy supports it
|
||||
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
||||
"""Represents an encoder/decoder model input prompt,
|
||||
comprising an explicit encoder prompt and a
|
||||
decoder prompt.
|
||||
|
||||
The encoder and decoder prompts, respectively,
|
||||
may formatted according to any of the
|
||||
SingletonPromptInputs schemas, and are not
|
||||
:class:`SingletonPromptInputs` schemas, and are not
|
||||
required to have the same schema.
|
||||
|
||||
Only the encoder prompt may have multi-modal data.
|
||||
|
||||
Note that an ExplicitEncoderDecoderPrompt may not
|
||||
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
|
||||
be used as an input to a decoder-only model,
|
||||
and that the `encoder_prompt` and `decoder_prompt`
|
||||
fields of this data structure may not themselves
|
||||
must be SingletonPromptInputs instances.
|
||||
fields of this data structure themselves must be
|
||||
:class:`SingletonPromptInputs` instances.
|
||||
"""
|
||||
|
||||
encoder_prompt: SingletonPromptInputs
|
||||
encoder_prompt: _T1_co
|
||||
|
||||
decoder_prompt: SingletonPromptInputs
|
||||
decoder_prompt: Optional[_T2_co]
|
||||
|
||||
|
||||
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
|
||||
@ -150,60 +101,12 @@ both decoder-only and encoder/decoder input types:
|
||||
"""
|
||||
|
||||
|
||||
def _has_required_keys(
|
||||
d: dict,
|
||||
required_keys: set,
|
||||
) -> bool:
|
||||
return required_keys.issubset(d.keys())
|
||||
|
||||
|
||||
def get_prompt_type(prompt: Optional[PromptInputs]) -> Optional[str]:
|
||||
"""
|
||||
Get the type-name of the prompt argument instance, given that
|
||||
isinstance() cannot apply to TypedDict subclasses directly.
|
||||
If the prompt is None, return 'None' as the type name.
|
||||
|
||||
Arguments:
|
||||
|
||||
* prompt: LLM input prompt or None
|
||||
|
||||
Returns:
|
||||
|
||||
* String representation of prompt type
|
||||
"""
|
||||
|
||||
if prompt is None:
|
||||
return 'None'
|
||||
|
||||
required_keys_dict = {
|
||||
'TextPrompt': {'prompt'},
|
||||
'TokensPrompt': {'prompt_token_ids'},
|
||||
'ExplicitEncoderDecoder': {'encoder_prompt', 'decoder_prompt'},
|
||||
}
|
||||
|
||||
if isinstance(prompt, dict):
|
||||
for (ptype, required_keys) in required_keys_dict.items():
|
||||
# Ignore type checking in the conditional below because type
|
||||
# checker does not understand that is_dict(prompt) narrows
|
||||
# down the possible types
|
||||
if _has_required_keys(
|
||||
prompt, # type: ignore
|
||||
required_keys):
|
||||
return ptype
|
||||
|
||||
raise ValueError(f"Invalid prompt {prompt}, valid types are "
|
||||
"required_keys_dict={required_keys_dict}")
|
||||
|
||||
if isinstance(prompt, str):
|
||||
return "str"
|
||||
|
||||
raise ValueError(f"Invalid prompt {prompt}")
|
||||
|
||||
|
||||
class LLMInputs(TypedDict):
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
passed to the model executor.
|
||||
|
||||
This specifies the data required for decoder-only models.
|
||||
"""
|
||||
prompt_token_ids: List[int]
|
||||
"""The token IDs of the prompt."""
|
||||
@ -213,7 +116,21 @@ class LLMInputs(TypedDict):
|
||||
The original prompt text corresponding to the token IDs, if available.
|
||||
"""
|
||||
|
||||
encoder_prompt_token_ids: NotRequired[List[int]]
|
||||
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
|
||||
class EncoderDecoderLLMInputs(LLMInputs):
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
passed to the model executor.
|
||||
|
||||
This specifies the required data for encoder-decoder models.
|
||||
"""
|
||||
encoder_prompt_token_ids: List[int]
|
||||
"""The token IDs of the encoder prompt."""
|
||||
|
||||
encoder_prompt: NotRequired[Optional[str]]
|
||||
@ -222,20 +139,40 @@ class LLMInputs(TypedDict):
|
||||
available.
|
||||
"""
|
||||
|
||||
multi_modal_data: NotRequired[Optional["MultiModalDataDict"]]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
_T1 = TypeVar("_T1",
|
||||
bound=SingletonPromptInputs,
|
||||
default=SingletonPromptInputs)
|
||||
_T2 = TypeVar("_T2",
|
||||
bound=SingletonPromptInputs,
|
||||
default=SingletonPromptInputs)
|
||||
|
||||
|
||||
def is_valid_encoder_decoder_llm_inputs(inputs: LLMInputs) -> bool:
|
||||
"""
|
||||
Return True if the LLMInputs instance has the correct configuration
|
||||
for encoder/decoder.
|
||||
"""
|
||||
def build_explicit_enc_dec_prompt(
|
||||
encoder_prompt: _T1,
|
||||
decoder_prompt: Optional[_T2],
|
||||
) -> ExplicitEncoderDecoderPrompt[_T1, _T2]:
|
||||
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
|
||||
decoder_prompt=decoder_prompt)
|
||||
|
||||
# True if encoder prompt token ids field exists &
|
||||
# is not None
|
||||
return ('encoder_prompt_token_ids' in inputs
|
||||
and inputs['encoder_prompt_token_ids'] is not None)
|
||||
|
||||
def zip_enc_dec_prompts(
|
||||
enc_prompts: Iterable[_T1],
|
||||
dec_prompts: Iterable[Optional[_T2]],
|
||||
) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]:
|
||||
"""
|
||||
Zip encoder and decoder prompts together into a list of
|
||||
:class:`ExplicitEncoderDecoderPrompt` instances.
|
||||
"""
|
||||
return [
|
||||
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
|
||||
for (encoder_prompt, decoder_prompt) in zip(enc_prompts, dec_prompts)
|
||||
]
|
||||
|
||||
|
||||
def to_enc_dec_tuple_list(
|
||||
enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]],
|
||||
) -> List[Tuple[_T1, Optional[_T2]]]:
|
||||
return [(enc_dec_prompt["encoder_prompt"],
|
||||
enc_dec_prompt["decoder_prompt"])
|
||||
for enc_dec_prompt in enc_dec_prompts]
|
||||
|
||||
75
vllm/inputs/parse.py
Normal file
75
vllm/inputs/parse.py
Normal file
@ -0,0 +1,75 @@
|
||||
from typing import List, Literal, Sequence, TypedDict, Union, overload
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
||||
LLMInputs, PromptInputs)
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
content: str
|
||||
is_tokens: Literal[False]
|
||||
|
||||
|
||||
class ParsedTokens(TypedDict):
|
||||
content: List[int]
|
||||
is_tokens: Literal[True]
|
||||
|
||||
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
|
||||
...
|
||||
|
||||
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
|
||||
if isinstance(prompt, str):
|
||||
# case 1: a string
|
||||
return [ParsedText(content=prompt, is_tokens=False)]
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
if is_list_of(prompt, str):
|
||||
# case 2: array of strings
|
||||
return [
|
||||
ParsedText(content=elem, is_tokens=False) for elem in prompt
|
||||
]
|
||||
if is_list_of(prompt, int):
|
||||
# case 3: array of tokens
|
||||
return [ParsedTokens(content=prompt, is_tokens=True)]
|
||||
if is_list_of(prompt, list):
|
||||
if len(prompt[0]) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
if is_list_of(prompt[0], int):
|
||||
# case 4: array of token arrays
|
||||
return [
|
||||
ParsedTokens(content=elem, is_tokens=True)
|
||||
for elem in prompt
|
||||
]
|
||||
|
||||
raise ValueError("prompt must be a string, array of strings, "
|
||||
"array of tokens, or array of token arrays")
|
||||
|
||||
|
||||
def is_explicit_encoder_decoder_prompt(
|
||||
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||
return isinstance(inputs, dict) and "encoder_prompt" in inputs
|
||||
|
||||
|
||||
def is_valid_encoder_decoder_llm_inputs(
|
||||
inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
|
||||
) -> TypeIs[EncoderDecoderLLMInputs]:
|
||||
return "encoder_prompt_token_ids" in inputs
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
|
||||
Union, overload, runtime_checkable)
|
||||
|
||||
from typing_extensions import TypeGuard
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -37,18 +37,18 @@ class _SupportsVisionType(Protocol):
|
||||
|
||||
|
||||
@overload
|
||||
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
|
||||
def supports_vision(model: Type[object]) -> TypeIs[Type[SupportsVision]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
|
||||
def supports_vision(model: object) -> TypeIs[SupportsVision]:
|
||||
...
|
||||
|
||||
|
||||
def supports_vision(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
|
||||
) -> Union[TypeIs[Type[SupportsVision]], TypeIs[SupportsVision]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _SupportsVisionType)
|
||||
|
||||
@ -94,18 +94,18 @@ class _SupportsLoRAType(Protocol):
|
||||
|
||||
|
||||
@overload
|
||||
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
|
||||
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
|
||||
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
|
||||
...
|
||||
|
||||
|
||||
def supports_lora(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
|
||||
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
|
||||
result = _supports_lora(model)
|
||||
|
||||
if not result:
|
||||
@ -137,7 +137,7 @@ def supports_lora(
|
||||
|
||||
def _supports_lora(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
|
||||
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _SupportsLoRAType)
|
||||
|
||||
@ -172,18 +172,18 @@ class _HasInnerStateType(Protocol):
|
||||
|
||||
|
||||
@overload
|
||||
def has_inner_state(model: object) -> TypeGuard[HasInnerState]:
|
||||
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def has_inner_state(model: Type[object]) -> TypeGuard[Type[HasInnerState]]:
|
||||
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
|
||||
...
|
||||
|
||||
|
||||
def has_inner_state(
|
||||
model: Union[Type[object], object]
|
||||
) -> Union[TypeGuard[Type[HasInnerState]], TypeGuard[HasInnerState]]:
|
||||
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _HasInnerStateType)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from vllm.inputs.registry import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.image_processor import get_image_processor
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .base import MultiModalInputs, MultiModalPlugin
|
||||
|
||||
@ -113,7 +114,8 @@ class ImagePlugin(MultiModalPlugin):
|
||||
def _default_input_mapper(self, ctx: InputContext,
|
||||
data: object) -> MultiModalInputs:
|
||||
model_config = ctx.model_config
|
||||
if isinstance(data, (Image.Image, list)):
|
||||
|
||||
if isinstance(data, Image.Image) or is_list_of(data, Image.Image):
|
||||
image_processor = self._get_hf_image_processor(model_config)
|
||||
if image_processor is None:
|
||||
raise RuntimeError("No HuggingFace processor is available "
|
||||
@ -127,7 +129,7 @@ class ImagePlugin(MultiModalPlugin):
|
||||
raise
|
||||
|
||||
return MultiModalInputs(batch_data)
|
||||
elif isinstance(data, torch.Tensor):
|
||||
elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
|
||||
raise TypeError(f"Invalid image type: {type(data)}")
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.inputs import is_valid_encoder_decoder_llm_inputs
|
||||
from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
|
||||
@ -17,8 +17,8 @@ from collections import defaultdict
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
|
||||
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
|
||||
Union, overload)
|
||||
Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
|
||||
Type, TypeVar, Union, overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
@ -26,12 +26,10 @@ import numpy.typing as npt
|
||||
import psutil
|
||||
import torch
|
||||
import torch.types
|
||||
from typing_extensions import ParamSpec
|
||||
from typing_extensions import ParamSpec, TypeIs, assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
|
||||
SingletonPromptInputs)
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -812,6 +810,24 @@ def get_dtype_size(dtype: torch.dtype) -> int:
|
||||
return torch.tensor([], dtype=dtype).element_size()
|
||||
|
||||
|
||||
# `collections` helpers
|
||||
def is_list_of(
|
||||
value: object,
|
||||
typ: Type[T],
|
||||
*,
|
||||
check: Literal["first", "all"] = "first",
|
||||
) -> TypeIs[List[T]]:
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if check == "first":
|
||||
return len(value) == 0 or isinstance(value[0], typ)
|
||||
elif check == "all":
|
||||
return all(isinstance(v, typ) for v in value)
|
||||
|
||||
assert_never(check)
|
||||
|
||||
|
||||
def merge_dicts(dict1: Dict[K, List[T]],
|
||||
dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
|
||||
"""Merge 2 dicts that have key -> List of items.
|
||||
@ -959,6 +975,7 @@ def enable_trace_function_call_for_thread() -> None:
|
||||
enable_trace_function_call(log_path)
|
||||
|
||||
|
||||
# `functools` helpers
|
||||
def identity(value: T) -> T:
|
||||
return value
|
||||
|
||||
@ -1080,50 +1097,3 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
|
||||
"""Utility function to run async task in a lock"""
|
||||
async with lock:
|
||||
return await task(*args, **kwargs)
|
||||
|
||||
|
||||
def is_encoder_decoder_model_config(model_config) -> bool:
|
||||
'''
|
||||
Extract the HF encoder/decoder model flag from the ModelConfig instance.
|
||||
Return False if model_config is None.
|
||||
'''
|
||||
return model_config is not None and \
|
||||
getattr(model_config.hf_config,
|
||||
"is_encoder_decoder",
|
||||
False)
|
||||
|
||||
|
||||
def is_embedding_model_config(model_config) -> bool:
|
||||
'''
|
||||
Extract the embedding model flag from the ModelConfig instance.
|
||||
Return False if model_config is None.
|
||||
'''
|
||||
return model_config is not None and \
|
||||
model_config.embedding_mode
|
||||
|
||||
|
||||
def build_explicit_enc_dec_prompt(
|
||||
encoder_prompt: SingletonPromptInputs,
|
||||
decoder_prompt: SingletonPromptInputs,
|
||||
) -> ExplicitEncoderDecoderPrompt:
|
||||
return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
|
||||
decoder_prompt=decoder_prompt)
|
||||
|
||||
|
||||
def zip_enc_dec_prompt_lists(
|
||||
enc_prompt_list: List[SingletonPromptInputs],
|
||||
dec_prompt_list: List[SingletonPromptInputs],
|
||||
) -> List[ExplicitEncoderDecoderPrompt]:
|
||||
return [
|
||||
build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
|
||||
for (encoder_prompt,
|
||||
decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
|
||||
]
|
||||
|
||||
|
||||
def to_enc_dec_tuple_list(
|
||||
enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
|
||||
) -> List[Tuple[PromptInputs, PromptInputs]]:
|
||||
return [(enc_dec_prompt['encoder_prompt'],
|
||||
enc_dec_prompt['decoder_prompt'])
|
||||
for enc_dec_prompt in enc_dec_prompts]
|
||||
|
||||
@ -19,8 +19,6 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (is_embedding_model_config,
|
||||
is_encoder_decoder_model_config)
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||||
@ -113,10 +111,10 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||
|
||||
def _is_encoder_decoder_model(self):
|
||||
return is_encoder_decoder_model_config(self.model_config)
|
||||
return self.model_config.is_encoder_decoder_model
|
||||
|
||||
def _is_embedding_model(self):
|
||||
return is_embedding_model_config(self.model_config)
|
||||
return self.model_config.is_embedding_model
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "cuda":
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user