[Model] Support VLMs with transformers backend (#20543)

Signed-off-by: raushan <raushan@huggingface.co>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Raushan Turganbay 2025-07-20 15:25:50 +02:00 committed by GitHub
parent 51ba839555
commit 9499e26e2a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 625 additions and 87 deletions

View File

@ -18,7 +18,7 @@ These models are what we list in [supported-text-models][supported-text-models]
### Transformers
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models are supported, and vision language model support is planned!
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs, and require setting `--disable_mm_preprocessor_cache` when running. Support for video inputs and caching of multi-modal preprocessors will be added in future releases.
To check if the modeling backend is Transformers, you can simply do this:
@ -28,7 +28,7 @@ llm = LLM(model=..., task="generate") # Name or path of your model
llm.apply_model(lambda model: print(type(model)))
```
If it is `TransformersForCausalLM` then it means it's based on Transformers!
If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers!
!!! tip
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference](../serving/offline_inference.md) or `--model-impl transformers` for the [openai-compatible-server](../serving/openai_compatible_server.md).
@ -36,6 +36,9 @@ If it is `TransformersForCausalLM` then it means it's based on Transformers!
!!! note
vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM.
!!! note
In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance.
#### Custom models
If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM!
@ -99,7 +102,7 @@ Here is what happens in the background when this model is loaded:
1. The config is loaded.
2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`.
3. `MyModel` is loaded into `TransformersForCausalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
That's it!

View File

@ -35,6 +35,8 @@ if current_platform.is_rocm():
REQUIRES_V0_MODELS = [
# V1 Test: not enough KV cache space in C1.
"fuyu",
# V1 Test: Deadlock issue when processing mm_inputs
"llava-onevision-transformers",
]
# yapf: disable
@ -170,6 +172,79 @@ VLM_TEST_SETTINGS = {
hf_output_post_proc=model_utils.ultravox_trunc_hf_output,
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Transformers fallback to test
## To reduce test burden, we only test batching arbitrary image size
# Dynamic image length and number of patches
"llava-onevision-transformers": VLMTestInfo(
models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"],
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
max_model_len=16384,
hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
image_size_factors=[(0.25, 0.5, 1.0)],
vllm_runner_kwargs={
"model_impl": "transformers",
"disable_mm_preprocessor_cache": True,
"enable_prefix_caching": False,
},
marks=[pytest.mark.core_model],
),
# FIXME(Isotr0py): Enable this test after
# https://github.com/huggingface/transformers/pull/39470 released
# "idefics3-transformers": VLMTestInfo(
# models=["HuggingFaceTB/SmolVLM-256M-Instruct"],
# test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
# prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501
# img_idx_to_prompt=lambda idx: "<image>",
# max_model_len=8192,
# max_num_seqs=2,
# auto_cls=AutoModelForImageTextToText,
# hf_output_post_proc=model_utils.idefics3_trunc_hf_output,
# image_size_factors=[(0.25, 0.5, 1.0)],
# vllm_runner_kwargs={
# "model_impl": "transformers",
# "disable_mm_preprocessor_cache": True,
# "enable_prefix_caching": False,
# },
# marks=[pytest.mark.core_model],
# ),
# Pixel values from processor are not 4D or 5D arrays
"qwen2_5_vl-transformers": VLMTestInfo(
models=["Qwen/Qwen2.5-VL-3B-Instruct"],
test_type=VLMTestType.IMAGE,
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForImageTextToText,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
image_size_factors=[(0.25, 0.2, 0.15)],
vllm_runner_kwargs={
"model_impl": "transformers",
"disable_mm_preprocessor_cache": True,
"enable_prefix_caching": False,
},
marks=[large_gpu_mark(min_gb=32)],
),
# Check "auto" with fallback to transformers
"internvl-transformers": VLMTestInfo(
models=["OpenGVLab/InternVL3-1B-hf"],
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<IMG_CONTEXT>",
max_model_len=4096,
use_tokenizer_eos=True,
image_size_factors=[(0.25, 0.5, 1.0)],
vllm_runner_kwargs={
"model_impl": "auto",
"disable_mm_preprocessor_cache": True,
"enable_prefix_caching": False,
},
auto_cls=AutoModelForImageTextToText,
marks=[pytest.mark.core_model],
),
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],

View File

@ -499,6 +499,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
_TRANSFORMERS_MODELS = {
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
"TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"),
}
_EXAMPLE_MODELS = {

View File

@ -562,6 +562,10 @@ class ModelConfig:
self.task = "embed"
model_info, arch = self.registry.inspect_model_cls(self.architectures)
self._model_info = model_info
self._architecture = arch
all_supported_tasks = self._get_supported_tasks(self.task)
logger.debug("Tasks supported by runner type: %s", all_supported_tasks)
supported_runner_types = self._get_supported_runner_types(
@ -587,10 +591,6 @@ class ModelConfig:
else:
self.truncation_side = "right"
model_info, arch = self.registry.inspect_model_cls(self.architectures)
self._model_info = model_info
self._architecture = arch
self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype(
@ -674,6 +674,16 @@ class ModelConfig:
"max_model_len must be an integer after __post_init__.")
return self
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
return "TransformersForMultimodalLM"
else:
return "TransformersForCausalLM"
@property
def registry(self):
return me_models.ModelRegistry
@ -681,7 +691,19 @@ class ModelConfig:
@property
def architectures(self) -> list[str]:
# architectures in the model config.
return getattr(self.hf_config, "architectures", [])
architectures = getattr(self.hf_config, "architectures", [])
# The registry assumes that it can always inspect the vLLM model class
# for a given architecture. This assumption breaks down for the
# Transformers backend, which may use a different class depending on
# the model type. To work around this, we add the correct Transformers
# backend class to the architectures list. We must do this here because
# we need access to the `hf_config` to determine the backend class.
transformers_backend_cls = self._get_transformers_backend_cls()
if (self.model_impl != ModelImpl.VLLM.value
and all(arch != transformers_backend_cls
for arch in architectures)):
architectures.append(transformers_backend_cls)
return architectures
@property
def architecture(self) -> str:
@ -827,10 +849,9 @@ class ModelConfig:
("EmbeddingModel", "embed"),
("RewardModel", "reward"),
]
_, arch = self.registry.inspect_model_cls(architectures)
for suffix, pref_task in suffix_to_preferred_task:
if arch.endswith(suffix):
if self.architecture.endswith(suffix):
return pref_task
return "embed"
@ -944,10 +965,10 @@ class ModelConfig:
("EmbeddingModel", "pooling"),
("RewardModel", "pooling"),
]
_, arch = self.registry.inspect_model_cls(self.architectures)
for suffix, pref_runner in suffix_to_preferred_runner:
if arch.endswith(suffix) and pref_runner in supported_runner_types:
if self.architecture.endswith(
suffix) and pref_runner in supported_runner_types:
return pref_runner
if "generate" in supported_runner_types:

View File

@ -25,6 +25,7 @@ from vllm.model_executor.models.adapters import (as_embedding_model,
as_reward_model,
as_seq_cls_model)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.model_executor.models.registry import _TRANSFORMERS_MODELS
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@ -169,9 +170,22 @@ def device_loading_context(module: torch.nn.Module,
def resolve_transformers_arch(model_config: ModelConfig,
architectures: list[str]):
if model_config.model_impl == ModelImpl.VLLM:
raise ValueError(
"Attempting to resolve architecture from the Transformers library "
"but the model implementation is set to vLLM. This should never "
"happen.")
for i, arch in enumerate(architectures):
if arch == "TransformersForCausalLM":
if arch in _TRANSFORMERS_MODELS:
continue
if model_config.model_impl == ModelImpl.AUTO:
logger.warning(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
None) or dict()
# Make sure that config class is always initialized before model class,
@ -199,25 +213,13 @@ def resolve_transformers_arch(model_config: ModelConfig,
"not present in the model config's 'auto_map' (relevant "
"if the model is custom).")
model_module = auto_modules["AutoModel"]
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not model_module.is_backend_compatible():
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersForCausalLM"
if model_config.model_impl == ModelImpl.AUTO:
if not model_module.is_backend_compatible():
raise ValueError(
f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM. Try setting "
"VLLM_USE_V1=0.")
logger.warning(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
architectures[i] = "TransformersForCausalLM"
if not model_module.is_backend_compatible():
raise ValueError(
f"The Transformers implementation of '{arch}' is not "
"compatible with vLLM.")
architectures[i] = model_config._get_transformers_backend_cls()
return architectures
@ -237,8 +239,9 @@ def get_model_architecture(
]
vllm_supported_archs = ModelRegistry.get_supported_archs()
vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures)
is_supported = lambda arch: (arch in vllm_supported_archs and arch not in
_TRANSFORMERS_MODELS)
vllm_not_supported = not any(is_supported(arch) for arch in architectures)
if vllm_not_supported:
# try automatic conversion in adapters.py
@ -259,7 +262,7 @@ def get_model_architecture(
break
if (model_config.model_impl == ModelImpl.TRANSFORMERS or
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
model_config.model_impl == ModelImpl.AUTO and vllm_not_supported):
architectures = resolve_transformers_arch(model_config, architectures)
logger.debug_once("Resolve transformers arch %s", str(architectures))
elif (model_config.quantization is not None

View File

@ -253,6 +253,7 @@ _SPECULATIVE_DECODING_MODELS = {
}
_TRANSFORMERS_MODELS = {
"TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
}
# yapf: enable
@ -504,9 +505,14 @@ class _ModelRegistry:
if causal_lm_arch in self.models:
normalized_arch.append(arch)
# make sure Transformers backend is put at the last as a fallback
if len(normalized_arch) != len(architectures):
normalized_arch.append("TransformersForCausalLM")
# NOTE(Isotr0py): Be careful of architectures' order!
# Make sure Transformers backend architecture is at the end of the
# list, otherwise pooling models automatic conversion will fail!
for arch in normalized_arch:
if arch.startswith("TransformersFor"):
normalized_arch.remove(arch)
normalized_arch.append(arch)
return normalized_arch
def inspect_model_cls(

View File

@ -15,8 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
from collections.abc import Iterable
from contextlib import nullcontext
from collections.abc import Iterable, Mapping
from contextlib import contextmanager, nullcontext
from typing import Literal, Optional, Union
import regex as re
@ -41,11 +41,21 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter,
flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__)
@ -112,6 +122,269 @@ def replace_linear_class(
)
# Copied from `accelerate`
@contextmanager
def init_on_device_without_buffers(device: torch.device):
"""
A context manager under which models are initialized with all
parameters on the specified device. However buffers are not
initialized on specified device.
Args:
device (`torch.device`):
Device to initialize all parameters on.
"""
old_register_parameter = nn.Module.register_parameter
def register_empty_parameter(module, name, param):
old_register_parameter(module, name, param)
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
kwargs["requires_grad"] = param.requires_grad
module._parameters[name] = param_cls(
module._parameters[name].to(device), **kwargs)
tensor_constructors_to_patch = {}
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
kwargs["device"] = device
return fn(*args, **kwargs)
return wrapper
try:
nn.Module.register_parameter = register_empty_parameter
for torch_function_name in tensor_constructors_to_patch:
setattr(
torch, torch_function_name,
patch_tensor_constructor(getattr(torch, torch_function_name)))
yield
finally:
nn.Module.register_parameter = old_register_parameter
for torch_function_name, old_torch_function in (
tensor_constructors_to_patch.items()):
setattr(torch, torch_function_name, old_torch_function)
class MultiModalProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.model_config.hf_config
def get_supported_mm_limits(self):
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len, mm_counts):
return {"image": self.get_max_image_tokens()}
def get_max_image_tokens(self) -> int:
width, height = self.get_max_image_size()
processor = self.get_hf_processor()
mm_processor_kwargs = self.ctx.model_config.mm_processor_kwargs or {}
mm_tokens = processor._get_num_multimodal_tokens(
image_sizes=([height, width], ), **mm_processor_kwargs)
image_tokens = mm_tokens["num_image_tokens"][0]
return image_tokens
def get_hf_processor(self):
processor = cached_get_processor(self.ctx.model_config.model)
return processor
def get_max_image_size(self):
return 10_000, 10_000 # hardcode for arbitrary very large size
class MultiModalDummyInputsBuilder(
BaseDummyInputsBuilder[MultiModalProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
processor = self.info.get_hf_processor()
if "gemma3" in processor.__class__.__name__.lower():
image_token = processor.boi_token
else:
image_token = getattr(processor, "image_token", "")
return image_token * num_images
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = self.info.get_max_image_size()
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images),
}
class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
):
"""
Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform.
The information returned by this method is used to update token inputs
which bypass the HF processor. It is also used to update the output of
HF processor if the HF process does not apply prompt updates to text
inputs.
Moreover, this information is critical to determine the token positions
in order to construct :class:`~vllm-multimodal.input.PlaceholderRange`
for each multi-modal item.
"""
return None
def _get_mm_fields_config(
self,
hf_inputs,
hf_processor_mm_kwargs,
num_image_patches: torch.Tensor = None,
):
# HF Processors always return a mask but vLLM doesn't need it
hf_inputs.pop("attention_mask", None)
mm_fields = {
key: MultiModalFieldConfig.flat_from_sizes("image",
num_image_patches)
for key in hf_inputs
}
mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes(
"image", num_image_patches)
mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image")
return mm_fields
def _apply_hf_processor_text_mm(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
):
"""
Apply the HF processor on the prompt text and multi-modal data
together.
In addition, return whether prompt replacements have been applied.
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
processor_data["return_mm_token_type_ids"] = True
processed_data = self._call_hf_processor(
prompt=prompt_text,
mm_data=processor_data,
mm_kwargs=hf_processor_mm_kwargs,
tok_kwargs=tokenization_kwargs,
)
processed_data.update(passthrough_data)
prompt_ids, = processed_data.pop("input_ids").tolist()
mm_token_type_ids = processed_data.pop(
"mm_token_type_ids"
) if "mm_token_type_ids" in processed_data else processed_data.pop(
"token_type_ids") # for gemma3 only
return prompt_ids, processed_data, mm_token_type_ids
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
return_mm_hashes: bool = False,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
"""
if return_mm_hashes:
raise ValueError(
"TransformersForMultimodalLM doesn't support mm hashing yet! "
"Probably you didn't set `disable_mm_preprocessor_cache=True`")
if tokenization_kwargs is None:
tokenization_kwargs = {}
mm_items = self._to_mm_items(mm_data)
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
(prompt_ids, processed_data,
mm_token_type_ids) = self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
# HF processor will return `mm_token_type_ids` from which
# we can infer mm_placeholders. Until then hardcode to make code run
# Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
mm_positions = torch.where(mm_token_type_ids == 1)[1]
images = mm_items.get_items("image", ImageProcessorItems)
mm_processor_kwargs = (self.info.ctx.model_config.mm_processor_kwargs
or {})
image_sizes = []
for item_idx in range(len(images)):
image_size = images.get_image_size(item_idx)
image_sizes.append((image_size.height, image_size.width))
mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens(
image_sizes=image_sizes, **mm_processor_kwargs)
mm_placeholders = {}
split_sizes = mm_tokens_per_modality["num_image_tokens"]
if split_sizes:
chunked_mm_positions = torch.split(mm_positions, split_sizes)
mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()]
chunked_mm_tokens = torch.split(mm_tokens, split_sizes)
ranges = [
PlaceholderRange(
offset=positions[0].item(),
length=positions.shape[0],
is_embed=(mm_tokens == hf_processor.image_token_id).bool())
for positions, mm_tokens in zip(chunked_mm_positions,
chunked_mm_tokens)
]
mm_placeholders = {"image": ranges}
num_image_patches = torch.tensor(
mm_tokens_per_modality["num_image_patches"]
) if "num_image_patches" in mm_tokens_per_modality else None
processed_data['num_image_patches'] = num_image_patches
mm_kwargs = MultiModalKwargs.from_hf_inputs(
processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs,
num_image_patches),
)
return MultiModalInputs(
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=None,
mm_placeholders=mm_placeholders,
)
class ConfigOverride:
"""Context manager to temporarily override config attributes."""
@ -153,6 +426,7 @@ class TransformersModel(nn.Module):
quant_config: QuantizationConfig = vllm_config.quant_config
self.config = config
self.text_config = config.get_text_config()
self.cache_config = cache_config
self.device_config = device_config
self.model_config = model_config
@ -173,14 +447,16 @@ class TransformersModel(nn.Module):
config_override = ConfigOverride(
config, sliding_window=config.interleaved_sliding_window)
# Use meta device to delay allocating GPU tensors
with torch.device("meta"), config_override:
# Set correct attn and init on "meta" to delay allocating GPU tensors
# TODO: @raushan, use the public `model.set_attn_implementation()`
# method after v4.54.0 is released
self.text_config._attn_implementation = "vllm"
with init_on_device_without_buffers("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
self.model: PreTrainedModel = AutoModel.from_config(
config,
attn_implementation="vllm",
torch_dtype=model_config.dtype,
trust_remote_code=model_config.trust_remote_code,
)
@ -189,27 +465,25 @@ class TransformersModel(nn.Module):
self.tensor_parallel()
# Input embeddings
text_config = config.get_text_config()
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
self.model.set_input_embeddings(
VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
text_config.vocab_size,
text_config.hidden_size,
org_num_embeddings=text_config.vocab_size,
quant_config=quant_config,
))
# Attention layers
self.attention_instances = self.create_attention_instances()
# Initialize buffers (e.g. rotary embedding inverse frequency)
self.init_buffers(self.model)
# Initialize any parameters that have not had their modules replaced
self.init_parameters(self.model)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
text_config.hidden_size))
def pipeline_parallel(self):
"""
@ -240,14 +514,15 @@ class TransformersModel(nn.Module):
# Layers before module list
for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or (self.config.tie_word_embeddings
and self.pp_group.is_last_rank):
if self.pp_group.is_first_rank or (
self.text_config.tie_word_embeddings
and self.pp_group.is_last_rank):
continue
setattr(self.model, name, PPMissingLayer())
# Module list
start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
start_layer, end_layer = get_pp_indices(
self.text_config.num_hidden_layers, self.pp_rank, self.pp_size)
layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name)
for i in range(len(layers)):
@ -298,7 +573,7 @@ class TransformersModel(nn.Module):
self.parallel_config)
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(self.config.num_hidden_layers,
start, end = get_pp_indices(self.text_config.num_hidden_layers,
self.pp_rank, self.pp_size)
attention_instances = {}
@ -323,35 +598,6 @@ class TransformersModel(nn.Module):
prefix=f"{i}.attn")
return attention_instances
def init_buffers(self, module: nn.Module):
"""
If a `buffer` is on the `meta` device, then its parent
`module` is the original module created by:
```python
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(...)
```
This means that:
- `type(module)` is a class from `transformers`
- This class is constructed using a `PretrainedConfig`
"""
for name, buffer in module.named_buffers(recurse=False):
if buffer.device == torch.device("meta"):
if module == self.model:
logger.warning(
"To initialize buffers correctly, we instantiate the "
"parent module and and extract the value of the "
"buffer from it. In this case, the parent module is "
"the base model. Instantiating the entire model here "
"risks GPU OOM. Could this buffer be moved to a child "
"module?")
new_buffer = getattr(type(module)(self.config), name)
setattr(module, name, new_buffer)
for child in module.children():
self.init_buffers(child)
def init_parameters(self, module: nn.Module):
"""
If a `parameter` is on the `meta` device, then its parent
@ -366,6 +612,7 @@ class TransformersModel(nn.Module):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data,
dtype=self.model_config.dtype,
device=self.device_config.device))
setattr(module, name, new_param)
for child in module.children():
@ -391,11 +638,16 @@ class TransformersModel(nn.Module):
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[None, ...]
if self.model_config.uses_mrope:
position_ids = positions[:, None]
else:
position_ids = positions[None, ...]
hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False,
position_ids=positions[None, ...],
position_ids=position_ids,
attention_instances=self.attention_instances,
return_dict=False)[0][0, ...] # we remove batch dimension for now
@ -507,3 +759,180 @@ class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
@MULTIMODAL_REGISTRY.register_processor(
MultiModalProcessor,
info=MultiModalProcessingInfo,
dummy_inputs=MultiModalDummyInputsBuilder)
class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA,
SupportsPP, SupportsMultiModal):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"]
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
quant_config: QuantizationConfig = vllm_config.quant_config
self.config = config
self.dtype = vllm_config.model_config.dtype
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
text_config = config.get_text_config()
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = text_config.vocab_size
self.lm_head = ParallelLMHead(
text_config.vocab_size,
text_config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if text_config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
text_config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@property
def hf_to_vllm_mapper(self):
# Backwards compatibility for prev released models
# State dicts back then had different formats
# and cannot be loaded with `AutoModel` mapping
# as is
prefix_mapper = {
"language_model.model": "model.language_model",
"text_model.model": "model.text_model",
"vision_tower": "model.vision_tower",
"vqmodel": "model.vqmodel",
"vision_model": "model.vision_model",
"vision_embed_tokens": "model.vision_embed_tokens",
"image_newline": "model.image_newline",
"multi_modal_projector": "model.multi_modal_projector",
"text_model.lm_head": "lm_head",
"language_model.lm_head": "lm_head",
}
# Don't change the order for QwenVL
if 'Qwen2' in self.config.__class__.__name__:
prefix_mapper["model"] = "model.language_model"
prefix_mapper["visual"] = "model.visual"
return WeightsMapper(orig_to_new_prefix=prefix_mapper, )
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
# NOTE: In v1, inputs_embeds is always generated at model runner from
# `get_multimodal_embeddings` and `get_input_embeddings`, this
# condition is only for v0 compatibility.
if inputs_embeds is None:
multimodal_embeds = self.get_multimodal_embeddings(**kwargs)
if multimodal_embeds is not None:
inputs_embeds = self.get_input_embeddings(
input_ids, multimodal_embeds)
input_ids = None
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=([
"lm_head."
] if self.config.get_text_config().tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def get_multimodal_embeddings(self, **kwargs):
pixel_values = kwargs.pop("pixel_values", None)
pixel_values = pixel_values if pixel_values is not None else kwargs.pop(
"image_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
if image_embeds is not None:
return image_embeds
if pixel_values is None and image_embeds is None:
return None
num_image_patches = kwargs.pop("num_image_patches")
if pixel_values is not None:
if isinstance(pixel_values, torch.Tensor):
pixel_values = flatten_bn(pixel_values).to(self.dtype)
elif is_list_of(pixel_values, torch.Tensor):
pixel_values = flatten_bn(flatten_bn(pixel_values),
concat=True).to(self.dtype)
else:
raise ValueError(
f"Unsupported pixel_values type {type(pixel_values)}. "
"Expected `torch.Tensor` or list of `torch.Tensor`.")
if isinstance(num_image_patches, list):
num_image_patches = torch.cat(num_image_patches)
vision_embeddings = self.model.model.get_image_features(
pixel_values,
**{
k: v.flatten(0, 1)
for k, v in kwargs.items()
},
)
if isinstance(vision_embeddings, torch.Tensor):
if vision_embeddings.ndim == 2:
vision_embeddings = vision_embeddings.unsqueeze(0)
# Embeddings have to be 2D tensors of length `num_images`
# but transformers returns concat tensors if each patch
# is of different size. We split it back to make vLLM happy
vision_embeddings = torch.split(
vision_embeddings,
num_image_patches.flatten().tolist())
vision_embeddings = [
embed.flatten(start_dim=0, end_dim=-2)
for embed in vision_embeddings
]
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings=None,
) -> torch.Tensor:
inputs_embeds = self.model.model.get_input_embeddings()(input_ids)
if (multimodal_embeddings is not None
and len(multimodal_embeddings) != 0):
mask = (input_ids == self.config.image_token_id)
mask = mask.unsqueeze(-1).expand_as(inputs_embeds)
multimodal_embeddings = torch.cat(multimodal_embeddings)
inputs_embeds = inputs_embeds.masked_scatter(
mask, multimodal_embeddings)
return inputs_embeds