[VLM] Calculate maximum number of multi-modal tokens by model (#6121)

This commit is contained in:
Cyrus Leung 2024-07-05 07:37:23 +08:00 committed by GitHub
parent 69ec3ca14c
commit ae96ef8fbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 265 additions and 95 deletions

View File

@ -51,17 +51,16 @@ As usual, follow :ref:`these steps <adding_a_new_model>` to implement the model
2. Register input mappers
-------------------------
For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
For each modality type that the model accepts as input, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`.
.. code-block:: diff
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsVision
+ from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
+ @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
@ -69,22 +68,22 @@ A default mapper is available for each modality in the core vLLM library. This i
:ref:`input_processing_pipeline`
3. (Optional) Register dummy data
---------------------------------
3. Register maximum number of multimodal tokens
----------------------------------------------------------
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
For each modality type that the model accepts as input, calculate the maximum possible number of tokens
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
@MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
Here are some examples:
@ -95,7 +94,36 @@ Here are some examples:
:ref:`input_processing_pipeline`
4. (Optional) Register input processor
4. (Optional) Register dummy data
---------------------------------
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
.. note::
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
Here are some examples:
- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__
.. seealso::
:ref:`input_processing_pipeline`
5. (Optional) Register input processor
--------------------------------------
Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor.
@ -104,15 +132,15 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsVision):
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:

View File

@ -25,13 +25,8 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
.. important::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
every model to perform profiling with.
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
@ -104,13 +99,8 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
.. important::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
every model to perform profiling with.
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
internally for each model.
To consume the server, you can use the OpenAI client like in the example below:

View File

@ -51,7 +51,7 @@ class InputContext:
additionally checking its type.
Raises:
ValueError: If the model is not of the specified type.
TypeError: If the model is not of the specified type.
"""
hf_config = self.model_config.hf_config

View File

@ -35,6 +35,10 @@ def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
patch_size=hf_config.patch_size)
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
return get_clip_image_feature_size(hf_config)
def dummy_seq_data_for_clip(
hf_config: CLIPVisionConfig,
seq_len: int,

View File

@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
get_max_clip_image_tokens, input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
@ -62,6 +62,17 @@ class LlavaImagePixelInputs(TypedDict):
LlavaImageInputs = LlavaImagePixelInputs
def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return get_max_clip_image_tokens(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def dummy_data_for_llava(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
@ -102,6 +113,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):

View File

@ -127,6 +127,17 @@ def get_llava_next_image_feature_size(
raise NotImplementedError(msg)
def get_max_llava_next_image_tokens(ctx: InputContext):
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
dummy_height = dummy_width = 448
return get_llava_next_image_feature_size(
ctx.get_hf_config(LlavaNextConfig),
input_height=dummy_height,
input_width=dummy_width,
)
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
@ -198,6 +209,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):

View File

@ -321,6 +321,17 @@ def get_phi3v_image_feature_size(
+ (new_height // 336 + 1) * 12
def get_max_phi3v_image_tokens(ctx: InputContext):
# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50
return get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
input_height=dummy_height,
input_width=dummy_width,
)
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50
@ -429,6 +440,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):

View File

@ -97,9 +97,19 @@ the corresponding plugin with the same modality key is applied.
"""
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
"""Return a dictionary to be passed as keyword arguments to
"""
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers."""
and processors in HuggingFace Transformers.
If the data is not supported, throw :exc:`TypeError`.
"""
MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
N = TypeVar("N", bound=Type[nn.Module])
@ -117,6 +127,7 @@ class MultiModalPlugin(ABC):
def __init__(self) -> None:
self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
@abstractmethod
def get_data_key(self) -> str:
@ -128,9 +139,12 @@ class MultiModalPlugin(ABC):
@abstractmethod
def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs:
"""Return a dictionary to be passed as keyword arguments to
"""
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
If the data is not supported, throw :exc:`TypeError`.
"""
raise NotImplementedError
@ -140,9 +154,11 @@ class MultiModalPlugin(ABC):
):
"""
Register an input mapper to a model class.
When the model receives input data that matches the modality served by
this plugin (see :meth:`get_data_type`), the provided function is
this plugin (see :meth:`get_data_key`), the provided function is
invoked to transform the data into a dictionary of model inputs.
If `None` is provided, then the default input mapper is used instead.
See also:
@ -170,10 +186,11 @@ class MultiModalPlugin(ABC):
Apply an input mapper to a data passed
to the model, transforming the data into a dictionary of model inputs.
If the data is not something that the mapper expects, throws TypeError.
The model is identified by ``model_config``.
Raises:
TypeError: If the data type is not supported.
See also:
:ref:`adding_a_new_multimodal_model`
"""
@ -188,3 +205,79 @@ class MultiModalPlugin(ABC):
f"model class {model_cls.__name__}.")
return mapper(InputContext(model_config), data)
@abstractmethod
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
raise NotImplementedError
def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
if max_mm_tokens < 1:
raise ValueError("You should set the number of tokens to a "
f"positive integer. Found: {max_mm_tokens}")
def register_max_multimodal_tokens(
self,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of multi-modal tokens input to the
language model for a model class.
If `None` is provided, then the default calculation is used instead.
See also:
:ref:`adding_a_new_multimodal_model`
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._max_mm_tokens:
logger.warning(
"Model class %s already calculates maximum number of "
"tokens in %s. It is overwritten by the new one.",
model_cls, self)
if isinstance(max_mm_tokens, int):
self._validate_max_multimodal_tokens(max_mm_tokens)
self._max_mm_tokens[model_cls] = max_mm_tokens \
or self._default_max_multimodal_tokens
return model_cls
return wrapper
def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
If this registry is not applicable to the model, `0` is returned.
The model is identified by ``model_config``.
See also:
:ref:`adding_a_new_multimodal_model`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
if model_cls not in self._input_mappers:
return 0
max_mm_tokens = self._max_mm_tokens.get(model_cls)
if max_mm_tokens is None:
raise KeyError(f"No maximum number of multi-modal tokens is given "
f"for model class {model_cls.__name__} in {self}.")
if callable(max_mm_tokens):
max_mm_tokens = max_mm_tokens(InputContext(model_config))
self._validate_max_multimodal_tokens(max_mm_tokens)
return max_mm_tokens

View File

@ -130,3 +130,6 @@ class ImagePlugin(MultiModalPlugin):
raise NotImplementedError("Embeddings input is not supported yet")
raise TypeError(f"Invalid image type: {type(data)}")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
return 3000

View File

@ -7,7 +7,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
MultiModalPlugin)
MultiModalPlugin, MultiModalTokensCalc)
from .image import ImagePlugin
logger = init_logger(__name__)
@ -48,28 +48,9 @@ class MultiModalRegistry:
msg = f"Unknown multi-modal data type: {data_type_key}"
raise NotImplementedError(msg)
def register_image_input_mapper(
self,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for image data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper("image", mapper)
def _process_input(self, key: str, value: object,
model_config: ModelConfig) -> MultiModalInputs:
plugin = self._plugins.get(key)
if plugin:
return plugin.map_input(model_config, value)
msg = f"Unknown multi-modal data type: {key}"
raise NotImplementedError(msg)
def register_input_mapper(
self,
data_type: str,
data_type_key: str,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
@ -77,16 +58,14 @@ class MultiModalRegistry:
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
plugin = self._plugins.get(data_type)
if not plugin:
msg = f"Unknown multi-modal data type: {data_type}"
raise NotImplementedError(msg)
return plugin.register_input_mapper(mapper)
return self._get_plugin(data_type_key).register_input_mapper(mapper)
def register_image_input(self,
mapper: Optional[MultiModalInputMapper] = None):
def register_image_input_mapper(
self,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for image pixel data to a model class.
Register an input mapper for image data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
@ -102,8 +81,8 @@ class MultiModalRegistry:
merged_dict: Dict[str, torch.Tensor] = {}
for data_key, data_value in data.items():
input_dict = self._process_input(data_key, data_value,
model_config)
input_dict = self._get_plugin(data_key) \
.map_input(model_config, data_value)
for input_key, input_tensor in input_dict.items():
if input_key in merged_dict:
@ -121,9 +100,35 @@ class MultiModalRegistry:
"""
return functools.partial(self.map_input, model_config)
def get_num_input_tokens(self):
def register_max_multimodal_tokens(
self,
data_type_key: str,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Get the number of input tokens for profiling purposes.
Register the maximum number of tokens, belonging to a
specific modality, input to the language model for a model class.
"""
# TODO: Provide this number on a per model basis.
return 3000
return self._get_plugin(data_type_key) \
.register_max_multimodal_tokens(max_mm_tokens)
def register_max_image_tokens(
self,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of image tokens
input to the language model for a model class.
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)
def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
"""
return sum(
plugin.get_max_multimodal_tokens(model_config)
for plugin in self._plugins.values())

View File

@ -820,12 +820,19 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
model_config = self.model_config
if supports_vision(self.model):
max_num_seqs = max(
1,
min(
max_num_seqs,
int(max_num_batched_tokens /
MULTIMODAL_REGISTRY.get_num_input_tokens())))
max_mm_tokens = MULTIMODAL_REGISTRY \
.get_max_multimodal_tokens(model_config)
max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs,
max_num_batched_tokens // max_mm_tokens)
if max_num_seqs < 1:
expr = (f"min({max_num_seqs_orig}, "
f"{max_num_batched_tokens} // {max_mm_tokens})")
logger.warning(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1
batch_size = 0
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +

View File

@ -168,14 +168,18 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
model_config = self.model_config
if supports_vision(self.model):
# TODO: properly inject these numbers from MultiModalRegistry.
# Right now, just use an overly conservative number.
max_num_seqs = max(
1,
min(
max_num_seqs,
int(max_num_batched_tokens /
MULTIMODAL_REGISTRY.get_num_input_tokens())))
max_mm_tokens = MULTIMODAL_REGISTRY \
.get_max_multimodal_tokens(model_config)
max_num_seqs_orig = max_num_seqs
max_num_seqs = min(max_num_seqs,
max_num_batched_tokens // max_mm_tokens)
if max_num_seqs < 1:
expr = (f"min({max_num_seqs_orig}, "
f"{max_num_batched_tokens} // {max_mm_tokens})")
logger.warning(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +