From 9499e26e2ae18826bcda99ae7e0883268cde03db Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Sun, 20 Jul 2025 15:25:50 +0200 Subject: [PATCH] [Model] Support VLMs with transformers backend (#20543) Signed-off-by: raushan Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py Co-authored-by: Cyrus Leung --- docs/models/supported_models.md | 9 +- .../multimodal/generation/test_common.py | 75 +++ tests/models/registry.py | 1 + vllm/config.py | 39 +- vllm/model_executor/model_loader/utils.py | 49 +- vllm/model_executor/models/registry.py | 12 +- vllm/model_executor/models/transformers.py | 527 ++++++++++++++++-- 7 files changed, 625 insertions(+), 87 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 306a7851a432d..0a2f69bd77111 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -18,7 +18,7 @@ These models are what we list in [supported-text-models][supported-text-models] ### Transformers -vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models are supported, and vision language model support is planned! +vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs, and require setting `--disable_mm_preprocessor_cache` when running. Support for video inputs and caching of multi-modal preprocessors will be added in future releases. To check if the modeling backend is Transformers, you can simply do this: @@ -28,7 +28,7 @@ llm = LLM(model=..., task="generate") # Name or path of your model llm.apply_model(lambda model: print(type(model))) ``` -If it is `TransformersForCausalLM` then it means it's based on Transformers! +If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers! !!! tip You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference](../serving/offline_inference.md) or `--model-impl transformers` for the [openai-compatible-server](../serving/openai_compatible_server.md). @@ -36,6 +36,9 @@ If it is `TransformersForCausalLM` then it means it's based on Transformers! !!! note vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. +!!! note + In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. + #### Custom models If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! @@ -99,7 +102,7 @@ Here is what happens in the background when this model is loaded: 1. The config is loaded. 2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. -3. `MyModel` is loaded into `TransformersForCausalLM` (see ) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. +3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see ) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. That's it! diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 98461676aa47f..9859ac5a89dd7 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -35,6 +35,8 @@ if current_platform.is_rocm(): REQUIRES_V0_MODELS = [ # V1 Test: not enough KV cache space in C1. "fuyu", + # V1 Test: Deadlock issue when processing mm_inputs + "llava-onevision-transformers", ] # yapf: disable @@ -170,6 +172,79 @@ VLM_TEST_SETTINGS = { hf_output_post_proc=model_utils.ultravox_trunc_hf_output, marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + #### Transformers fallback to test + ## To reduce test burden, we only test batching arbitrary image size + # Dynamic image length and number of patches + "llava-onevision-transformers": VLMTestInfo( + models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], + test_type=VLMTestType.IMAGE, + prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + max_model_len=16384, + hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + auto_cls=AutoModelForImageTextToText, + vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "transformers", + "disable_mm_preprocessor_cache": True, + "enable_prefix_caching": False, + }, + marks=[pytest.mark.core_model], + ), + # FIXME(Isotr0py): Enable this test after + # https://github.com/huggingface/transformers/pull/39470 released + # "idefics3-transformers": VLMTestInfo( + # models=["HuggingFaceTB/SmolVLM-256M-Instruct"], + # test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + # prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}\nAssistant:", # noqa: E501 + # img_idx_to_prompt=lambda idx: "", + # max_model_len=8192, + # max_num_seqs=2, + # auto_cls=AutoModelForImageTextToText, + # hf_output_post_proc=model_utils.idefics3_trunc_hf_output, + # image_size_factors=[(0.25, 0.5, 1.0)], + # vllm_runner_kwargs={ + # "model_impl": "transformers", + # "disable_mm_preprocessor_cache": True, + # "enable_prefix_caching": False, + # }, + # marks=[pytest.mark.core_model], + # ), + # Pixel values from processor are not 4D or 5D arrays + "qwen2_5_vl-transformers": VLMTestInfo( + models=["Qwen/Qwen2.5-VL-3B-Instruct"], + test_type=VLMTestType.IMAGE, + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, + image_size_factors=[(0.25, 0.2, 0.15)], + vllm_runner_kwargs={ + "model_impl": "transformers", + "disable_mm_preprocessor_cache": True, + "enable_prefix_caching": False, + }, + marks=[large_gpu_mark(min_gb=32)], + ), + # Check "auto" with fallback to transformers + "internvl-transformers": VLMTestInfo( + models=["OpenGVLab/InternVL3-1B-hf"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "", + max_model_len=4096, + use_tokenizer_eos=True, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "auto", + "disable_mm_preprocessor_cache": True, + "enable_prefix_caching": False, + }, + auto_cls=AutoModelForImageTextToText, + marks=[pytest.mark.core_model], + ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], diff --git a/tests/models/registry.py b/tests/models/registry.py index c2f1089af2ac6..19725acd6c451 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -499,6 +499,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { _TRANSFORMERS_MODELS = { "TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 + "TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), } _EXAMPLE_MODELS = { diff --git a/vllm/config.py b/vllm/config.py index c261f968e7fcb..44106dd279b6b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -562,6 +562,10 @@ class ModelConfig: self.task = "embed" + model_info, arch = self.registry.inspect_model_cls(self.architectures) + self._model_info = model_info + self._architecture = arch + all_supported_tasks = self._get_supported_tasks(self.task) logger.debug("Tasks supported by runner type: %s", all_supported_tasks) supported_runner_types = self._get_supported_runner_types( @@ -587,10 +591,6 @@ class ModelConfig: else: self.truncation_side = "right" - model_info, arch = self.registry.inspect_model_cls(self.architectures) - self._model_info = model_info - self._architecture = arch - self.pooler_config = self._init_pooler_config() self.dtype = _get_and_verify_dtype( @@ -674,6 +674,16 @@ class ModelConfig: "max_model_len must be an integer after __post_init__.") return self + def _get_transformers_backend_cls(self) -> str: + """Determine which Transformers backend class will be used if + `model_impl` is set to `transformers` or `auto`.""" + if self.hf_config != self.hf_text_config: + # If 'hf_text_config' is the same as 'hf_config'. If not, it is + # probably a composite config, i.e. multimodal + return "TransformersForMultimodalLM" + else: + return "TransformersForCausalLM" + @property def registry(self): return me_models.ModelRegistry @@ -681,7 +691,19 @@ class ModelConfig: @property def architectures(self) -> list[str]: # architectures in the model config. - return getattr(self.hf_config, "architectures", []) + architectures = getattr(self.hf_config, "architectures", []) + # The registry assumes that it can always inspect the vLLM model class + # for a given architecture. This assumption breaks down for the + # Transformers backend, which may use a different class depending on + # the model type. To work around this, we add the correct Transformers + # backend class to the architectures list. We must do this here because + # we need access to the `hf_config` to determine the backend class. + transformers_backend_cls = self._get_transformers_backend_cls() + if (self.model_impl != ModelImpl.VLLM.value + and all(arch != transformers_backend_cls + for arch in architectures)): + architectures.append(transformers_backend_cls) + return architectures @property def architecture(self) -> str: @@ -827,10 +849,9 @@ class ModelConfig: ("EmbeddingModel", "embed"), ("RewardModel", "reward"), ] - _, arch = self.registry.inspect_model_cls(architectures) for suffix, pref_task in suffix_to_preferred_task: - if arch.endswith(suffix): + if self.architecture.endswith(suffix): return pref_task return "embed" @@ -944,10 +965,10 @@ class ModelConfig: ("EmbeddingModel", "pooling"), ("RewardModel", "pooling"), ] - _, arch = self.registry.inspect_model_cls(self.architectures) for suffix, pref_runner in suffix_to_preferred_runner: - if arch.endswith(suffix) and pref_runner in supported_runner_types: + if self.architecture.endswith( + suffix) and pref_runner in supported_runner_types: return pref_runner if "generate" in supported_runner_types: diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 190d1f006bc46..42c5512905f21 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -25,6 +25,7 @@ from vllm.model_executor.models.adapters import (as_embedding_model, as_reward_model, as_seq_cls_model) from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.model_executor.models.registry import _TRANSFORMERS_MODELS from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -169,9 +170,22 @@ def device_loading_context(module: torch.nn.Module, def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]): + if model_config.model_impl == ModelImpl.VLLM: + raise ValueError( + "Attempting to resolve architecture from the Transformers library " + "but the model implementation is set to vLLM. This should never " + "happen.") + for i, arch in enumerate(architectures): - if arch == "TransformersForCausalLM": + if arch in _TRANSFORMERS_MODELS: continue + + if model_config.model_impl == ModelImpl.AUTO: + logger.warning( + "%s has no vLLM implementation, falling back to Transformers " + "implementation. Some features may not be supported and " + "performance may not be optimal.", arch) + auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", None) or dict() # Make sure that config class is always initialized before model class, @@ -199,25 +213,13 @@ def resolve_transformers_arch(model_config: ModelConfig, "not present in the model config's 'auto_map' (relevant " "if the model is custom).") model_module = auto_modules["AutoModel"] - # TODO(Isotr0py): Further clean up these raises. - # perhaps handled them in _ModelRegistry._raise_for_unsupported? - if model_config.model_impl == ModelImpl.TRANSFORMERS: - if not model_module.is_backend_compatible(): - raise ValueError( - f"The Transformers implementation of {arch} is not " - "compatible with vLLM.") - architectures[i] = "TransformersForCausalLM" - if model_config.model_impl == ModelImpl.AUTO: - if not model_module.is_backend_compatible(): - raise ValueError( - f"{arch} has no vLLM implementation and the Transformers " - "implementation is not compatible with vLLM. Try setting " - "VLLM_USE_V1=0.") - logger.warning( - "%s has no vLLM implementation, falling back to Transformers " - "implementation. Some features may not be supported and " - "performance may not be optimal.", arch) - architectures[i] = "TransformersForCausalLM" + + if not model_module.is_backend_compatible(): + raise ValueError( + f"The Transformers implementation of '{arch}' is not " + "compatible with vLLM.") + + architectures[i] = model_config._get_transformers_backend_cls() return architectures @@ -237,8 +239,9 @@ def get_model_architecture( ] vllm_supported_archs = ModelRegistry.get_supported_archs() - vllm_not_supported = not any(arch in vllm_supported_archs - for arch in architectures) + is_supported = lambda arch: (arch in vllm_supported_archs and arch not in + _TRANSFORMERS_MODELS) + vllm_not_supported = not any(is_supported(arch) for arch in architectures) if vllm_not_supported: # try automatic conversion in adapters.py @@ -259,7 +262,7 @@ def get_model_architecture( break if (model_config.model_impl == ModelImpl.TRANSFORMERS or - model_config.model_impl != ModelImpl.VLLM and vllm_not_supported): + model_config.model_impl == ModelImpl.AUTO and vllm_not_supported): architectures = resolve_transformers_arch(model_config, architectures) logger.debug_once("Resolve transformers arch %s", str(architectures)) elif (model_config.quantization is not None diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b57130ec84c90..a85e8b0e7b1b6 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -253,6 +253,7 @@ _SPECULATIVE_DECODING_MODELS = { } _TRANSFORMERS_MODELS = { + "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), } # yapf: enable @@ -504,9 +505,14 @@ class _ModelRegistry: if causal_lm_arch in self.models: normalized_arch.append(arch) - # make sure Transformers backend is put at the last as a fallback - if len(normalized_arch) != len(architectures): - normalized_arch.append("TransformersForCausalLM") + # NOTE(Isotr0py): Be careful of architectures' order! + # Make sure Transformers backend architecture is at the end of the + # list, otherwise pooling models automatic conversion will fail! + for arch in normalized_arch: + if arch.startswith("TransformersFor"): + normalized_arch.remove(arch) + normalized_arch.append(arch) + return normalized_arch def inspect_model_cls( diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 04ee3a454f9d8..47cff29caab08 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -15,8 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `transformers` models""" -from collections.abc import Iterable -from contextlib import nullcontext +from collections.abc import Iterable, Mapping +from contextlib import contextmanager, nullcontext from typing import Literal, Optional, Union import regex as re @@ -41,11 +41,21 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, PlaceholderRange) +from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo) +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils import is_list_of -from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, + SupportsQuant) from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - is_pp_missing_parameter, + flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, maybe_prefix) logger = init_logger(__name__) @@ -112,6 +122,269 @@ def replace_linear_class( ) +# Copied from `accelerate` +@contextmanager +def init_on_device_without_buffers(device: torch.device): + """ + A context manager under which models are initialized with all + parameters on the specified device. However buffers are not + initialized on specified device. + + Args: + device (`torch.device`): + Device to initialize all parameters on. + """ + + old_register_parameter = nn.Module.register_parameter + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs) + + tensor_constructors_to_patch = {} + + def patch_tensor_constructor(fn): + + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + try: + nn.Module.register_parameter = register_empty_parameter + for torch_function_name in tensor_constructors_to_patch: + setattr( + torch, torch_function_name, + patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + nn.Module.register_parameter = old_register_parameter + for torch_function_name, old_torch_function in ( + tensor_constructors_to_patch.items()): + setattr(torch, torch_function_name, old_torch_function) + + +class MultiModalProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.model_config.hf_config + + def get_supported_mm_limits(self): + return {"image": None} + + def get_mm_max_tokens_per_item(self, seq_len, mm_counts): + return {"image": self.get_max_image_tokens()} + + def get_max_image_tokens(self) -> int: + width, height = self.get_max_image_size() + processor = self.get_hf_processor() + mm_processor_kwargs = self.ctx.model_config.mm_processor_kwargs or {} + mm_tokens = processor._get_num_multimodal_tokens( + image_sizes=([height, width], ), **mm_processor_kwargs) + image_tokens = mm_tokens["num_image_tokens"][0] + return image_tokens + + def get_hf_processor(self): + processor = cached_get_processor(self.ctx.model_config.model) + return processor + + def get_max_image_size(self): + return 10_000, 10_000 # hardcode for arbitrary very large size + + +class MultiModalDummyInputsBuilder( + BaseDummyInputsBuilder[MultiModalProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + if "gemma3" in processor.__class__.__name__.lower(): + image_token = processor.boi_token + else: + image_token = getattr(processor, "image_token", "") + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_max_image_size() + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + + +class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ): + """ + Given the original multi-modal items for this modality + and HF-processed data, output the updates to perform. + + The information returned by this method is used to update token inputs + which bypass the HF processor. It is also used to update the output of + HF processor if the HF process does not apply prompt updates to text + inputs. + + Moreover, this information is critical to determine the token positions + in order to construct :class:`~vllm-multimodal.input.PlaceholderRange` + for each multi-modal item. + """ + return None + + def _get_mm_fields_config( + self, + hf_inputs, + hf_processor_mm_kwargs, + num_image_patches: torch.Tensor = None, + ): + # HF Processors always return a mask but vLLM doesn't need it + hf_inputs.pop("attention_mask", None) + mm_fields = { + key: MultiModalFieldConfig.flat_from_sizes("image", + num_image_patches) + for key in hf_inputs + } + mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes( + "image", num_image_patches) + mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") + return mm_fields + + def _apply_hf_processor_text_mm( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ): + """ + Apply the HF processor on the prompt text and multi-modal data + together. + + In addition, return whether prompt replacements have been applied. + """ + processor_data, passthrough_data = self._get_hf_mm_data(mm_items) + processor_data["return_mm_token_type_ids"] = True + + processed_data = self._call_hf_processor( + prompt=prompt_text, + mm_data=processor_data, + mm_kwargs=hf_processor_mm_kwargs, + tok_kwargs=tokenization_kwargs, + ) + processed_data.update(passthrough_data) + + prompt_ids, = processed_data.pop("input_ids").tolist() + mm_token_type_ids = processed_data.pop( + "mm_token_type_ids" + ) if "mm_token_type_ids" in processed_data else processed_data.pop( + "token_type_ids") # for gemma3 only + + return prompt_ids, processed_data, mm_token_type_ids + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + """ + Process multi-modal inputs to be used in vLLM. + + Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + """ + if return_mm_hashes: + raise ValueError( + "TransformersForMultimodalLM doesn't support mm hashing yet! " + "Probably you didn't set `disable_mm_preprocessor_cache=True`") + + if tokenization_kwargs is None: + tokenization_kwargs = {} + + mm_items = self._to_mm_items(mm_data) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + (prompt_ids, processed_data, + mm_token_type_ids) = self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + + # HF processor will return `mm_token_type_ids` from which + # we can infer mm_placeholders. Until then hardcode to make code run + # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1 + mm_positions = torch.where(mm_token_type_ids == 1)[1] + images = mm_items.get_items("image", ImageProcessorItems) + mm_processor_kwargs = (self.info.ctx.model_config.mm_processor_kwargs + or {}) + image_sizes = [] + for item_idx in range(len(images)): + image_size = images.get_image_size(item_idx) + image_sizes.append((image_size.height, image_size.width)) + + mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( + image_sizes=image_sizes, **mm_processor_kwargs) + + mm_placeholders = {} + split_sizes = mm_tokens_per_modality["num_image_tokens"] + if split_sizes: + chunked_mm_positions = torch.split(mm_positions, split_sizes) + mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()] + chunked_mm_tokens = torch.split(mm_tokens, split_sizes) + ranges = [ + PlaceholderRange( + offset=positions[0].item(), + length=positions.shape[0], + is_embed=(mm_tokens == hf_processor.image_token_id).bool()) + for positions, mm_tokens in zip(chunked_mm_positions, + chunked_mm_tokens) + ] + mm_placeholders = {"image": ranges} + + num_image_patches = torch.tensor( + mm_tokens_per_modality["num_image_patches"] + ) if "num_image_patches" in mm_tokens_per_modality else None + processed_data['num_image_patches'] = num_image_patches + mm_kwargs = MultiModalKwargs.from_hf_inputs( + processed_data, + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, + num_image_patches), + ) + + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=prompt_ids, + mm_kwargs=mm_kwargs, + mm_hashes=None, + mm_placeholders=mm_placeholders, + ) + + class ConfigOverride: """Context manager to temporarily override config attributes.""" @@ -153,6 +426,7 @@ class TransformersModel(nn.Module): quant_config: QuantizationConfig = vllm_config.quant_config self.config = config + self.text_config = config.get_text_config() self.cache_config = cache_config self.device_config = device_config self.model_config = model_config @@ -173,14 +447,16 @@ class TransformersModel(nn.Module): config_override = ConfigOverride( config, sliding_window=config.interleaved_sliding_window) - # Use meta device to delay allocating GPU tensors - with torch.device("meta"), config_override: + # Set correct attn and init on "meta" to delay allocating GPU tensors + # TODO: @raushan, use the public `model.set_attn_implementation()` + # method after v4.54.0 is released + self.text_config._attn_implementation = "vllm" + with init_on_device_without_buffers("meta"), config_override: # FIXME(Isotr0py): We need to refactor this part in the future to # avoid registering an extra model layer, otherwise we will need a # weights mapper to rename weights. self.model: PreTrainedModel = AutoModel.from_config( config, - attn_implementation="vllm", torch_dtype=model_config.dtype, trust_remote_code=model_config.trust_remote_code, ) @@ -189,27 +465,25 @@ class TransformersModel(nn.Module): self.tensor_parallel() # Input embeddings + text_config = config.get_text_config() if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): self.model.set_input_embeddings( VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, + text_config.vocab_size, + text_config.hidden_size, + org_num_embeddings=text_config.vocab_size, quant_config=quant_config, )) # Attention layers self.attention_instances = self.create_attention_instances() - # Initialize buffers (e.g. rotary embedding inverse frequency) - self.init_buffers(self.model) - # Initialize any parameters that have not had their modules replaced self.init_parameters(self.model) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + text_config.hidden_size)) def pipeline_parallel(self): """ @@ -240,14 +514,15 @@ class TransformersModel(nn.Module): # Layers before module list for name in pp_plan[:module_list_idx]: - if self.pp_group.is_first_rank or (self.config.tie_word_embeddings - and self.pp_group.is_last_rank): + if self.pp_group.is_first_rank or ( + self.text_config.tie_word_embeddings + and self.pp_group.is_last_rank): continue setattr(self.model, name, PPMissingLayer()) # Module list - start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers, - self.pp_rank, self.pp_size) + start_layer, end_layer = get_pp_indices( + self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) layers_name = pp_plan[module_list_idx] layers = getattr(self.model, layers_name) for i in range(len(layers)): @@ -298,7 +573,7 @@ class TransformersModel(nn.Module): self.parallel_config) head_size = self.model_config.get_head_size() num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - start, end = get_pp_indices(self.config.num_hidden_layers, + start, end = get_pp_indices(self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) attention_instances = {} @@ -323,35 +598,6 @@ class TransformersModel(nn.Module): prefix=f"{i}.attn") return attention_instances - def init_buffers(self, module: nn.Module): - """ - If a `buffer` is on the `meta` device, then its parent - `module` is the original module created by: - - ```python - with torch.device("meta"): - self.model: PreTrainedModel = AutoModel.from_config(...) - ``` - - This means that: - - `type(module)` is a class from `transformers` - - This class is constructed using a `PretrainedConfig` - """ - for name, buffer in module.named_buffers(recurse=False): - if buffer.device == torch.device("meta"): - if module == self.model: - logger.warning( - "To initialize buffers correctly, we instantiate the " - "parent module and and extract the value of the " - "buffer from it. In this case, the parent module is " - "the base model. Instantiating the entire model here " - "risks GPU OOM. Could this buffer be moved to a child " - "module?") - new_buffer = getattr(type(module)(self.config), name) - setattr(module, name, new_buffer) - for child in module.children(): - self.init_buffers(child) - def init_parameters(self, module: nn.Module): """ If a `parameter` is on the `meta` device, then its parent @@ -366,6 +612,7 @@ class TransformersModel(nn.Module): if param.device == torch.device("meta"): new_param = nn.Parameter( torch.empty_like(param.data, + dtype=self.model_config.dtype, device=self.device_config.device)) setattr(module, name, new_param) for child in module.children(): @@ -391,11 +638,16 @@ class TransformersModel(nn.Module): if inputs_embeds is not None: inputs_embeds = inputs_embeds[None, ...] + if self.model_config.uses_mrope: + position_ids = positions[:, None] + else: + position_ids = positions[None, ...] + hidden_states = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, use_cache=False, - position_ids=positions[None, ...], + position_ids=position_ids, attention_instances=self.attention_instances, return_dict=False)[0][0, ...] # we remove batch dimension for now @@ -507,3 +759,180 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder) +class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, + SupportsPP, SupportsMultiModal): + embedding_padding_modules = ["lm_head"] + embedding_modules = ["embed_tokens"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: PretrainedConfig = vllm_config.model_config.hf_config + quant_config: QuantizationConfig = vllm_config.quant_config + + self.config = config + self.dtype = vllm_config.model_config.dtype + + self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix) + text_config = config.get_text_config() + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = text_config.vocab_size + self.lm_head = ParallelLMHead( + text_config.vocab_size, + text_config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if text_config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.get_input_embeddings()) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + text_config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + @property + def hf_to_vllm_mapper(self): + # Backwards compatibility for prev released models + # State dicts back then had different formats + # and cannot be loaded with `AutoModel` mapping + # as is + prefix_mapper = { + "language_model.model": "model.language_model", + "text_model.model": "model.text_model", + "vision_tower": "model.vision_tower", + "vqmodel": "model.vqmodel", + "vision_model": "model.vision_model", + "vision_embed_tokens": "model.vision_embed_tokens", + "image_newline": "model.image_newline", + "multi_modal_projector": "model.multi_modal_projector", + "text_model.lm_head": "lm_head", + "language_model.lm_head": "lm_head", + } + # Don't change the order for QwenVL + if 'Qwen2' in self.config.__class__.__name__: + prefix_mapper["model"] = "model.language_model" + prefix_mapper["visual"] = "model.visual" + + return WeightsMapper(orig_to_new_prefix=prefix_mapper, ) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + if inputs_embeds is None: + multimodal_embeds = self.get_multimodal_embeddings(**kwargs) + if multimodal_embeds is not None: + inputs_embeds = self.get_input_embeddings( + input_ids, multimodal_embeds) + input_ids = None + + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=([ + "lm_head." + ] if self.config.get_text_config().tie_word_embeddings else None), + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_multimodal_embeddings(self, **kwargs): + pixel_values = kwargs.pop("pixel_values", None) + pixel_values = pixel_values if pixel_values is not None else kwargs.pop( + "image_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + + if image_embeds is not None: + return image_embeds + + if pixel_values is None and image_embeds is None: + return None + + num_image_patches = kwargs.pop("num_image_patches") + if pixel_values is not None: + if isinstance(pixel_values, torch.Tensor): + pixel_values = flatten_bn(pixel_values).to(self.dtype) + elif is_list_of(pixel_values, torch.Tensor): + pixel_values = flatten_bn(flatten_bn(pixel_values), + concat=True).to(self.dtype) + else: + raise ValueError( + f"Unsupported pixel_values type {type(pixel_values)}. " + "Expected `torch.Tensor` or list of `torch.Tensor`.") + + if isinstance(num_image_patches, list): + num_image_patches = torch.cat(num_image_patches) + + vision_embeddings = self.model.model.get_image_features( + pixel_values, + **{ + k: v.flatten(0, 1) + for k, v in kwargs.items() + }, + ) + + if isinstance(vision_embeddings, torch.Tensor): + if vision_embeddings.ndim == 2: + vision_embeddings = vision_embeddings.unsqueeze(0) + + # Embeddings have to be 2D tensors of length `num_images` + # but transformers returns concat tensors if each patch + # is of different size. We split it back to make vLLM happy + vision_embeddings = torch.split( + vision_embeddings, + num_image_patches.flatten().tolist()) + vision_embeddings = [ + embed.flatten(start_dim=0, end_dim=-2) + for embed in vision_embeddings + ] + + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings=None, + ) -> torch.Tensor: + inputs_embeds = self.model.model.get_input_embeddings()(input_ids) + if (multimodal_embeddings is not None + and len(multimodal_embeddings) != 0): + mask = (input_ids == self.config.image_token_id) + mask = mask.unsqueeze(-1).expand_as(inputs_embeds) + multimodal_embeddings = torch.cat(multimodal_embeddings) + + inputs_embeds = inputs_embeds.masked_scatter( + mask, multimodal_embeddings) + return inputs_embeds