[VLM] Calculate maximum number of multi-modal tokens by model (#6121)

This commit is contained in:
Cyrus Leung 2024-07-05 07:37:23 +08:00 committed by GitHub
parent 69ec3ca14c
commit ae96ef8fbd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 265 additions and 95 deletions

View File

@ -51,17 +51,16 @@ As usual, follow :ref:`these steps <adding_a_new_model>` to implement the model
2. Register input mappers 2. Register input mappers
------------------------- -------------------------
For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`. For each modality type that the model accepts as input, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`. This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`.
.. code-block:: diff .. code-block:: diff
from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.interfaces import SupportsVision
+ from vllm.multimodal import MULTIMODAL_REGISTRY + from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_feature_input_mapper() + @MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsVision):
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function. A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.
@ -69,22 +68,22 @@ A default mapper is available for each modality in the core vLLM library. This i
:ref:`input_processing_pipeline` :ref:`input_processing_pipeline`
3. (Optional) Register dummy data 3. Register maximum number of multimodal tokens
--------------------------------- ----------------------------------------------------------
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models. For each modality type that the model accepts as input, calculate the maximum possible number of tokens
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`. and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.
.. code-block:: diff .. code-block:: diff
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() + @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>) @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision): class YourModelForImage2Seq(nn.Module, SupportsVision):
Here are some examples: Here are some examples:
@ -95,7 +94,36 @@ Here are some examples:
:ref:`input_processing_pipeline` :ref:`input_processing_pipeline`
4. (Optional) Register input processor 4. (Optional) Register dummy data
---------------------------------
During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.
.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
.. note::
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.
Here are some examples:
- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__
.. seealso::
:ref:`input_processing_pipeline`
5. (Optional) Register input processor
-------------------------------------- --------------------------------------
Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor. Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor.
@ -104,15 +132,15 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce
.. code-block:: diff .. code-block:: diff
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>) @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>) + @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsVision): class YourModelForImage2Seq(nn.Module, SupportsVision):
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation. A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples: Here are some examples:

View File

@ -25,13 +25,8 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
.. important:: .. important::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
every model to perform profiling with. internally for each model.
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
@ -104,13 +99,8 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
.. important:: .. important::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
every model to perform profiling with. internally for each model.
This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
To consume the server, you can use the OpenAI client like in the example below: To consume the server, you can use the OpenAI client like in the example below:

View File

@ -51,7 +51,7 @@ class InputContext:
additionally checking its type. additionally checking its type.
Raises: Raises:
ValueError: If the model is not of the specified type. TypeError: If the model is not of the specified type.
""" """
hf_config = self.model_config.hf_config hf_config = self.model_config.hf_config

View File

@ -35,6 +35,10 @@ def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
patch_size=hf_config.patch_size) patch_size=hf_config.patch_size)
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
return get_clip_image_feature_size(hf_config)
def dummy_seq_data_for_clip( def dummy_seq_data_for_clip(
hf_config: CLIPVisionConfig, hf_config: CLIPVisionConfig,
seq_len: int, seq_len: int,

View File

@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip) get_max_clip_image_tokens, input_processor_for_clip)
from .interfaces import SupportsVision from .interfaces import SupportsVision
from .utils import merge_vision_embeddings from .utils import merge_vision_embeddings
@ -62,6 +62,17 @@ class LlavaImagePixelInputs(TypedDict):
LlavaImageInputs = LlavaImagePixelInputs LlavaImageInputs = LlavaImagePixelInputs
def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
return get_max_clip_image_tokens(vision_config)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def dummy_data_for_llava(ctx: InputContext, seq_len: int): def dummy_data_for_llava(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(LlavaConfig) hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
@ -102,6 +113,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision): class LlavaForConditionalGeneration(nn.Module, SupportsVision):

View File

@ -127,6 +127,17 @@ def get_llava_next_image_feature_size(
raise NotImplementedError(msg) raise NotImplementedError(msg)
def get_max_llava_next_image_tokens(ctx: InputContext):
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
dummy_height = dummy_width = 448
return get_llava_next_image_feature_size(
ctx.get_hf_config(LlavaNextConfig),
input_height=dummy_height,
input_width=dummy_width,
)
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(LlavaNextConfig) hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
@ -198,6 +209,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):

View File

@ -321,6 +321,17 @@ def get_phi3v_image_feature_size(
+ (new_height // 336 + 1) * 12 + (new_height // 336 + 1) * 12
def get_max_phi3v_image_tokens(ctx: InputContext):
# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50
return get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
input_height=dummy_height,
input_width=dummy_width,
)
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int): def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
# Result in the max possible feature size (h:w = 16:1) # Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50 dummy_height, dummy_width = 8000, 50
@ -429,6 +440,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision): class Phi3VForCausalLM(nn.Module, SupportsVision):

View File

@ -97,9 +97,19 @@ the corresponding plugin with the same modality key is applied.
""" """
MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs] MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs]
"""Return a dictionary to be passed as keyword arguments to """
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers.""" and processors in HuggingFace Transformers.
If the data is not supported, throw :exc:`TypeError`.
"""
MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
N = TypeVar("N", bound=Type[nn.Module]) N = TypeVar("N", bound=Type[nn.Module])
@ -117,6 +127,7 @@ class MultiModalPlugin(ABC):
def __init__(self) -> None: def __init__(self) -> None:
self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {} self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}
@abstractmethod @abstractmethod
def get_data_key(self) -> str: def get_data_key(self) -> str:
@ -128,9 +139,12 @@ class MultiModalPlugin(ABC):
@abstractmethod @abstractmethod
def _default_input_mapper(self, ctx: InputContext, def _default_input_mapper(self, ctx: InputContext,
data: object) -> MultiModalInputs: data: object) -> MultiModalInputs:
"""Return a dictionary to be passed as keyword arguments to """
Return a dictionary to be passed as keyword arguments to
:meth:`~torch.nn.Module.forward`. This is similar in concept to :meth:`~torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers. tokenizers and processors in HuggingFace Transformers.
If the data is not supported, throw :exc:`TypeError`.
""" """
raise NotImplementedError raise NotImplementedError
@ -140,9 +154,11 @@ class MultiModalPlugin(ABC):
): ):
""" """
Register an input mapper to a model class. Register an input mapper to a model class.
When the model receives input data that matches the modality served by When the model receives input data that matches the modality served by
this plugin (see :meth:`get_data_type`), the provided function is this plugin (see :meth:`get_data_key`), the provided function is
invoked to transform the data into a dictionary of model inputs. invoked to transform the data into a dictionary of model inputs.
If `None` is provided, then the default input mapper is used instead. If `None` is provided, then the default input mapper is used instead.
See also: See also:
@ -170,10 +186,11 @@ class MultiModalPlugin(ABC):
Apply an input mapper to a data passed Apply an input mapper to a data passed
to the model, transforming the data into a dictionary of model inputs. to the model, transforming the data into a dictionary of model inputs.
If the data is not something that the mapper expects, throws TypeError.
The model is identified by ``model_config``. The model is identified by ``model_config``.
Raises:
TypeError: If the data type is not supported.
See also: See also:
:ref:`adding_a_new_multimodal_model` :ref:`adding_a_new_multimodal_model`
""" """
@ -188,3 +205,79 @@ class MultiModalPlugin(ABC):
f"model class {model_cls.__name__}.") f"model class {model_cls.__name__}.")
return mapper(InputContext(model_config), data) return mapper(InputContext(model_config), data)
@abstractmethod
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
"""
Calculate the maximum number of multimodal tokens input to the language
model. This does not include tokens that correspond to the input text.
"""
raise NotImplementedError
def _validate_max_multimodal_tokens(self, max_mm_tokens: int):
if max_mm_tokens < 1:
raise ValueError("You should set the number of tokens to a "
f"positive integer. Found: {max_mm_tokens}")
def register_max_multimodal_tokens(
self,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of multi-modal tokens input to the
language model for a model class.
If `None` is provided, then the default calculation is used instead.
See also:
:ref:`adding_a_new_multimodal_model`
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._max_mm_tokens:
logger.warning(
"Model class %s already calculates maximum number of "
"tokens in %s. It is overwritten by the new one.",
model_cls, self)
if isinstance(max_mm_tokens, int):
self._validate_max_multimodal_tokens(max_mm_tokens)
self._max_mm_tokens[model_cls] = max_mm_tokens \
or self._default_max_multimodal_tokens
return model_cls
return wrapper
def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
If this registry is not applicable to the model, `0` is returned.
The model is identified by ``model_config``.
See also:
:ref:`adding_a_new_multimodal_model`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
model_cls, _ = get_model_architecture(model_config)
if model_cls not in self._input_mappers:
return 0
max_mm_tokens = self._max_mm_tokens.get(model_cls)
if max_mm_tokens is None:
raise KeyError(f"No maximum number of multi-modal tokens is given "
f"for model class {model_cls.__name__} in {self}.")
if callable(max_mm_tokens):
max_mm_tokens = max_mm_tokens(InputContext(model_config))
self._validate_max_multimodal_tokens(max_mm_tokens)
return max_mm_tokens

View File

@ -130,3 +130,6 @@ class ImagePlugin(MultiModalPlugin):
raise NotImplementedError("Embeddings input is not supported yet") raise NotImplementedError("Embeddings input is not supported yet")
raise TypeError(f"Invalid image type: {type(data)}") raise TypeError(f"Invalid image type: {type(data)}")
def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
return 3000

View File

@ -7,7 +7,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs,
MultiModalPlugin) MultiModalPlugin, MultiModalTokensCalc)
from .image import ImagePlugin from .image import ImagePlugin
logger = init_logger(__name__) logger = init_logger(__name__)
@ -48,28 +48,9 @@ class MultiModalRegistry:
msg = f"Unknown multi-modal data type: {data_type_key}" msg = f"Unknown multi-modal data type: {data_type_key}"
raise NotImplementedError(msg) raise NotImplementedError(msg)
def register_image_input_mapper(
self,
mapper: Optional[MultiModalInputMapper] = None,
):
"""
Register an input mapper for image data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details.
"""
return self.register_input_mapper("image", mapper)
def _process_input(self, key: str, value: object,
model_config: ModelConfig) -> MultiModalInputs:
plugin = self._plugins.get(key)
if plugin:
return plugin.map_input(model_config, value)
msg = f"Unknown multi-modal data type: {key}"
raise NotImplementedError(msg)
def register_input_mapper( def register_input_mapper(
self, self,
data_type: str, data_type_key: str,
mapper: Optional[MultiModalInputMapper] = None, mapper: Optional[MultiModalInputMapper] = None,
): ):
""" """
@ -77,16 +58,14 @@ class MultiModalRegistry:
See :meth:`MultiModalPlugin.register_input_mapper` for more details. See :meth:`MultiModalPlugin.register_input_mapper` for more details.
""" """
plugin = self._plugins.get(data_type) return self._get_plugin(data_type_key).register_input_mapper(mapper)
if not plugin:
msg = f"Unknown multi-modal data type: {data_type}"
raise NotImplementedError(msg)
return plugin.register_input_mapper(mapper)
def register_image_input(self, def register_image_input_mapper(
mapper: Optional[MultiModalInputMapper] = None): self,
mapper: Optional[MultiModalInputMapper] = None,
):
""" """
Register an input mapper for image pixel data to a model class. Register an input mapper for image data to a model class.
See :meth:`MultiModalPlugin.register_input_mapper` for more details. See :meth:`MultiModalPlugin.register_input_mapper` for more details.
""" """
@ -102,8 +81,8 @@ class MultiModalRegistry:
merged_dict: Dict[str, torch.Tensor] = {} merged_dict: Dict[str, torch.Tensor] = {}
for data_key, data_value in data.items(): for data_key, data_value in data.items():
input_dict = self._process_input(data_key, data_value, input_dict = self._get_plugin(data_key) \
model_config) .map_input(model_config, data_value)
for input_key, input_tensor in input_dict.items(): for input_key, input_tensor in input_dict.items():
if input_key in merged_dict: if input_key in merged_dict:
@ -121,9 +100,35 @@ class MultiModalRegistry:
""" """
return functools.partial(self.map_input, model_config) return functools.partial(self.map_input, model_config)
def get_num_input_tokens(self): def register_max_multimodal_tokens(
self,
data_type_key: str,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
""" """
Get the number of input tokens for profiling purposes. Register the maximum number of tokens, belonging to a
specific modality, input to the language model for a model class.
""" """
# TODO: Provide this number on a per model basis. return self._get_plugin(data_type_key) \
return 3000 .register_max_multimodal_tokens(max_mm_tokens)
def register_max_image_tokens(
self,
max_mm_tokens: Optional[MultiModalTokensCalc] = None,
):
"""
Register the maximum number of image tokens
input to the language model for a model class.
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)
def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int:
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.
See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details.
"""
return sum(
plugin.get_max_multimodal_tokens(model_config)
for plugin in self._plugins.values())

View File

@ -820,12 +820,19 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
model_config = self.model_config model_config = self.model_config
if supports_vision(self.model): if supports_vision(self.model):
max_num_seqs = max( max_mm_tokens = MULTIMODAL_REGISTRY \
1, .get_max_multimodal_tokens(model_config)
min( max_num_seqs_orig = max_num_seqs
max_num_seqs, max_num_seqs = min(max_num_seqs,
int(max_num_batched_tokens / max_num_batched_tokens // max_mm_tokens)
MULTIMODAL_REGISTRY.get_num_input_tokens()))) if max_num_seqs < 1:
expr = (f"min({max_num_seqs_orig}, "
f"{max_num_batched_tokens} // {max_mm_tokens})")
logger.warning(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1
batch_size = 0 batch_size = 0
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +

View File

@ -168,14 +168,18 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
model_config = self.model_config model_config = self.model_config
if supports_vision(self.model): if supports_vision(self.model):
# TODO: properly inject these numbers from MultiModalRegistry. max_mm_tokens = MULTIMODAL_REGISTRY \
# Right now, just use an overly conservative number. .get_max_multimodal_tokens(model_config)
max_num_seqs = max( max_num_seqs_orig = max_num_seqs
1, max_num_seqs = min(max_num_seqs,
min( max_num_batched_tokens // max_mm_tokens)
max_num_seqs, if max_num_seqs < 1:
int(max_num_batched_tokens / expr = (f"min({max_num_seqs_orig}, "
MULTIMODAL_REGISTRY.get_num_input_tokens()))) f"{max_num_batched_tokens} // {max_mm_tokens})")
logger.warning(
"Computed max_num_seqs (%s) to be less than 1. "
"Setting it to the minimum value of 1.", expr)
max_num_seqs = 1
for group_id in range(max_num_seqs): for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs + seq_len = (max_num_batched_tokens // max_num_seqs +