diff --git a/tests/models/language/pooling/test_mm_classifier_conversion.py b/tests/models/language/pooling/test_mm_classifier_conversion.py new file mode 100644 index 0000000000000..166b953de43e7 --- /dev/null +++ b/tests/models/language/pooling/test_mm_classifier_conversion.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.platforms import current_platform + + +def test_idefics_multimodal( + vllm_runner, + monkeypatch, +) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + with vllm_runner(model_name="HuggingFaceM4/Idefics3-8B-Llama3", + runner="pooling", + task="classify", + convert="classify", + load_format="dummy", + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=1, + disable_log_stats=True, + dtype="bfloat16") as vllm_model: + llm = vllm_model.get_llm() + outputs = llm.classify(prompts) + for output in outputs: + assert len(output.outputs.probs) == 2 + + +def update_config(config): + config.text_config.update({ + "architectures": ["Gemma3ForSequenceClassification"], + "classifier_from_token": ["A", "B", "C", "D", "E"], + "method": + "no_post_processing", + "id2label": { + "A": "Chair", + "B": "Couch", + "C": "Table", + "D": "Bed", + "E": "Cupboard" + }, + }) + return config + + +def test_gemma_multimodal( + vllm_runner, + monkeypatch, +) -> None: + if current_platform.is_rocm(): + # ROCm Triton FA does not currently support sliding window attention + # switch to use ROCm CK FA backend + monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False") + + messages = [{ + "role": + "system", + "content": + """ + You are a helpful assistant. You will be given a product description + which may also include an image. Classify the following product into + one of the categories: + + A = chair + B = couch + C = table + D = bed + E = cupboard + + You'll answer with exactly one letter (A, B, C, D, or E).""" + }, { + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": + "https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg" + } + }, { + "type": "text", + "text": "A fine 19th century piece of furniture." + }] + }] + + with vllm_runner(model_name="google/gemma-3-4b-it", + runner="pooling", + task="classify", + convert="classify", + load_format="auto", + hf_overrides=update_config, + override_pooler_config={"pooling_type": "LAST"}, + max_model_len=512, + enforce_eager=True, + tensor_parallel_size=1, + disable_log_stats=True, + dtype="bfloat16") as vllm_model: + + llm = vllm_model.get_llm() + prompts = llm.preprocess_chat(messages) + + result = llm.classify(prompts) + assert result[0].outputs.probs[0] > 0.95 + assert all(c < 0.05 for c in result[0].outputs.probs[1:]) \ No newline at end of file diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a6174161f115a..94749eb884512 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -703,6 +703,106 @@ class LLM: return outputs + def preprocess_chat( + self, + messages: Union[list[ChatCompletionMessageParam], + list[list[ChatCompletionMessageParam]]], + lora_request: Optional[LoRARequest] = None, + chat_template: Optional[str] = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tools: Optional[list[dict[str, Any]]] = None, + chat_template_kwargs: Optional[dict[str, Any]] = None, + mm_processor_kwargs: Optional[dict[str, Any]] = None, + ) -> list[TokensPrompt]: + """ + Generate prompt for a chat conversation. The pre-processed + prompt can then be used as input for the other LLM methods. + + Refer to `chat` for a complete description of the arguments. + Returns: + A list of `TokensPrompts` objects containing the tokenized + prompt after chat template interpolation, and the + pre-processed multi-modal inputs. + """ + list_of_messages: list[list[ChatCompletionMessageParam]] + + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is list[list[...]] + list_of_messages = cast(list[list[ChatCompletionMessageParam]], + messages) + else: + # messages is list[...] + list_of_messages = [ + cast(list[ChatCompletionMessageParam], messages) + ] + + tokenizer = self.get_tokenizer(lora_request) + model_config = self.llm_engine.get_model_config() + resolved_content_format = resolve_chat_template_content_format( + chat_template, + tools, + chat_template_content_format, + tokenizer, + model_config=model_config, + ) + + _chat_template_kwargs: dict[str, Any] = dict( + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + ) + _chat_template_kwargs.update(chat_template_kwargs or {}) + + prompts: list[TokensPrompt] = [] + + for msgs in list_of_messages: + # NOTE: _parse_chat_message_content_parts() currently doesn't + # handle mm_processor_kwargs, since there is no implementation in + # the chat message parsing for it. + conversation, mm_data, mm_uuids = parse_chat_messages( + msgs, + model_config, + tokenizer, + content_format=resolved_content_format, + ) + + if isinstance(tokenizer, MistralTokenizer): + prompt_token_ids = apply_mistral_chat_template( + tokenizer, + messages=msgs, + **_chat_template_kwargs, + ) + else: + prompt_str = apply_hf_chat_template( + tokenizer=tokenizer, + conversation=conversation, + model_config=model_config, + **_chat_template_kwargs, + ) + # Special tokens are already included in chat templates so + # should not be added by the tokenizer in this case. + prompt_token_ids = tokenizer.encode(prompt_str, + add_special_tokens=False) + + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + + if mm_uuids is not None: + prompt["multi_modal_uuids"] = mm_uuids + + if mm_processor_kwargs is not None: + prompt["mm_processor_kwargs"] = mm_processor_kwargs + + prompts.append(prompt) + + return prompts + def chat( self, messages: Union[list[ChatCompletionMessageParam], @@ -769,80 +869,18 @@ class LLM: A list of `RequestOutput` objects containing the generated responses in the same order as the input messages. """ - list_of_messages: list[list[ChatCompletionMessageParam]] - # Handle multi and single conversations - if is_list_of(messages, list): - # messages is list[list[...]] - list_of_messages = cast(list[list[ChatCompletionMessageParam]], - messages) - else: - # messages is list[...] - list_of_messages = [ - cast(list[ChatCompletionMessageParam], messages) - ] - - tokenizer = self.get_tokenizer(lora_request) - model_config = self.llm_engine.get_model_config() - resolved_content_format = resolve_chat_template_content_format( - chat_template, - tools, - chat_template_content_format, - tokenizer, - model_config=model_config, - ) - - _chat_template_kwargs: dict[str, Any] = dict( + prompts = self.preprocess_chat( + messages=messages, + lora_request=lora_request, chat_template=chat_template, + chat_template_content_format=chat_template_content_format, add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, tools=tools, + chat_template_kwargs=chat_template_kwargs, + mm_processor_kwargs=mm_processor_kwargs, ) - _chat_template_kwargs.update(chat_template_kwargs or {}) - - prompts: list[Union[TokensPrompt, TextPrompt]] = [] - - for msgs in list_of_messages: - # NOTE: _parse_chat_message_content_parts() currently doesn't - # handle mm_processor_kwargs, since there is no implementation in - # the chat message parsing for it. - conversation, mm_data, mm_uuids = parse_chat_messages( - msgs, - model_config, - tokenizer, - content_format=resolved_content_format, - ) - - if isinstance(tokenizer, MistralTokenizer): - prompt_token_ids = apply_mistral_chat_template( - tokenizer, - messages=msgs, - **_chat_template_kwargs, - ) - else: - prompt_str = apply_hf_chat_template( - tokenizer=tokenizer, - conversation=conversation, - model_config=model_config, - **_chat_template_kwargs, - ) - # Special tokens are already included in chat templates so - # should not be added by the tokenizer in this case. - prompt_token_ids = tokenizer.encode(prompt_str, - add_special_tokens=False) - - prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) - - if mm_data is not None: - prompt["multi_modal_data"] = mm_data - - if mm_uuids is not None: - prompt["multi_modal_uuids"] = mm_uuids - - if mm_processor_kwargs is not None: - prompt["mm_processor_kwargs"] = mm_processor_kwargs - - prompts.append(prompt) return self.generate( prompts, diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index c82fa5a40aa53..0c2441a6db44d 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -19,10 +19,11 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import QKVCrossParallelLinear from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.models.adapters import (as_embedding_model, - as_reward_model, - as_seq_cls_model) -from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.model_executor.models.adapters import ( + as_embedding_model, as_reward_model, as_seq_cls_model, + try_create_mm_pooling_model_cls) +from vllm.model_executor.models.interfaces import (SupportsQuant, + supports_multimodal) from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -183,6 +184,15 @@ def get_model_architecture( "performance may not be optimal.", arch) convert_type = model_config.convert_type + if convert_type != "none" and supports_multimodal(model_cls): + logger.debug_once("Detected conversion of Multi Modal model.") + converted = try_create_mm_pooling_model_cls(model_cls) + if converted is not None: + logger.debug_once("Creating wrapper class to forward pooler.") + return converted, arch + else: + logger.debug_once("Attempting direct conversion.") + if convert_type == "none": pass elif convert_type == "embed": diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 78ad9a433e314..c4328a176a5de 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -1,12 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ast +import inspect from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast import torch import torch.nn as nn +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig @@ -129,6 +132,41 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: return model_name + pooling_suffix +def try_create_mm_pooling_model_cls(orig_cls: _T) -> _T: + + class CallVisitor(ast.NodeVisitor): + + def __init__(self): + self.calls = [] + + def visit_Call(self, node): + if isinstance(node.func, ast.Name): + self.calls.append(node.func.id) + self.generic_visit(node) + + visitor = CallVisitor() + visitor.visit(ast.parse(inspect.getsource(orig_cls))) + if "init_vllm_registered_model" not in visitor.calls: + return None + + class ModelForPooling(orig_cls, VllmModelForPooling): + + is_pooling_model = True + + def __init__( + self, + *, + vllm_config: "VllmConfig", + prefix: str = "", + **kwargs: Any, + ) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + + self.pooler = self.get_language_model().pooler + + return ModelForPooling # type: ignore + + def _create_pooling_model_cls(orig_cls: _T) -> _T: # Lazy import from .utils import AutoWeightsLoader, WeightsMapper @@ -399,6 +437,7 @@ def load_weights_using_from_2_way_softmax( from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config + tokens = getattr(model.config, "classifier_from_token", []) tokens = cast(list[int], tokens) assert len(tokens) == 2 @@ -406,9 +445,10 @@ def load_weights_using_from_2_way_softmax( if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: + quant_config = model.vllm_config.quant_config model.lm_head = ParallelLMHead(model.config.vocab_size, model.config.hidden_size, - quant_config=model.quant_config) + quant_config=quant_config) loader = AutoWeightsLoader(model) loaded_weights = loader.load_weights(weights) @@ -452,9 +492,10 @@ def load_weights_no_post_processing(model, if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: + quant_config = model.vllm_config.quant_config model.lm_head = ParallelLMHead(model.config.vocab_size, model.config.hidden_size, - quant_config=model.quant_config) + quant_config=quant_config) loader = AutoWeightsLoader(model) loaded_weights = loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index f3dc7dde46bdf..e652ba2f1c7fe 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -512,7 +512,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, architectures=["Gemma3ForCausalLM"], ) logit_scale = getattr(config, "logit_scale", 1.0) - self.language_model.logits_processor.scale *= logit_scale + + if hasattr(self.language_model, "logits_processor"): + # The logits processor can be unset if we're using + # automatic conversion to pooling model. + self.language_model.logits_processor.scale *= logit_scale self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors)