From eeec9e339005d887e0064f7b3e7771295ecd68e7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 13 Dec 2024 18:40:07 +0800 Subject: [PATCH] [Frontend] Separate pooling APIs in offline inference (#11129) Signed-off-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 7 +- docs/source/models/pooling_models.rst | 53 +++- examples/offline_inference_classification.py | 28 ++ examples/offline_inference_embedding.py | 16 +- examples/offline_inference_scoring.py | 23 ++ ...ine_inference_vision_language_embedding.py | 2 +- tests/conftest.py | 18 +- tests/entrypoints/openai/test_score.py | 10 +- .../models/embedding/language/test_scoring.py | 10 +- tests/models/test_oot_registration.py | 5 +- vllm/__init__.py | 36 +-- vllm/engine/llm_engine.py | 17 +- vllm/entrypoints/llm.py | 143 ++++++++- vllm/entrypoints/openai/protocol.py | 2 +- vllm/entrypoints/openai/serving_embedding.py | 9 +- vllm/entrypoints/openai/serving_engine.py | 12 +- vllm/entrypoints/openai/serving_score.py | 12 +- vllm/model_executor/layers/pooler.py | 290 ++++++++++++------ vllm/model_executor/models/gritlm.py | 15 +- vllm/outputs.py | 225 +++++++++----- vllm/sequence.py | 40 ++- 21 files changed, 669 insertions(+), 304 deletions(-) create mode 100644 examples/offline_inference_classification.py create mode 100644 examples/offline_inference_scoring.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 6a6ee3cf713ae..97aae233db105 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -181,14 +181,14 @@ steps: commands: - VLLM_USE_V1=1 pytest -v -s v1 -- label: Examples Test # 15min +- label: Examples Test # 25min working_dir: "/vllm-workspace/examples" #mirror_hardwares: [amd] source_file_dependencies: - vllm/entrypoints - examples/ commands: - - pip install awscli tensorizer # for llava example and tensorizer test + - pip install tensorizer # for tensorizer test - python3 offline_inference.py - python3 cpu_offload.py - python3 offline_inference_chat.py @@ -198,6 +198,9 @@ steps: - python3 offline_inference_vision_language_multi_image.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py + - python3 offline_inference_classification.py + - python3 offline_inference_embedding.py + - python3 offline_inference_scoring.py - python3 offline_profile.py --model facebook/opt-125m - label: Prefix Caching Test # 9min diff --git a/docs/source/models/pooling_models.rst b/docs/source/models/pooling_models.rst index 7fa66274c3c5a..94475c5e6689d 100644 --- a/docs/source/models/pooling_models.rst +++ b/docs/source/models/pooling_models.rst @@ -6,7 +6,7 @@ Pooling Models vLLM also supports pooling models, including embedding, reranking and reward models. In vLLM, pooling models implement the :class:`~vllm.model_executor.models.VllmModelForPooling` interface. -These models use a :class:`~vllm.model_executor.layers.Pooler` to aggregate the final hidden states of the input +These models use a :class:`~vllm.model_executor.layers.Pooler` to extract the final hidden states of the input before returning them. .. note:: @@ -45,20 +45,48 @@ which takes priority over both the model's and Sentence Transformers's defaults. ^^^^^^^^^^^^^^ The :class:`~vllm.LLM.encode` method is available to all pooling models in vLLM. -It returns the aggregated hidden states directly. +It returns the extracted hidden states directly, which is useful for reward models. + +.. code-block:: python + + llm = LLM(model="Qwen/Qwen2.5-Math-RM-72B", task="reward") + output, = llm.encode("Hello, my name is") + + data = output.outputs.data + print(f"Prompt: {prompt!r} | Data: {data!r}") + +``LLM.embed`` +^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.embed` method outputs an embedding vector for each prompt. +It is primarily designed for embedding models. .. code-block:: python llm = LLM(model="intfloat/e5-mistral-7b-instruct", task="embed") - outputs = llm.encode("Hello, my name is") + output, = llm.embed("Hello, my name is") - outputs = model.encode(prompts) - for output in outputs: - embeddings = output.outputs.embedding - print(f"Prompt: {prompt!r}, Embeddings (size={len(embeddings)}: {embeddings!r}") + embeds = output.outputs.embedding + print(f"Embeddings: {embeds!r} (size={len(embeds)})") A code example can be found in `examples/offline_inference_embedding.py `_. +``LLM.classify`` +^^^^^^^^^^^^^^^^ + +The :class:`~vllm.LLM.classify` method outputs a probability vector for each prompt. +It is primarily designed for classification models. + +.. code-block:: python + + llm = LLM(model="jason9693/Qwen2.5-1.5B-apeach", task="classify") + output, = llm.classify("Hello, my name is") + + probs = output.outputs.probs + print(f"Class Probabilities: {probs!r} (size={len(probs)})") + +A code example can be found in `examples/offline_inference_classification.py `_. + ``LLM.score`` ^^^^^^^^^^^^^ @@ -71,7 +99,16 @@ These types of models serve as rerankers between candidate query-document pairs vLLM can only perform the model inference component (e.g. embedding, reranking) of RAG. To handle RAG at a higher level, you should use integration frameworks such as `LangChain `_. -You can use `these tests `_ as reference. +.. code-block:: python + + llm = LLM(model="BAAI/bge-reranker-v2-m3", task="score") + output, = llm.score("What is the capital of France?", + "The capital of Brazil is Brasilia.") + + score = output.outputs.score + print(f"Score: {score}") + +A code example can be found in `examples/offline_inference_scoring.py `_. Online Inference ---------------- diff --git a/examples/offline_inference_classification.py b/examples/offline_inference_classification.py new file mode 100644 index 0000000000000..de539b639a196 --- /dev/null +++ b/examples/offline_inference_classification.py @@ -0,0 +1,28 @@ +from vllm import LLM + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# Create an LLM. +# You should pass task="classify" for classification models +model = LLM( + model="jason9693/Qwen2.5-1.5B-apeach", + task="classify", + enforce_eager=True, +) + +# Generate logits. The output is a list of ClassificationRequestOutputs. +outputs = model.classify(prompts) + +# Print the outputs. +for prompt, output in zip(prompts, outputs): + probs = output.outputs.probs + probs_trimmed = ((str(probs[:16])[:-1] + + ", ...]") if len(probs) > 16 else probs) + print(f"Prompt: {prompt!r} | " + f"Class Probabilities: {probs_trimmed} (size={len(probs)})") diff --git a/examples/offline_inference_embedding.py b/examples/offline_inference_embedding.py index 17f6d992073d7..58d004313ad51 100644 --- a/examples/offline_inference_embedding.py +++ b/examples/offline_inference_embedding.py @@ -9,14 +9,20 @@ prompts = [ ] # Create an LLM. +# You should pass task="embed" for embedding models model = LLM( model="intfloat/e5-mistral-7b-instruct", - task="embed", # You should pass task="embed" for embedding models + task="embed", enforce_eager=True, ) -# Generate embedding. The output is a list of PoolingRequestOutputs. -outputs = model.encode(prompts) +# Generate embedding. The output is a list of EmbeddingRequestOutputs. +outputs = model.embed(prompts) + # Print the outputs. -for output in outputs: - print(output.outputs.embedding) # list of 4096 floats +for prompt, output in zip(prompts, outputs): + embeds = output.outputs.embedding + embeds_trimmed = ((str(embeds[:16])[:-1] + + ", ...]") if len(embeds) > 16 else embeds) + print(f"Prompt: {prompt!r} | " + f"Embeddings: {embeds_trimmed} (size={len(embeds)})") diff --git a/examples/offline_inference_scoring.py b/examples/offline_inference_scoring.py new file mode 100644 index 0000000000000..5da9e710959b5 --- /dev/null +++ b/examples/offline_inference_scoring.py @@ -0,0 +1,23 @@ +from vllm import LLM + +# Sample prompts. +text_1 = "What is the capital of France?" +texts_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." +] + +# Create an LLM. +# You should pass task="score" for cross-encoder models +model = LLM( + model="BAAI/bge-reranker-v2-m3", + task="score", + enforce_eager=True, +) + +# Generate scores. The output is a list of ScoringRequestOutputs. +outputs = model.score(text_1, texts_2) + +# Print the outputs. +for text_2, output in zip(texts_2, outputs): + score = output.outputs.score + print(f"Pair: {[text_1, text_2]!r} | Score: {score}") diff --git a/examples/offline_inference_vision_language_embedding.py b/examples/offline_inference_vision_language_embedding.py index bf466109f0981..4ce3d496bf45b 100644 --- a/examples/offline_inference_vision_language_embedding.py +++ b/examples/offline_inference_vision_language_embedding.py @@ -133,7 +133,7 @@ def run_encode(model: str, modality: QueryModality): if req_data.image is not None: mm_data["image"] = req_data.image - outputs = req_data.llm.encode({ + outputs = req_data.llm.embed({ "prompt": req_data.prompt, "multi_modal_data": mm_data, }) diff --git a/tests/conftest.py b/tests/conftest.py index 7606e0f11dfeb..4e939221329cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -719,14 +719,6 @@ class VllmRunner: return inputs - def classify(self, prompts: List[str]) -> List[str]: - req_outputs = self.model.encode(prompts) - outputs = [] - for req_output in req_outputs: - embedding = req_output.outputs.embedding - outputs.append(embedding) - return outputs - def generate( self, prompts: List[str], @@ -897,6 +889,10 @@ class VllmRunner: returned_outputs.append((token_ids, texts)) return returned_outputs + def classify(self, prompts: List[str]) -> List[List[float]]: + req_outputs = self.model.classify(prompts) + return [req_output.outputs.probs for req_output in req_outputs] + def encode( self, prompts: List[str], @@ -909,16 +905,16 @@ class VllmRunner: videos=videos, audios=audios) - req_outputs = self.model.encode(inputs) + req_outputs = self.model.embed(inputs) return [req_output.outputs.embedding for req_output in req_outputs] def score( self, text_1: Union[str, List[str]], text_2: Union[str, List[str]], - ) -> List[List[float]]: + ) -> List[float]: req_outputs = self.model.score(text_1, text_2) - return [req_output.outputs.embedding for req_output in req_outputs] + return [req_output.outputs.score for req_output in req_outputs] def __enter__(self): return self diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index 7565ff7192f67..0698c19ad0023 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -39,8 +39,8 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer, assert score.id is not None assert score.data is not None assert len(score.data) == 2 - assert score.data[0].score[0] <= 0.01 - assert score.data[1].score[0] >= 0.9 + assert score.data[0].score <= 0.01 + assert score.data[1].score >= 0.9 @pytest.mark.asyncio @@ -67,8 +67,8 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer, assert score.id is not None assert score.data is not None assert len(score.data) == 2 - assert score.data[0].score[0] <= 0.01 - assert score.data[1].score[0] >= 0.9 + assert score.data[0].score <= 0.01 + assert score.data[1].score >= 0.9 @pytest.mark.asyncio @@ -90,4 +90,4 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer, assert score.id is not None assert score.data is not None assert len(score.data) == 1 - assert score.data[0].score[0] >= 0.9 + assert score.data[0].score >= 0.9 diff --git a/tests/models/embedding/language/test_scoring.py b/tests/models/embedding/language/test_scoring.py index 0c3115d195fc1..af31e1a635f65 100644 --- a/tests/models/embedding/language/test_scoring.py +++ b/tests/models/embedding/language/test_scoring.py @@ -42,7 +42,7 @@ def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str): assert len(vllm_outputs) == 1 assert len(hf_outputs) == 1 - assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) + assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) @pytest.mark.parametrize("dtype", ["half"]) @@ -63,8 +63,8 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): assert len(vllm_outputs) == 2 assert len(hf_outputs) == 2 - assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) - assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01) + assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) + assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01) @pytest.mark.parametrize("dtype", ["half"]) @@ -85,5 +85,5 @@ def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str): assert len(vllm_outputs) == 2 assert len(hf_outputs) == 2 - assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) - assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01) + assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) + assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01) diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 94be215258f89..2c413a633896a 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -2,7 +2,7 @@ import os import pytest -from vllm import LLM, PoolingParams, SamplingParams +from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from ..utils import fork_new_process_for_each_test @@ -36,9 +36,8 @@ def test_oot_registration_text_generation(dummy_opt_path): def test_oot_registration_embedding(dummy_gemma2_embedding_path): os.environ["VLLM_PLUGINS"] = "register_dummy_model" prompts = ["Hello, my name is", "The text does not matter"] - sampling_params = PoolingParams() llm = LLM(model=dummy_gemma2_embedding_path, load_format="dummy") - outputs = llm.encode(prompts, sampling_params) + outputs = llm.embed(prompts) for output in outputs: assert all(v == 0 for v in output.outputs.embedding) diff --git a/vllm/__init__.py b/vllm/__init__.py index a10f6d3128cb6..45252b93e3d54 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -7,8 +7,11 @@ from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry -from vllm.outputs import (CompletionOutput, PoolingOutput, - PoolingRequestOutput, RequestOutput) +from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput, + CompletionOutput, EmbeddingOutput, + EmbeddingRequestOutput, PoolingOutput, + PoolingRequestOutput, RequestOutput, ScoringOutput, + ScoringRequestOutput) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams @@ -27,6 +30,12 @@ __all__ = [ "CompletionOutput", "PoolingOutput", "PoolingRequestOutput", + "EmbeddingOutput", + "EmbeddingRequestOutput", + "ClassificationOutput", + "ClassificationRequestOutput", + "ScoringOutput", + "ScoringRequestOutput", "LLMEngine", "EngineArgs", "AsyncLLMEngine", @@ -34,26 +43,3 @@ __all__ = [ "initialize_ray_cluster", "PoolingParams", ] - - -def __getattr__(name: str): - import warnings - - if name == "EmbeddingOutput": - msg = ("EmbeddingOutput has been renamed to PoolingOutput. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return PoolingOutput - - if name == "EmbeddingRequestOutput": - msg = ("EmbeddingRequestOutput has been renamed to " - "PoolingRequestOutput. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return PoolingRequestOutput - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d756f71e4fa53..dc2d77d6927cd 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -46,11 +46,10 @@ from vllm.outputs import (PoolingRequestOutput, RequestOutput, from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - ParallelSampleSequenceGroup, Sequence, - SequenceGroup, SequenceGroupBase, - SequenceGroupMetadata, SequenceGroupOutput, - SequenceStatus) +from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, + PoolingSequenceGroupOutput, Sequence, SequenceGroup, + SequenceGroupBase, SequenceGroupMetadata, + SequenceGroupOutput, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -966,9 +965,9 @@ class LLMEngine: @staticmethod def _process_sequence_group_outputs( seq_group: SequenceGroup, - outputs: List[EmbeddingSequenceGroupOutput], + outputs: List[PoolingSequenceGroupOutput], ) -> None: - seq_group.embeddings = outputs[0].embeddings + seq_group.pooled_data = outputs[0].data for seq in seq_group.get_seqs(): seq.status = SequenceStatus.FINISHED_STOPPED @@ -1784,8 +1783,8 @@ class LLMEngine: num_prompt_tokens_iter) # Spec decode, if enabled, emits specialized metrics from the worker in # sampler output. - if model_output and (model_output[0].spec_decode_worker_metrics - is not None): + if model_output and isinstance(model_output[0], SamplerOutput) and ( + model_output[0].spec_decode_worker_metrics is not None): spec_decode_metrics = model_output[0].spec_decode_worker_metrics else: spec_decode_metrics = None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0bec978c4869c..11b2574ce42dd 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -26,7 +26,9 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest, LLMGuidedOptions) -from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, + PoolingRequestOutput, RequestOutput, + ScoringRequestOutput) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, @@ -120,7 +122,7 @@ class LLM: serving, use the :class:`~vllm.AsyncLLMEngine` class instead. """ - DEPRECATE_LEGACY: ClassVar[bool] = False + DEPRECATE_LEGACY: ClassVar[bool] = True """A flag to toggle whether to deprecate the legacy generate/encode API.""" DEPRECATE_INIT_POSARGS: ClassVar[bool] = True @@ -257,11 +259,14 @@ class LLM: self, prompts: Union[PromptType, Sequence[PromptType]], /, - *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, + *, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, ) -> List[RequestOutput]: ... @@ -275,6 +280,9 @@ class LLM: prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, ) -> List[RequestOutput]: ... @@ -288,6 +296,9 @@ class LLM: prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, ) -> List[RequestOutput]: ... @@ -302,6 +313,9 @@ class LLM: prompt_token_ids: List[int], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, ) -> List[RequestOutput]: ... @@ -316,6 +330,9 @@ class LLM: prompt_token_ids: List[List[int]], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, ) -> List[RequestOutput]: ... @@ -328,6 +345,9 @@ class LLM: prompt_token_ids: Union[List[int], List[List[int]]], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, ) -> List[RequestOutput]: ... @@ -678,11 +698,12 @@ class LLM: self, prompts: Union[PromptType, Sequence[PromptType]], /, - *, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, + *, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[PoolingRequestOutput]: ... @@ -696,6 +717,7 @@ class LLM: prompt_token_ids: Optional[List[int]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[PoolingRequestOutput]: ... @@ -709,6 +731,7 @@ class LLM: prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[PoolingRequestOutput]: ... @@ -723,6 +746,7 @@ class LLM: prompt_token_ids: List[int], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[PoolingRequestOutput]: ... @@ -737,6 +761,7 @@ class LLM: prompt_token_ids: List[List[int]], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[PoolingRequestOutput]: ... @@ -749,6 +774,7 @@ class LLM: prompt_token_ids: Union[List[int], List[List[int]]], use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[PoolingRequestOutput]: ... @@ -768,7 +794,8 @@ class LLM: lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[PoolingRequestOutput]: - """Generates the completions for the input prompts. + """Apply pooling to the hidden states corresponding to the input + prompts. This class automatically batches the given prompts, considering the memory constraint. For the best performance, put all of your prompts @@ -787,7 +814,7 @@ class LLM: Returns: A list of ``PoolingRequestOutput`` objects containing the - generated embeddings in the same order as the input prompts. + pooled hidden states in the same order as the input prompts. Note: Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is @@ -833,28 +860,110 @@ class LLM: return self.engine_class.validate_outputs(outputs, PoolingRequestOutput) + def embed( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[EmbeddingRequestOutput]: + """ + Generate an embedding vector for each prompt. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``EmbeddingRequestOutput`` objects containing the + embedding vectors in the same order as the input prompts. + """ + if self.llm_engine.model_config.task != "embed": + raise ValueError( + "Embedding API is only enabled for `--task embed`") + + items = self.encode(prompts, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + return [EmbeddingRequestOutput.from_base(item) for item in items] + + def classify( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[ClassificationRequestOutput]: + """ + Generate class logits for each prompt. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``ClassificationRequestOutput`` objects containing the + embedding vectors in the same order as the input prompts. + """ + if self.llm_engine.model_config.task != "classify": + raise ValueError( + "Classification API is only enabled for `--task classify`") + + items = self.encode(prompts, + use_tqdm=use_tqdm, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + return [ClassificationRequestOutput.from_base(item) for item in items] + def score( self, text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]], text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]], /, + *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> List[PoolingRequestOutput]: - """Generates similarity scores for all pairs . + ) -> List[ScoringRequestOutput]: + """Generate similarity scores for all pairs ````. - The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case - the text_1 sentence will be replicated N times to pair with the text_2 - sentences. The input pairs are used to build a list of prompts for the + The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``. + In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N`` + times to pair with the ``text_2`` sentences. + The input pairs are used to build a list of prompts for the cross encoder model. This class automatically batches the prompts, considering the memory constraint. For the best performance, put all of your texts into a single list and pass it to this method. Args: text_1: can be a single prompt or a list of prompts, in which - case it has to have the same length as the text_2 list + case it has to have the same length as the ``text_2`` list text_2: The texts to pair with the query to form the input to the LLM. See :class:`~vllm.inputs.PromptType` for more details about the format of each prompts. @@ -864,7 +973,7 @@ class LLM: generation, if any. Returns: - A list of ``PoolingRequestOutput`` objects containing the + A list of ``ScoringRequestOutput`` objects containing the generated scores in the same order as the input prompts. """ runner_type = self.llm_engine.model_config.runner_type @@ -884,6 +993,8 @@ class LLM: if not self.llm_engine.model_config.is_cross_encoder: raise ValueError("Your model does not support cross encoding") + if self.llm_engine.model_config.task != "score": + raise ValueError("Score API is only enabled for `--task score`") tokenizer = self.llm_engine.get_tokenizer() @@ -954,8 +1065,10 @@ class LLM: ) outputs = self._run_engine(use_tqdm=use_tqdm) - return self.engine_class.validate_outputs(outputs, - PoolingRequestOutput) + items = self.engine_class.validate_outputs(outputs, + PoolingRequestOutput) + + return [ScoringRequestOutput.from_base(item) for item in items] def start_profile(self) -> None: self.llm_engine.start_profile() diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ee94a9413f098..34c9f0a96216f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -900,7 +900,7 @@ class EmbeddingResponse(OpenAIBaseModel): class ScoreResponseData(OpenAIBaseModel): index: int object: str = "score" - score: Union[List[float], str] + score: float class ScoreResponse(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 3f7b75e893cad..fd501ad4f833e 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -18,14 +18,15 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.logger import init_logger -from vllm.outputs import PoolingOutput, PoolingRequestOutput +from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, + PoolingRequestOutput) from vllm.utils import merge_async_iterators logger = init_logger(__name__) def _get_embedding( - output: PoolingOutput, + output: EmbeddingOutput, encoding_format: Literal["float", "base64"], ) -> Union[List[float], str]: if encoding_format == "float": @@ -46,8 +47,10 @@ def request_output_to_embedding_response( data: List[EmbeddingResponseData] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): + embedding_res = EmbeddingRequestOutput.from_base(final_res) prompt_token_ids = final_res.prompt_token_ids - embedding = _get_embedding(final_res.outputs, encoding_format) + + embedding = _get_embedding(embedding_res.outputs, encoding_format) embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) data.append(embedding_data) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index d5ad4354c78be..5b6a089e4c319 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -31,7 +31,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ErrorResponse, LoadLoraAdapterRequest, ModelCard, ModelList, - ModelPermission, + ModelPermission, ScoreRequest, TokenizeChatRequest, TokenizeCompletionRequest, UnloadLoraAdapterRequest) @@ -73,7 +73,7 @@ class LoRAModulePath: CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, - EmbeddingCompletionRequest, + EmbeddingCompletionRequest, ScoreRequest, TokenizeCompletionRequest] ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, @@ -567,12 +567,14 @@ class OpenAIServing: return None @staticmethod - def _base_request_id(raw_request: Request, + def _base_request_id(raw_request: Optional[Request], default: Optional[str] = None) -> Optional[str]: """Pulls the request id to use from a header, if provided""" default = default or random_uuid() - return raw_request.headers.get( - "X-Request-Id", default) if raw_request is not None else default + if raw_request is None: + return default + + return raw_request.headers.get("X-Request-Id", default) @staticmethod def _get_decoded_token(logprob: Logprob, diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 4929e720c00e4..6f5cc14ac37cc 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -13,7 +13,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest, from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger -from vllm.outputs import PoolingRequestOutput +from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from vllm.utils import make_async, merge_async_iterators @@ -24,13 +24,13 @@ def request_output_to_score_response( final_res_batch: List[PoolingRequestOutput], request_id: str, created_time: int, model_name: str) -> ScoreResponse: data: List[ScoreResponseData] = [] - score = None num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): - if final_res is not None: - score = final_res.outputs.embedding - score_data = ScoreResponseData(index=idx, score=score) - data.append(score_data) + classify_res = ScoringRequestOutput.from_base(final_res) + + score_data = ScoreResponseData(index=idx, + score=classify_res.outputs.score) + data.append(score_data) usage = UsageInfo( prompt_tokens=num_prompt_tokens, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index e0d42e30ebef3..75bf33dc70a51 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,14 +1,16 @@ from enum import IntEnum -from typing import List, Optional +from typing import List, Optional, Union import torch import torch.nn as nn +import torch.nn.functional as F from transformers import PretrainedConfig +from typing_extensions import assert_never from vllm.config import PoolerConfig from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) -from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput +from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) @@ -22,7 +24,7 @@ class PoolingType(IntEnum): MEAN = 4 -class Pooler(nn.Module): +class SimplePooler(nn.Module): """A layer that pools specific information from hidden states. This layer does the following: @@ -35,22 +37,204 @@ class Pooler(nn.Module): normalize: Whether to normalize the pooled data. """ + @staticmethod + def from_pooling_type( + pooling_type: PoolingType, + *, + normalize: bool, + softmax: bool, + step_tag_id: Optional[int] = None, + returned_token_ids: Optional[List[int]] = None, + ) -> "SimplePooler": + if pooling_type == PoolingType.LAST: + assert step_tag_id is None and returned_token_ids is None + return LastPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.ALL: + assert step_tag_id is None and returned_token_ids is None + return AllPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.CLS: + assert step_tag_id is None and returned_token_ids is None + return CLSPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.MEAN: + assert step_tag_id is None and returned_token_ids is None + return MeanPool(normalize=normalize, softmax=softmax) + if pooling_type == PoolingType.STEP: + return StepPool(normalize=normalize, + softmax=softmax, + step_tag_id=step_tag_id, + returned_token_ids=returned_token_ids) + + assert_never(pooling_type) + + def __init__(self, *, normalize: bool, softmax: bool) -> None: + super().__init__() + + self.head = PoolerHead(normalize=normalize, softmax=softmax) + + def get_prompt_lens( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> torch.Tensor: + return PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + raise NotImplementedError + + def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput: + return PoolingSequenceGroupOutput(data) + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.extract_states(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data) + pooled_outputs = [self.build_output(data) for data in pooled_data] + return PoolerOutput(outputs=pooled_outputs) + + +class CLSPool(SimplePooler): + + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + first_token_flat_indices = torch.zeros_like(prompt_lens) + first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] + return hidden_states[first_token_flat_indices] + + +class LastPool(SimplePooler): + + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 + return hidden_states[last_token_flat_indices] + + +class AllPool(SimplePooler): + + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + offset = 0 + pooled_data = list[torch.Tensor]() + for prompt_len in prompt_lens: + pooled_data.append(hidden_states[offset:offset + prompt_len]) + offset += prompt_len + + return pooled_data + + +class MeanPool(SimplePooler): + + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + cumsum = torch.cumsum(hidden_states, dim=0) + start_indices = torch.cat([ + torch.tensor([0], device=hidden_states.device), + torch.cumsum(prompt_lens[:-1], dim=0) + ]) + end_indices = torch.cumsum(prompt_lens, dim=0) + return (cumsum[end_indices - 1] - cumsum[start_indices] + + hidden_states[start_indices]) / prompt_lens.unsqueeze(1) + + +class StepPool(SimplePooler): + def __init__( self, - pooling_type: PoolingType, + *, normalize: bool, softmax: bool, step_tag_id: Optional[int] = None, returned_token_ids: Optional[List[int]] = None, ): - super().__init__() + super().__init__(normalize=normalize, softmax=softmax) - self.pooling_type = pooling_type - self.normalize = normalize - self.softmax = softmax self.step_tag_id = step_tag_id self.returned_token_ids = returned_token_ids + def extract_states( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + + returned_token_ids = self.returned_token_ids + if returned_token_ids is not None and len(returned_token_ids) > 0: + hidden_states = hidden_states[:, returned_token_ids] + + step_tag_id = self.step_tag_id + + offset = 0 + pooled_data = list[torch.Tensor]() + for prompt_len, seq_data_i in zip(prompt_lens, + pooling_metadata.seq_data.values()): + pooled_data_i = hidden_states[offset:offset + prompt_len] + if step_tag_id is not None: + token_ids = torch.tensor(seq_data_i.prompt_token_ids) + pooled_data_i = pooled_data_i[token_ids == step_tag_id] + + offset += prompt_len + pooled_data.append(pooled_data_i) + + return pooled_data + + +class PoolerHead(nn.Module): + + def __init__(self, *, normalize: bool, softmax: bool) -> None: + super().__init__() + + self.normalize = normalize + self.softmax = softmax + + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor]): + if self.normalize: + if isinstance(pooled_data, list): + pooled_data = [ + F.normalize(data, p=2, dim=1) for data in pooled_data + ] + else: + pooled_data = F.normalize(pooled_data, p=2, dim=1) + + if self.softmax: + if isinstance(pooled_data, list): + pooled_data = [F.softmax(data, dim=-1) for data in pooled_data] + else: + pooled_data = F.softmax(pooled_data, dim=-1) + + return pooled_data + + +class Pooler(nn.Module): + @classmethod def from_config_with_defaults( cls, @@ -60,8 +244,8 @@ class Pooler(nn.Module): softmax: bool, step_tag_id: Optional[int] = None, returned_token_ids: Optional[List[int]] = None, - ) -> "Pooler": - return cls( + ) -> SimplePooler: + return SimplePooler.from_pooling_type( pooling_type=PoolingType[pooler_config.pooling_type] if pooler_config.pooling_type is not None else pooling_type, normalize=pooler_config.normalize @@ -75,85 +259,6 @@ class Pooler(nn.Module): returned_token_ids, ) - def forward( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - """Pools specific information from hidden states based on metadata.""" - - prompt_lens = PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device).prompt_lens - - if self.pooling_type is PoolingType.CLS: - first_token_flat_indices = torch.zeros_like(prompt_lens) - first_token_flat_indices[1:] += torch.cumsum(prompt_lens, - dim=0)[:-1] - pooled_data = hidden_states[first_token_flat_indices] - elif self.pooling_type == PoolingType.LAST: - last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 - pooled_data = hidden_states[last_token_flat_indices] - elif self.pooling_type == PoolingType.ALL: - offset = 0 - pooled_data = [] - for prompt_len in prompt_lens: - pooled_data.append(hidden_states[offset:offset + prompt_len]) - offset += prompt_len - elif self.pooling_type == PoolingType.MEAN: - # Calculate mean pooling - cumsum = torch.cumsum(hidden_states, dim=0) - start_indices = torch.cat([ - torch.tensor([0], device=hidden_states.device), - torch.cumsum(prompt_lens[:-1], dim=0) - ]) - end_indices = torch.cumsum(prompt_lens, dim=0) - pooled_data = ( - cumsum[end_indices - 1] - cumsum[start_indices] + - hidden_states[start_indices]) / prompt_lens.unsqueeze(1) - elif self.pooling_type == PoolingType.STEP: - returned_token_ids = self.returned_token_ids - if returned_token_ids is not None and len(returned_token_ids) > 0: - hidden_states = hidden_states[:, returned_token_ids] - - step_tag_id = self.step_tag_id - - offset = 0 - pooled_data = [] - for prompt_len, seq_data_i in zip( - prompt_lens, pooling_metadata.seq_data.values()): - pooled_data_i = hidden_states[offset:offset + prompt_len] - if step_tag_id is not None: - token_ids = torch.tensor(seq_data_i.prompt_token_ids) - pooled_data_i = pooled_data_i[token_ids == step_tag_id] - - offset += prompt_len - pooled_data.append(pooled_data_i) - else: - raise ValueError(f"Invalid pooling type: {self.pooling_type}") - - if self.normalize: - if isinstance(pooled_data, list): - pooled_data = [ - nn.functional.normalize(data, p=2, dim=1) - for data in pooled_data - ] - else: - pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) - - if self.softmax: - if isinstance(pooled_data, list): - pooled_data = [ - nn.functional.softmax(data, dim=-1) for data in pooled_data - ] - else: - pooled_data = nn.functional.softmax(pooled_data, dim=-1) - - pooled_outputs = [ - EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data - ] - - return PoolerOutput(outputs=pooled_outputs) - class CrossEncodingPooler(nn.Module): """A layer that pools specific information from hidden states. @@ -208,9 +313,8 @@ class CrossEncodingPooler(nn.Module): if self.pooler is not None: # apply classifier once on the full batch if possible pooled_output = self.classifier(pooled_output) - logits = self.default_activation_function(pooled_output) - pooled_outputs = [ - EmbeddingSequenceGroupOutput(data.tolist()) for data in logits - ] + scores = self.default_activation_function(pooled_output).squeeze(-1) + + pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 34c1332ac4a66..d179d6235424a 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -2,19 +2,20 @@ from array import array from typing import List, Optional, Union import torch -from torch import nn +import torch.nn as nn from xformers.ops.fmha.attn_bias import BlockDiagonalMask from vllm.attention import AttentionMetadata from vllm.attention.backends.xformers import XFormersImpl from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import PoolerHead from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (EmbeddingSequenceGroupOutput, IntermediateTensors, - PoolerOutput) +from vllm.sequence import (IntermediateTensors, PoolerOutput, + PoolingSequenceGroupOutput) logger = init_logger(__name__) @@ -52,6 +53,8 @@ class GritLMPooler(nn.Module): self.embed_pattern_ids = tokens_to_ids( ["▁<", "|", "embed", "|", ">", "<0x0A>"]) + self.head = PoolerHead(normalize=True, softmax=False) + def _find_array(self, arr: array, target: array, start_idx: int) -> int: """ Find the first occurrence of target in arr starting from start_idx. @@ -75,7 +78,7 @@ class GritLMPooler(nn.Module): return i return -1 - def _get_instruction_len(self, prompt_token_ids: array) -> bool: + def _get_instruction_len(self, prompt_token_ids: array) -> int: """ Get the length of the instruction in the prompt. @@ -168,10 +171,10 @@ class GritLMPooler(nn.Module): mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( 1) - pooled_data = nn.functional.normalize(mean_embeddings, p=2, dim=1) + pooled_data = self.head(mean_embeddings) pooled_outputs = [ - EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data + PoolingSequenceGroupOutput(data) for data in pooled_data ] return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/outputs.py b/vllm/outputs.py index 86264f604f6bc..8c6c1aca3a917 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,9 +1,13 @@ import time +import warnings from dataclasses import dataclass -from typing import Dict, List, Optional +from typing import Dict, Generic, List, Optional from typing import Sequence as GenericSequence from typing import Union +import torch +from typing_extensions import TypeVar + from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict from vllm.sampling_params import RequestOutputKind @@ -57,14 +61,26 @@ class PoolingOutput: """The output data of one pooling output of a request. Args: - embedding: The embedding vector, which is a list of floats. The - length of vector depends on the model as listed in the embedding guide. + data: The extracted hidden states. """ - embedding: List[float] + data: torch.Tensor def __repr__(self) -> str: - return (f"PoolingOutput(" - f"embedding={len(self.embedding)})") + return (f"PoolingOutput(data={self.data})") + + def __eq__(self, other: object) -> bool: + return (isinstance(other, self.__class__) and bool( + (self.data == other.data).all())) + + @property + def embedding(self) -> list[float]: + msg = ("`LLM.encode()` now returns raw outputs. " + "To return embeddings, use `LLM.embed()`. " + "To return class probabilities, use `LLM.classify()` " + "and access the `probs` attribute. ") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + return self.data.tolist() class RequestOutput: @@ -316,7 +332,10 @@ class RequestOutput: f"multi_modal_placeholders={self.multi_modal_placeholders})") -class PoolingRequestOutput: +_O = TypeVar("_O", default=PoolingOutput) + + +class PoolingRequestOutput(Generic[_O]): """ The output data of a pooling request to the LLM. @@ -327,24 +346,24 @@ class PoolingRequestOutput: finished (bool): A flag indicating whether the pooling is completed. """ - def __init__(self, request_id: str, outputs: "PoolingOutput", + def __init__(self, request_id: str, outputs: _O, prompt_token_ids: List[int], finished: bool): self.request_id = request_id self.prompt_token_ids = prompt_token_ids self.finished = finished self.outputs = outputs - @classmethod - def from_seq_group(cls, - seq_group: 'SequenceGroup') -> "PoolingRequestOutput": - if seq_group.embeddings is None: - raise ValueError( - "Embeddings are missing in seq_group for EmbeddingRequest.") - output = PoolingOutput(seq_group.embeddings) + @staticmethod + def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput": + pooled_data = seq_group.pooled_data + assert pooled_data is not None + + output = PoolingOutput(pooled_data) prompt_token_ids = seq_group.prompt_token_ids finished = seq_group.is_finished() - return cls(seq_group.request_id, output, prompt_token_ids, finished) + return PoolingRequestOutput(seq_group.request_id, output, + prompt_token_ids, finished) def __repr__(self): """ @@ -356,89 +375,137 @@ class PoolingRequestOutput: Returns: str: A string representation of the PoolingRequestOutput instance. """ - return (f"PoolingRequestOutput(request_id='{self.request_id}', " - f"outputs={repr(self.outputs)}, " + return (f"{type(self).__name__}(request_id={self.request_id!r}, " + f"outputs={self.outputs!r}, " f"prompt_token_ids={self.prompt_token_ids}, " f"finished={self.finished})") -@dataclass -class ScoreOutput: - """The output data of one completion output of a request. - - Args: - score: The score, which is a list of floats. - index: The correspondent text index of the score. - """ - index: int - score: List[float] - - def __repr__(self) -> str: - return (f"ScoreOutput(" - f"score={self.score}), " - f"index={self.index})") - - -class ScoreRequestOutput: - """ - The output data of an score request to the LLM. - - Args: - request_id (str): A unique identifier for the score request. - outputs (score): The embedding results for the given input. - """ - - def __init__(self, request_id: str, outputs: "ScoreOutput"): - self.request_id = request_id - self.outputs = outputs - - def __repr__(self): - """ - Returns a string representation of an ScoreRequestOutput instance. - - The representation includes the request_id and the number of outputs, - providing a quick overview of the embedding request's results. - - Returns: - str: A string representation of the ScoreRequestOutput instance. - """ - return (f"ScoreRequestOutput(request_id='{self.request_id}', " - f"outputs={repr(self.outputs)}") - - class RequestOutputFactory: @staticmethod def create(seq_group: SequenceGroup, seq_id_to_seq_group: Dict[str, SequenceGroupBase], use_cache: bool = False): - # Determine the type based on a condition, for example: - if hasattr(seq_group, - 'embeddings') and seq_group.embeddings is not None: + if seq_group.pooled_data is not None: return PoolingRequestOutput.from_seq_group(seq_group) else: return RequestOutput.from_seq_group(seq_group, use_cache, seq_id_to_seq_group) -def __getattr__(name: str): - import warnings +@dataclass +class EmbeddingOutput: + """The output data of one embedding output of a request. - if name == "EmbeddingOutput": - msg = ("EmbeddingOutput has been renamed to PoolingOutput. " - "The original name will be removed in an upcoming version.") + Args: + embedding: The embedding vector, which is a list of floats. + Its length depends on the hidden dimension of the model. + """ + embedding: list[float] - warnings.warn(DeprecationWarning(msg), stacklevel=2) + @staticmethod + def from_base(pooling_output: PoolingOutput): + pooled_data = pooling_output.data + if pooled_data.ndim != 1: + raise ValueError("pooled_data should be a 1-D embedding vector") - return PoolingOutput + return EmbeddingOutput(pooled_data.tolist()) - if name == "EmbeddingRequestOutput": - msg = ("EmbeddingRequestOutput has been renamed to " - "PoolingRequestOutput. " - "The original name will be removed in an upcoming version.") + @property + def hidden_size(self) -> int: + return len(self.embedding) - warnings.warn(DeprecationWarning(msg), stacklevel=2) + def __repr__(self) -> str: + return f"EmbeddingOutput(hidden_size={self.hidden_size})" - return PoolingRequestOutput - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") +class EmbeddingRequestOutput(PoolingRequestOutput[EmbeddingOutput]): + + @staticmethod + def from_base(request_output: PoolingRequestOutput): + return EmbeddingRequestOutput( + request_id=request_output.request_id, + outputs=EmbeddingOutput.from_base(request_output.outputs), + prompt_token_ids=request_output.prompt_token_ids, + finished=request_output.finished, + ) + + +@dataclass +class ClassificationOutput: + """The output data of one classification output of a request. + + Args: + probs: The probability vector, which is a list of floats. + Its length depends on the number of classes. + """ + probs: list[float] + + @staticmethod + def from_base(pooling_output: PoolingOutput): + pooled_data = pooling_output.data + if pooled_data.ndim != 1: + raise ValueError("pooled_data should be a 1-D probability vector") + + return ClassificationOutput(pooled_data.tolist()) + + @property + def num_classes(self) -> int: + return len(self.probs) + + def __repr__(self) -> str: + return f"ClassificationOutput(num_classes={self.num_classes})" + + +class ClassificationRequestOutput(PoolingRequestOutput[ClassificationOutput]): + + @staticmethod + def from_base(request_output: PoolingRequestOutput): + return ClassificationRequestOutput( + request_id=request_output.request_id, + outputs=ClassificationOutput.from_base(request_output.outputs), + prompt_token_ids=request_output.prompt_token_ids, + finished=request_output.finished, + ) + + +@dataclass +class ScoringOutput: + """The output data of one scoring output of a request. + + Args: + score: The similarity score, which is a scalar value. + """ + score: float + + @staticmethod + def from_base(pooling_output: PoolingOutput): + pooled_data = pooling_output.data + if pooled_data.ndim != 0: + raise ValueError("pooled_data should be a scalar score") + + return ScoringOutput(pooled_data.item()) + + def __repr__(self) -> str: + return f"ScoringOutput(score={self.score})" + + @property + def embedding(self) -> list[float]: + msg = ("`LLM.score()` now returns scalar scores. " + "Please access it via the `score` attribute. ") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + return [self.score] + + +class ScoringRequestOutput(PoolingRequestOutput[ScoringOutput]): + + @staticmethod + def from_base(request_output: PoolingRequestOutput): + return ScoringRequestOutput( + request_id=request_output.request_id, + outputs=ScoringOutput.from_base(request_output.outputs), + prompt_token_ids=request_output.prompt_token_ids, + finished=request_output.finished, + ) diff --git a/vllm/sequence.py b/vllm/sequence.py index b0f3c1cc3609f..ddb9ca5944f10 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -617,10 +617,9 @@ class SequenceGroup: sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. lora_request: LoRA request. - embeddings: The embeddings vectors of the prompt of the sequence group - for a pooling model. - pooling_params: The pooling parameters used to generate the pooling + pooling_params: The parameters used to generate the pooler for a pooling model. + pooled_data: The extracted hidden states from a pooling model. encoder_seq: Optional, the single encoder sequence. Should be None unless you are working with an encoder/decoder model. trace_headers: OpenTelemetry trace headers. @@ -635,8 +634,8 @@ class SequenceGroup: arrival_time: float, sampling_params: Optional[SamplingParams] = None, lora_request: Optional[LoRARequest] = None, - embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, + pooled_data: Optional[torch.Tensor] = None, encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -658,8 +657,8 @@ class SequenceGroup: self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() - self.embeddings = embeddings self.pooling_params = pooling_params + self.pooled_data = pooled_data self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers @@ -1033,8 +1032,8 @@ class CompletionSequenceGroupOutput( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] array_like=True): # type: ignore[call-arg] - __metaclass__ = SequenceGroupOutput """The model output associated with a completion sequence group.""" + __metaclass__ = SequenceGroupOutput samples: List[SequenceOutput] # Prompt logprob for each prompt query token. prompt_logprobs: Optional[PromptLogprobs] @@ -1050,23 +1049,24 @@ class CompletionSequenceGroupOutput( and self.prompt_logprobs == other.prompt_logprobs) -class EmbeddingSequenceGroupOutput( +class PoolingSequenceGroupOutput( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] array_like=True, # type: ignore[call-arg] ): - """The model output associated with an embedding sequence group.""" + """The model output associated with a pooling sequence group.""" __metaclass__ = SequenceGroupOutput - embeddings: List[int] + # Annotated as Any to be compatible with msgspec + # The actual type is in SequenceGroup.pooled_data + data: Any def __repr__(self) -> str: - return (f"EmbeddingSequenceGroupOutput(" - f"embeddings_shape={len(self.embeddings)})") + return f"PoolingSequenceGroupOutput(data={self.data}" def __eq__(self, other: object) -> bool: - if not isinstance(other, EmbeddingSequenceGroupOutput): + if not isinstance(other, PoolingSequenceGroupOutput): raise NotImplementedError() - return self.embeddings == other.embeddings + return self.data == other.data # cannot use msgspec.Struct here because Dynamo does not support it @@ -1085,7 +1085,7 @@ class IntermediateTensors: elif isinstance(key, slice): return self.__class__({k: v[key] for k, v in self.tensors.items()}) - def __setitem__(self, key: str, value): + def __setitem__(self, key: str, value: torch.Tensor): self.tensors[key] = value def __len__(self): @@ -1103,16 +1103,12 @@ class PoolerOutput( omit_defaults=True, # type: ignore[call-arg] array_like=True): # type: ignore[call-arg] """The output from a pooling operation in the pooling model.""" - outputs: List[EmbeddingSequenceGroupOutput] + outputs: List[PoolingSequenceGroupOutput] - # lazy import to avoid circular import - from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics - spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None - - def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput: + def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput: return self.outputs[idx] - def __setitem__(self, idx: int, value): + def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput): self.outputs[idx] = value def __len__(self): @@ -1385,8 +1381,8 @@ class ParallelSampleSequenceGroup(SequenceGroupBase): arrival_time=seq_group.arrival_time, sampling_params=original_params, lora_request=seq_group.lora_request, - embeddings=seq_group.embeddings, pooling_params=seq_group.pooling_params, + pooled_data=seq_group.pooled_data, encoder_seq=seq_group.encoder_seq, trace_headers=seq_group.trace_headers, prompt_adapter_request=seq_group.prompt_adapter_request,