diff --git a/docs/source/dev/multimodal/adding_multimodal_model.rst b/docs/source/dev/multimodal/adding_multimodal_model.rst index 0e9590639b22..32f62003f0e2 100644 --- a/docs/source/dev/multimodal/adding_multimodal_model.rst +++ b/docs/source/dev/multimodal/adding_multimodal_model.rst @@ -51,17 +51,16 @@ As usual, follow :ref:`these steps ` to implement the model 2. Register input mappers ------------------------- -For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper `. +For each modality type that the model accepts as input, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper `. This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`. .. code-block:: diff - from vllm.model_executor.models.interfaces import SupportsVision + from vllm.model_executor.models.interfaces import SupportsVision + from vllm.multimodal import MULTIMODAL_REGISTRY - + @MULTIMODAL_REGISTRY.register_image_feature_input_mapper() - + @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() - class YourModelForImage2Seq(nn.Module, SupportsVision): + + @MULTIMODAL_REGISTRY.register_image_input_mapper() + class YourModelForImage2Seq(nn.Module, SupportsVision): A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function. @@ -69,22 +68,22 @@ A default mapper is available for each modality in the core vLLM library. This i :ref:`input_processing_pipeline` -3. (Optional) Register dummy data ---------------------------------- +3. Register maximum number of multimodal tokens +---------------------------------------------------------- -During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models. -In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data `. +For each modality type that the model accepts as input, calculate the maximum possible number of tokens +and register it via :meth:`INPUT_REGISTRY.register_dummy_data `. .. code-block:: diff - from vllm.inputs import INPUT_REGISTRY - from vllm.model_executor.models.interfaces import SupportsVision - from vllm.multimodal import MULTIMODAL_REGISTRY + from vllm.inputs import INPUT_REGISTRY + from vllm.model_executor.models.interfaces import SupportsVision + from vllm.multimodal import MULTIMODAL_REGISTRY - @MULTIMODAL_REGISTRY.register_image_feature_input_mapper() - @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() - + @INPUT_REGISTRY.register_dummy_data() - class YourModelForImage2Seq(nn.Module, SupportsVision): + @MULTIMODAL_REGISTRY.register_image_input_mapper() + + @MULTIMODAL_REGISTRY.register_max_image_tokens() + @INPUT_REGISTRY.register_dummy_data() + class YourModelForImage2Seq(nn.Module, SupportsVision): Here are some examples: @@ -95,7 +94,36 @@ Here are some examples: :ref:`input_processing_pipeline` -4. (Optional) Register input processor +4. (Optional) Register dummy data +--------------------------------- + +During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models. +In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data `. + +.. code-block:: diff + + from vllm.inputs import INPUT_REGISTRY + from vllm.model_executor.models.interfaces import SupportsVision + from vllm.multimodal import MULTIMODAL_REGISTRY + + @MULTIMODAL_REGISTRY.register_image_input_mapper() + @MULTIMODAL_REGISTRY.register_max_image_tokens() + + @INPUT_REGISTRY.register_dummy_data() + class YourModelForImage2Seq(nn.Module, SupportsVision): + +.. note:: + The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step. + +Here are some examples: + +- Image inputs (static feature size): `LLaVA-1.5 Model `__ +- Image inputs (dynamic feature size): `LLaVA-NeXT Model `__ + +.. seealso:: + :ref:`input_processing_pipeline` + + +5. (Optional) Register input processor -------------------------------------- Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor. @@ -104,15 +132,15 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce .. code-block:: diff - from vllm.inputs import INPUT_REGISTRY - from vllm.model_executor.models.interfaces import SupportsVision - from vllm.multimodal import MULTIMODAL_REGISTRY + from vllm.inputs import INPUT_REGISTRY + from vllm.model_executor.models.interfaces import SupportsVision + from vllm.multimodal import MULTIMODAL_REGISTRY - @MULTIMODAL_REGISTRY.register_image_feature_input_mapper() - @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper() - @INPUT_REGISTRY.register_dummy_data() + @MULTIMODAL_REGISTRY.register_image_input_mapper() + @MULTIMODAL_REGISTRY.register_max_image_tokens() + @INPUT_REGISTRY.register_dummy_data() + @INPUT_REGISTRY.register_input_processor() - class YourModelForImage2Seq(nn.Module, SupportsVision): + class YourModelForImage2Seq(nn.Module, SupportsVision): A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation. Here are some examples: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index f9e5dbea1c9d..906f4d054a35 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -25,13 +25,8 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` .. important:: We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow - the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for - every model to perform profiling with. - - This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through - :meth:`MULTIMODAL_REGISTRY.get_num_input_tokens ` - for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced - with a more accurate profiling strategy in the future. + the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that + internally for each model. To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: @@ -104,13 +99,8 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with .. important:: We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow - the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for - every model to perform profiling with. - - This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through - :meth:`MULTIMODAL_REGISTRY.get_num_input_tokens ` - for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced - with a more accurate profiling strategy in the future. + the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that + internally for each model. To consume the server, you can use the OpenAI client like in the example below: diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 2c87e3d92582..9396296ffd90 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -51,7 +51,7 @@ class InputContext: additionally checking its type. Raises: - ValueError: If the model is not of the specified type. + TypeError: If the model is not of the specified type. """ hf_config = self.model_config.hf_config diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 4533e8cbdb41..d8fbf796b5d3 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -35,6 +35,10 @@ def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int: patch_size=hf_config.patch_size) +def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: + return get_clip_image_feature_size(hf_config) + + def dummy_seq_data_for_clip( hf_config: CLIPVisionConfig, seq_len: int, diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 526b080bf77b..840e40c946fc 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -21,7 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, - input_processor_for_clip) + get_max_clip_image_tokens, input_processor_for_clip) from .interfaces import SupportsVision from .utils import merge_vision_embeddings @@ -62,6 +62,17 @@ class LlavaImagePixelInputs(TypedDict): LlavaImageInputs = LlavaImagePixelInputs +def get_max_llava_image_tokens(ctx: InputContext): + hf_config = ctx.get_hf_config(LlavaConfig) + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + return get_max_clip_image_tokens(vision_config) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + def dummy_data_for_llava(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config @@ -102,6 +113,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): @MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava) class LlavaForConditionalGeneration(nn.Module, SupportsVision): diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 4b03a5f9f7c8..c37a68978c85 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -127,6 +127,17 @@ def get_llava_next_image_feature_size( raise NotImplementedError(msg) +def get_max_llava_next_image_tokens(ctx: InputContext): + # Result in the max possible feature size (2x2 grid of 336x336px tiles) + dummy_height = dummy_width = 448 + + return get_llava_next_image_feature_size( + ctx.get_hf_config(LlavaNextConfig), + input_height=dummy_height, + input_width=dummy_width, + ) + + def dummy_data_for_llava_next(ctx: InputContext, seq_len: int): hf_config = ctx.get_hf_config(LlavaNextConfig) vision_config = hf_config.vision_config @@ -198,6 +209,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): @MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 9f12a8b2b11b..0259960abbbf 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -321,6 +321,17 @@ def get_phi3v_image_feature_size( + (new_height // 336 + 1) * 12 +def get_max_phi3v_image_tokens(ctx: InputContext): + # Result in the max possible feature size (h:w = 16:1) + dummy_height, dummy_width = 8000, 50 + + return get_phi3v_image_feature_size( + ctx.get_hf_config(PretrainedConfig), + input_height=dummy_height, + input_width=dummy_width, + ) + + def dummy_data_for_phi3v(ctx: InputContext, seq_len: int): # Result in the max possible feature size (h:w = 16:1) dummy_height, dummy_width = 8000, 50 @@ -429,6 +440,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): @MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) class Phi3VForCausalLM(nn.Module, SupportsVision): diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index e7b45649d728..56cee73bd388 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -97,9 +97,19 @@ the corresponding plugin with the same modality key is applied. """ MultiModalInputMapper = Callable[[InputContext, object], MultiModalInputs] -"""Return a dictionary to be passed as keyword arguments to +""" +Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers -and processors in HuggingFace Transformers.""" +and processors in HuggingFace Transformers. + +If the data is not supported, throw :exc:`TypeError`. +""" + +MultiModalTokensCalc = Union[int, Callable[[InputContext], int]] +""" +Calculate the maximum number of multimodal tokens input to the language +model. This does not include tokens that correspond to the input text. +""" N = TypeVar("N", bound=Type[nn.Module]) @@ -117,6 +127,7 @@ class MultiModalPlugin(ABC): def __init__(self) -> None: self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {} + self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {} @abstractmethod def get_data_key(self) -> str: @@ -128,9 +139,12 @@ class MultiModalPlugin(ABC): @abstractmethod def _default_input_mapper(self, ctx: InputContext, data: object) -> MultiModalInputs: - """Return a dictionary to be passed as keyword arguments to + """ + Return a dictionary to be passed as keyword arguments to :meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers. + + If the data is not supported, throw :exc:`TypeError`. """ raise NotImplementedError @@ -140,9 +154,11 @@ class MultiModalPlugin(ABC): ): """ Register an input mapper to a model class. + When the model receives input data that matches the modality served by - this plugin (see :meth:`get_data_type`), the provided function is + this plugin (see :meth:`get_data_key`), the provided function is invoked to transform the data into a dictionary of model inputs. + If `None` is provided, then the default input mapper is used instead. See also: @@ -170,10 +186,11 @@ class MultiModalPlugin(ABC): Apply an input mapper to a data passed to the model, transforming the data into a dictionary of model inputs. - If the data is not something that the mapper expects, throws TypeError. - The model is identified by ``model_config``. + Raises: + TypeError: If the data type is not supported. + See also: :ref:`adding_a_new_multimodal_model` """ @@ -188,3 +205,79 @@ class MultiModalPlugin(ABC): f"model class {model_cls.__name__}.") return mapper(InputContext(model_config), data) + + @abstractmethod + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + """ + Calculate the maximum number of multimodal tokens input to the language + model. This does not include tokens that correspond to the input text. + """ + raise NotImplementedError + + def _validate_max_multimodal_tokens(self, max_mm_tokens: int): + if max_mm_tokens < 1: + raise ValueError("You should set the number of tokens to a " + f"positive integer. Found: {max_mm_tokens}") + + def register_max_multimodal_tokens( + self, + max_mm_tokens: Optional[MultiModalTokensCalc] = None, + ): + """ + Register the maximum number of multi-modal tokens input to the + language model for a model class. + + If `None` is provided, then the default calculation is used instead. + + See also: + :ref:`adding_a_new_multimodal_model` + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._max_mm_tokens: + logger.warning( + "Model class %s already calculates maximum number of " + "tokens in %s. It is overwritten by the new one.", + model_cls, self) + + if isinstance(max_mm_tokens, int): + self._validate_max_multimodal_tokens(max_mm_tokens) + + self._max_mm_tokens[model_cls] = max_mm_tokens \ + or self._default_max_multimodal_tokens + + return model_cls + + return wrapper + + def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: + """ + Get the maximum number of multi-modal tokens + for profiling the memory usage of a model. + + If this registry is not applicable to the model, `0` is returned. + + The model is identified by ``model_config``. + + See also: + :ref:`adding_a_new_multimodal_model` + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + + if model_cls not in self._input_mappers: + return 0 + + max_mm_tokens = self._max_mm_tokens.get(model_cls) + if max_mm_tokens is None: + raise KeyError(f"No maximum number of multi-modal tokens is given " + f"for model class {model_cls.__name__} in {self}.") + + if callable(max_mm_tokens): + max_mm_tokens = max_mm_tokens(InputContext(model_config)) + + self._validate_max_multimodal_tokens(max_mm_tokens) + + return max_mm_tokens diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 27010fa6ed81..b6c73512350f 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -130,3 +130,6 @@ class ImagePlugin(MultiModalPlugin): raise NotImplementedError("Embeddings input is not supported yet") raise TypeError(f"Invalid image type: {type(data)}") + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + return 3000 diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index bd4583ef58da..e0716bbf1571 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -7,7 +7,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, - MultiModalPlugin) + MultiModalPlugin, MultiModalTokensCalc) from .image import ImagePlugin logger = init_logger(__name__) @@ -48,28 +48,9 @@ class MultiModalRegistry: msg = f"Unknown multi-modal data type: {data_type_key}" raise NotImplementedError(msg) - def register_image_input_mapper( - self, - mapper: Optional[MultiModalInputMapper] = None, - ): - """ - Register an input mapper for image data to a model class. - - See :meth:`MultiModalPlugin.register_input_mapper` for more details. - """ - return self.register_input_mapper("image", mapper) - - def _process_input(self, key: str, value: object, - model_config: ModelConfig) -> MultiModalInputs: - plugin = self._plugins.get(key) - if plugin: - return plugin.map_input(model_config, value) - msg = f"Unknown multi-modal data type: {key}" - raise NotImplementedError(msg) - def register_input_mapper( self, - data_type: str, + data_type_key: str, mapper: Optional[MultiModalInputMapper] = None, ): """ @@ -77,16 +58,14 @@ class MultiModalRegistry: See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ - plugin = self._plugins.get(data_type) - if not plugin: - msg = f"Unknown multi-modal data type: {data_type}" - raise NotImplementedError(msg) - return plugin.register_input_mapper(mapper) + return self._get_plugin(data_type_key).register_input_mapper(mapper) - def register_image_input(self, - mapper: Optional[MultiModalInputMapper] = None): + def register_image_input_mapper( + self, + mapper: Optional[MultiModalInputMapper] = None, + ): """ - Register an input mapper for image pixel data to a model class. + Register an input mapper for image data to a model class. See :meth:`MultiModalPlugin.register_input_mapper` for more details. """ @@ -102,8 +81,8 @@ class MultiModalRegistry: merged_dict: Dict[str, torch.Tensor] = {} for data_key, data_value in data.items(): - input_dict = self._process_input(data_key, data_value, - model_config) + input_dict = self._get_plugin(data_key) \ + .map_input(model_config, data_value) for input_key, input_tensor in input_dict.items(): if input_key in merged_dict: @@ -121,9 +100,35 @@ class MultiModalRegistry: """ return functools.partial(self.map_input, model_config) - def get_num_input_tokens(self): + def register_max_multimodal_tokens( + self, + data_type_key: str, + max_mm_tokens: Optional[MultiModalTokensCalc] = None, + ): """ - Get the number of input tokens for profiling purposes. + Register the maximum number of tokens, belonging to a + specific modality, input to the language model for a model class. """ - # TODO: Provide this number on a per model basis. - return 3000 + return self._get_plugin(data_type_key) \ + .register_max_multimodal_tokens(max_mm_tokens) + + def register_max_image_tokens( + self, + max_mm_tokens: Optional[MultiModalTokensCalc] = None, + ): + """ + Register the maximum number of image tokens + input to the language model for a model class. + """ + return self.register_max_multimodal_tokens("image", max_mm_tokens) + + def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: + """ + Get the maximum number of multi-modal tokens + for profiling the memory usage of a model. + + See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. + """ + return sum( + plugin.get_max_multimodal_tokens(model_config) + for plugin in self._plugins.values()) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2ae5263baa18..d0c82d6bbedf 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -820,12 +820,19 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): model_config = self.model_config if supports_vision(self.model): - max_num_seqs = max( - 1, - min( - max_num_seqs, - int(max_num_batched_tokens / - MULTIMODAL_REGISTRY.get_num_input_tokens()))) + max_mm_tokens = MULTIMODAL_REGISTRY \ + .get_max_multimodal_tokens(model_config) + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 + batch_size = 0 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index c3a24c89f302..03b9cce5ae79 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -168,14 +168,18 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]): model_config = self.model_config if supports_vision(self.model): - # TODO: properly inject these numbers from MultiModalRegistry. - # Right now, just use an overly conservative number. - max_num_seqs = max( - 1, - min( - max_num_seqs, - int(max_num_batched_tokens / - MULTIMODAL_REGISTRY.get_num_input_tokens()))) + max_mm_tokens = MULTIMODAL_REGISTRY \ + .get_max_multimodal_tokens(model_config) + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs +