diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index b6007b9f4630..1a5879a6d35f 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -45,7 +45,11 @@ datamodule_config = { class PrithviMAE: def __init__(self, model): self.model = LLM( - model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True + model=model, + skip_tokenizer_init=True, + dtype="float16", + enforce_eager=True, + model_impl="terratorch", ) def run(self, input_data, location_coords): diff --git a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py index adc27859a1cd..5d629fabf0a2 100644 --- a/examples/offline_inference/prithvi_geospatial_mae_io_processor.py +++ b/examples/offline_inference/prithvi_geospatial_mae_io_processor.py @@ -37,6 +37,7 @@ def main(): # The maximum number depends on the available GPU memory max_num_seqs=32, io_processor_plugin="prithvi_to_tiff_india", + model_impl="terratorch", ) pooling_params = PoolingParams(task="encode", softmax=False) diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py index 359162c470f0..c6eed64838ea 100644 --- a/examples/online_serving/prithvi_geospatial_mae.py +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -15,6 +15,7 @@ import requests # https://github.com/christian-pinto/prithvi_io_processor_plugin # - start vllm in serving mode with the below args # --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' +# --model-impl terratorch # --task embed --trust-remote-code # --skip-tokenizer-init --enforce-eager # --io-processor-plugin prithvi_to_tiff_india diff --git a/requirements/test.in b/requirements/test.in index 5b1688c76c95..5db9cd797904 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -53,5 +53,5 @@ runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 -terratorch==1.1rc2 # required for PrithviMAE test decord==0.6.0 +terratorch==1.1rc3 # required for PrithviMAE test diff --git a/requirements/test.txt b/requirements/test.txt index 0b728ebfb007..332a9b9cfbf5 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1042,7 +1042,7 @@ tensorboardx==2.6.4 # via lightning tensorizer==2.10.1 # via -r requirements/test.in -terratorch==1.1rc2 +terratorch==1.1rc3 # via -r requirements/test.in threadpoolctl==3.5.0 # via scikit-learn diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 1afe9ea970c9..fffab1a984c2 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -298,6 +298,8 @@ def _compare_tp( tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides hf_config = get_config(model_id, trust_remote_code) + skip_tokenizer_init = model_info.skip_tokenizer_init + max_num_seqs = model_info.max_num_seqs dtype = "float16" if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS: @@ -351,6 +353,10 @@ def _compare_tp( common_args.extend(["--load-format", load_format]) if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + if skip_tokenizer_init: + common_args.append("--skip-tokenizer-init") + if max_num_seqs: + common_args.extend(["--max-num-seqs", f"{max_num_seqs}"]) specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill testing_ray_compiled_graph = False diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index c93b436f384b..65c5e6896844 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -178,6 +178,7 @@ def _compare_sp( trust_remote_code = model_info.trust_remote_code tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides + skip_tokenizer_init = model_info.skip_tokenizer_init if load_format == "dummy": # Avoid OOM @@ -227,6 +228,8 @@ def _compare_sp( common_args.extend(["--load-format", load_format]) if hf_overrides: common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) + if skip_tokenizer_init: + common_args.append("--skip-tokenizer-init") compilation_config = { 'level': 3, diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index 5b6e2a4146b1..ce90a67c0151 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -104,7 +104,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt, trust_remote_code=model_info.trust_remote_code, revision=model_info.revision, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) # Initialize the tokenizer tokenizer = get_tokenizer( diff --git a/tests/entrypoints/openai/test_skip_tokenizer.py b/tests/entrypoints/openai/test_skip_tokenizer.py index 0bb42ed8aa7f..af520ac61d8d 100644 --- a/tests/entrypoints/openai/test_skip_tokenizer.py +++ b/tests/entrypoints/openai/test_skip_tokenizer.py @@ -11,7 +11,7 @@ import torch from ...utils import RemoteOpenAIServer -MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" +MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11" DTYPE = "float16" @@ -35,7 +35,9 @@ def server(): "--trust-remote-code", "--skip-tokenizer-init", "--max-num-seqs", - "32" + "32", + "--model-impl", + "terratorch" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 0c1f19371a16..18db1027c004 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -1266,7 +1266,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) # Build the tokenizer group and grab the underlying tokenizer tokenizer_group = TokenizerGroup( @@ -1322,7 +1324,9 @@ def test_resolve_content_format_hf_defined(model, expected_format): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) tokenizer_group = TokenizerGroup( model, @@ -1382,7 +1386,9 @@ def test_resolve_content_format_fallbacks(model, expected_format): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) tokenizer_group = TokenizerGroup( model_config.tokenizer, diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index a5d6948f06ef..ae7083833695 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -69,6 +69,9 @@ def run_test( vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode if model_info.hf_overrides: vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides + if model_info.skip_tokenizer_init: + vllm_runner_kwargs_[ + "skip_tokenizer_init"] = model_info.skip_tokenizer_init if vllm_runner_kwargs: vllm_runner_kwargs_.update(vllm_runner_kwargs) diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py index e9be79fba911..b503d4256702 100644 --- a/tests/models/multimodal/pooling/test_prithvi_mae.py +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -46,7 +46,7 @@ def _run_test( vllm_model.encode(prompt) -MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] +MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"] @pytest.mark.core_model diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 8ffd65cf087b..ced0ab3377a9 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -66,7 +66,9 @@ def _test_processing_correctness( hf_overrides=model_info.hf_overrides, # Ensure that the cache can fit all of the data mm_processor_cache_gb=2048, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] diff --git a/tests/models/multimodal/processing/test_tensor_schema.py b/tests/models/multimodal/processing/test_tensor_schema.py index 615564f70ea3..b678313752d6 100644 --- a/tests/models/multimodal/processing/test_tensor_schema.py +++ b/tests/models/multimodal/processing/test_tensor_schema.py @@ -196,7 +196,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=hf_overrides_fn, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] diff --git a/tests/models/multimodal/test_mapping.py b/tests/models/multimodal/test_mapping.py index 7096810d8e15..caf1966ab513 100644 --- a/tests/models/multimodal/test_mapping.py +++ b/tests/models/multimodal/test_mapping.py @@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str): revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, - ) + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, + dtype=model_info.dtype) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) original_weights = create_repo_dummy_weights(model_id) diff --git a/tests/models/registry.py b/tests/models/registry.py index c9e2eec5117d..38efb01341eb 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -6,10 +6,11 @@ from dataclasses import dataclass, field from typing import Any, Literal, Optional import pytest +import torch from packaging.version import Version from transformers import __version__ as TRANSFORMERS_VERSION -from vllm.config import TokenizerMode +from vllm.config import ModelDType, TokenizerMode @dataclass(frozen=True) @@ -47,6 +48,23 @@ class _HfExamplesInfo: The reason for the minimum/maximum version requirement. """ + skip_tokenizer_init: bool = False + """ + If true, skip initialization of tokenizer and detokenizer. + """ + + dtype: ModelDType = "auto" + """ + The data type for the model weights and activations. + """ + + enforce_eager: bool = False + """ + Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + """ + is_available_online: bool = True """ Set this to ``False`` if the name of this architecture no longer exists on @@ -76,6 +94,9 @@ class _HfExamplesInfo: If not specified, the default revision will be used. """ + max_num_seqs: Optional[int] = None + """Maximum number of sequences to be processed in a single iteration.""" + def check_transformers_version( self, *, @@ -361,8 +382,21 @@ _EMBEDDING_EXAMPLE_MODELS = { "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", trust_remote_code=True), "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 - "PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 - is_available_online=False), # noqa: E501 + "PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # This is to avoid the model + # going OOM in CI + max_num_seqs=32, + ), + "Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # This is to avoid the model going OOM in CI + max_num_seqs=32, + ), } _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index b4d516233b4b..aaa04f52f779 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -73,6 +73,9 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, tokenizer=model_info.tokenizer, tokenizer_mode=model_info.tokenizer_mode, revision=model_info.revision, + enforce_eager=model_info.enforce_eager, + skip_tokenizer_init=model_info.skip_tokenizer_init, + dtype=model_info.dtype, speculative_config={ "model": model_info.speculative_model, "num_speculative_tokens": 1, @@ -85,7 +88,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, model_impl=ModelImpl.TRANSFORMERS if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM, hf_overrides=hf_overrides_fn, - ) + max_num_seqs=model_info.max_num_seqs) @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) diff --git a/tests/models/test_terratorch.py b/tests/models/test_terratorch.py new file mode 100644 index 000000000000..bfa54280dc02 --- /dev/null +++ b/tests/models/test_terratorch.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.conftest import VllmRunner +from vllm.utils import set_default_torch_num_threads + + +@pytest.mark.parametrize( + "model", + [ + "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", + "mgazz/Prithvi_v2_eo_300_tl_unet_agb" + ], +) +def test_inference( + vllm_runner: type[VllmRunner], + model: str, +) -> None: + + pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16) + location_coords = torch.full((1, 2), 1.0, dtype=torch.float16) + prompt = dict(prompt_token_ids=[1], + multi_modal_data=dict(pixel_values=pixel_values, + location_coords=location_coords)) + with ( + set_default_torch_num_threads(1), + vllm_runner( + model, + runner="pooling", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # Limit the maximum number of sequences to avoid the + # test going OOM during the warmup run + max_num_seqs=32, + ) as vllm_model, + ): + + vllm_output = vllm_model.llm.encode(prompt) + assert torch.equal( + torch.isnan(vllm_output[0].outputs.data).any(), + torch.tensor(False)) diff --git a/tests/models/utils.py b/tests/models/utils.py index 40a41afff828..ab0b27af4d69 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -294,6 +294,8 @@ def build_model_context( limit_mm_per_prompt=limit_mm_per_prompt, mm_processor_cache_gb=mm_processor_cache_gb, hf_overrides=model_info.hf_overrides, + skip_tokenizer_init=model_info.skip_tokenizer_init, + enforce_eager=model_info.enforce_eager, **model_config_kwargs, ) return InputContext(model_config) diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index b2fbef2ee25c..825165e89b33 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -7,12 +7,11 @@ import requests from tests.utils import RemoteOpenAIServer from vllm.config import VllmConfig -from vllm.entrypoints.llm import LLM from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.plugins.io_processors import get_io_processor from vllm.pooling_params import PoolingParams -MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" +MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11" image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 @@ -23,61 +22,7 @@ def test_loading_missing_plugin(): get_io_processor(vllm_config, "wrong_plugin") -def test_loading_engine_with_wrong_plugin(): - - with pytest.raises(ValueError): - LLM( - model=MODEL_NAME, - skip_tokenizer_init=True, - trust_remote_code=True, - enforce_eager=True, - # Limit the maximum number of parallel requests - # to avoid the model going OOM in CI. - max_num_seqs=32, - io_processor_plugin="wrong_plugin", - ) - - -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): - - img_prompt = dict( - data=image_url, - data_format="url", - image_format="tiff", - out_data_format="b64_json", - ) - - pooling_params = PoolingParams(task="encode", softmax=False) - - with vllm_runner( - model_name, - runner="pooling", - skip_tokenizer_init=True, - trust_remote_code=True, - enforce_eager=True, - # Limit the maximum number of parallel requests - # to avoid the model going OOM in CI. - max_num_seqs=1, - io_processor_plugin="prithvi_to_tiff_valencia", - ) as llm_runner: - pooler_output = llm_runner.get_llm().encode( - img_prompt, - pooling_params=pooling_params, - ) - output = pooler_output[0].outputs - - # verify the output is formatted as expected for this plugin - assert all( - hasattr(output, attr) - for attr in ["type", "format", "data", "request_id"]) - - # We just check that the output is a valid base64 string. - # Raises an exception and fails the test if the string is corrupted. - base64.b64decode(output.data) - - -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def server(): args = [ "--runner", @@ -90,7 +35,9 @@ def server(): "--max-num-seqs", "32", "--io-processor-plugin", - "prithvi_to_tiff_valencia" + "prithvi_to_tiff_valencia", + "--model-impl", + "terratorch", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -136,3 +83,43 @@ async def test_prithvi_mae_plugin_online( # We just check that the output is a valid base64 string. # Raises an exception and fails the test if the string is corrupted. base64.b64decode(plugin_data["data"]) + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str): + + img_prompt = dict( + data=image_url, + data_format="url", + image_format="tiff", + out_data_format="b64_json", + ) + + pooling_params = PoolingParams(task="encode", softmax=False) + + with vllm_runner( + model_name, + runner="pooling", + skip_tokenizer_init=True, + trust_remote_code=True, + enforce_eager=True, + # Limit the maximum number of parallel requests + # to avoid the model going OOM in CI. + max_num_seqs=1, + model_impl="terratorch", + io_processor_plugin="prithvi_to_tiff_valencia", + ) as llm_runner: + pooler_output = llm_runner.get_llm().encode( + img_prompt, + pooling_params=pooling_params, + ) + output = pooler_output[0].outputs + + # verify the output is formatted as expected for this plugin + assert all( + hasattr(output, attr) + for attr in ["type", "format", "data", "request_id"]) + + # We just check that the output is a valid base64 string. + # Raises an exception and fails the test if the string is corrupted. + base64.b64decode(output.data) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 2cea2695a66e..7c2b49702265 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -171,6 +171,7 @@ class ModelImpl(str, enum.Enum): AUTO = "auto" VLLM = "vllm" TRANSFORMERS = "transformers" + TERRATORCH = "terratorch" def get_attr_docs(cls: type[Any]) -> dict[str, str]: @@ -496,7 +497,9 @@ class ModelConfig: back to the Transformers implementation if no vLLM implementation is available.\n - "vllm" will use the vLLM model implementation.\n - - "transformers" will use the Transformers model implementation.""" + - "transformers" will use the Transformers model implementation.\n + - "terratorch" will use the TerraTorch model implementation. + """ override_attention_dtype: Optional[str] = None """Override dtype for attention""" logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index feca60f2f001..38d300b03d2c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -184,10 +184,11 @@ _EMBEDDING_MODELS = { "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 - # Technically PrithviGeoSpatialMAE is a model that works on images, both in - # input and output. I am adding it here because it piggybacks on embedding + # Technically Terratorch models work on images, both in + # input and output. I am adding it here because it piggy-backs on embedding # models for the time being. - "PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), + "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"), + "Terratorch": ("terratorch", "Terratorch"), } _CROSS_ENCODER_MODELS = { @@ -639,6 +640,9 @@ class _ModelRegistry: model_info = self._try_inspect_model_cls(arch) if model_info is not None: return (model_info, arch) + elif model_config.model_impl == ModelImpl.TERRATORCH: + model_info = self._try_inspect_model_cls("Terratorch") + return (model_info, "Terratorch") # Fallback to transformers impl (after resolving convert_type) if (all(arch not in self.models for arch in architectures) @@ -687,6 +691,11 @@ class _ModelRegistry: model_cls = self._try_load_model_cls(arch) if model_cls is not None: return (model_cls, arch) + elif model_config.model_impl == ModelImpl.TERRATORCH: + arch = "Terratorch" + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) # Fallback to transformers impl (after resolving convert_type) if (all(arch not in self.models for arch in architectures) diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/terratorch.py similarity index 52% rename from vllm/model_executor/models/prithvi_geospatial_mae.py rename to vllm/model_executor/models/terratorch.py index 2edc357d2df1..739396a4932c 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/terratorch.py @@ -15,13 +15,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only IBM/NASA Prithvi Geospatial model.""" +"""Wrapper around `Terratorch` models""" +from collections import OrderedDict from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch import torch.nn as nn +from terratorch.vllm import (DummyDataGenerator, InferenceRunner, + InputDefinition, InputTypeEnum) from transformers import BatchFeature from vllm.config import VllmConfig @@ -29,6 +32,7 @@ from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargsItems, @@ -45,52 +49,46 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings, from .interfaces_base import default_pooling_type -def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]): - # This model receives in input a multi-dimensional tensor representing - # a single image patch and therefore it is not to be split - # into multiple elements, but rather to be considered a single one. - # Hence, the decision of using a MultiModalSharedField. - # The expected shape is (num_channels, width, height). - - # This model however allows the user to also submit multiple image - # patches as a batch, adding a further dimension to the above shape. - # At this stage we only support submitting one patch per request and - # batching is achieved via vLLM batching. - # TODO (christian-pinto): enable support for multi patch requests - # in tandem with vLLM batching. - return dict( - pixel_values=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - location_coords=MultiModalFieldConfig.shared(batch_size=1, - modality="image"), - ) +def _terratorch_field_names(pretrained_cfg: dict): + input_definition = InputDefinition(**pretrained_cfg["input"]) + return set(input_definition.data.keys()) -class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser): +def _terratorch_field_factory( + pretrained_cfg: dict +) -> Callable[ + [Mapping[str, torch.Tensor]], + Mapping[str, MultiModalFieldConfig], +]: - def _parse_image_data( - self, - data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], - ) -> Optional[ModalityDataItems[Any, Any]]: - if isinstance(data, dict): - return DictEmbeddingItems( - data, - modality="image", - required_fields={"pixel_values", "location_coords"}, - fields_factory=_prithvi_field_config, - ) + def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]): + input_definition = InputDefinition(**pretrained_cfg["input"]) + fields = {} + for input_name, input in input_definition.data.items(): + if input.type == InputTypeEnum.tensor: + fields[input_name] = "image" - return super()._parse_image_data(data) + mm_fields_config = {} + for field_name, field_modality in fields.items(): + mm_fields_config[field_name] = MultiModalFieldConfig.shared( + batch_size=1, modality=field_modality) + return mm_fields_config + + return _terratorch_field_config -class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): +class TerratorchProcessingInfo(BaseProcessingInfo): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} -class PrithviGeoSpatialMAEInputBuilder( - BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]): +class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]): + + def __init__(self, info: TerratorchProcessingInfo): + super().__init__(info) + self.dummy_data_generator = DummyDataGenerator( + self.info.get_hf_config().to_dict()["pretrained_cfg"]) def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: return "" @@ -100,29 +98,57 @@ class PrithviGeoSpatialMAEInputBuilder( seq_len: int, mm_counts: Mapping[str, int], ) -> MultiModalDataDict: - # This model input is fixed and is in the form of a torch Tensor. - # The size of pixel_values might change in the cases where we resize - # the input but never exceeds the dimensions below. - image_data = { - "pixel_values": torch.full((6, 512, 512), 1.0, - dtype=torch.float16), - "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), - } - - return {"image": image_data} + # Dummy data is generated based on the 'input' section + # defined in the HF configuration file + return self.dummy_data_generator.get_dummy_mm_data() -class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): +class TerratorchMultiModalDataParser(MultiModalDataParser): + + def __init__(self, pretrained_cfg: dict, *args, **kwargs): + self._pretrained_cfg = pretrained_cfg + super().__init__(*args, **kwargs) + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> Optional[ModalityDataItems[Any, Any]]: + if isinstance(data, dict): + + terratorch_fields = _terratorch_field_names(self._pretrained_cfg) + + return DictEmbeddingItems( + data, + modality="image", + required_fields=terratorch_fields, + fields_factory=_terratorch_field_factory(self._pretrained_cfg), + ) + + return super()._parse_image_data(data) + + +class TerratorchMultiModalProcessor(BaseMultiModalProcessor): + + def __init__( + self, + info: TerratorchProcessingInfo, + dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]", + *, + cache: Optional[MultiModalProcessorOnlyCache] = None) -> None: + + self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"] + super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache) def _get_data_parser(self) -> MultiModalDataParser: - return PrithviGeoSpatialMAEMultiModalDataParser() + return TerratorchMultiModalDataParser( + pretrained_cfg=self.pretrained_cfg) def _get_mm_fields_config( self, hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - return _prithvi_field_config(hf_inputs) + return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs) def _get_prompt_updates( self, @@ -173,13 +199,11 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): @default_pooling_type("All") @MULTIMODAL_REGISTRY.register_processor( - PrithviGeoSpatialMAEMultiModalProcessor, - info=PrithviGeoSpatialMAEProcessingInfo, - dummy_inputs=PrithviGeoSpatialMAEInputBuilder, + TerratorchMultiModalProcessor, + info=TerratorchProcessingInfo, + dummy_inputs=TerratorchInputBuilder, ) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): - """Prithvi Masked Autoencoder""" - +class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): supports_multimodal_raw_input_only = True is_pooling_model = True @@ -190,43 +214,13 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): raise ValueError("Only image modality is supported") - def _instantiate_model(self, config: dict) -> Optional[nn.Module]: - # We might be able/need to support different tasks with this same model - if config["task_args"]["task"] == "SemanticSegmentationTask": - from terratorch.cli_tools import SemanticSegmentationTask - - task = SemanticSegmentationTask( - config["model_args"], - config["task_args"]["model_factory"], - loss=config["task_args"]["loss"], - lr=config["task_args"]["lr"], - ignore_index=config["task_args"]["ignore_index"], - optimizer=config["task_args"]["optimizer"], - optimizer_hparams=config["optimizer_params"], - scheduler=config["task_args"]["scheduler"], - scheduler_hparams=config["scheduler_params"], - plot_on_val=config["task_args"]["plot_on_val"], - freeze_decoder=config["task_args"]["freeze_decoder"], - freeze_backbone=config["task_args"]["freeze_backbone"], - ) - - return task.model - else: - return None - def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() - # the actual model is dynamically instantiated using terratorch - # allowing us to perform changes to the model architecture - # at startup time (e.g., change the model decoder class.) - self.model = self._instantiate_model( - vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]) - if self.model is None: - raise ValueError( - "Unsupported task. " - "Only SemanticSegmentationTask is supported for now " - "by PrithviGeospatialMAE.") + config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"] + + self.inference_runner = InferenceRunner(config) + self.model = self.inference_runner.model pooler_config = vllm_config.model_config.pooler_config assert pooler_config is not None @@ -234,23 +228,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): self.pooler = DispatchPooler( {"encode": Pooler.for_encode(pooler_config)}, ) - def _parse_and_validate_multimodal_data( - self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - pixel_values = kwargs.pop("pixel_values", None) - if not isinstance(pixel_values, torch.Tensor): - raise ValueError(f"Incorrect type of pixel_values. " - f"Got type: {type(pixel_values)}") - - location_coords = kwargs.pop("location_coords", None) - if not isinstance(location_coords, torch.Tensor): - raise ValueError(f"Incorrect type of location_coords. " - f"Got type: {type(location_coords)}") - location_coords = torch.unbind(location_coords, dim=0)[0] - if location_coords.shape == torch.Size([0]): - location_coords = None - - return pixel_values, location_coords - def get_input_embeddings( self, input_ids: torch.Tensor, @@ -270,10 +247,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ): - pixel_values, location_coords = ( - self._parse_and_validate_multimodal_data(**kwargs)) - model_output = self.model(pixel_values, - location_coords=location_coords) + model_output = self.inference_runner.forward(**kwargs) return model_output.output @@ -283,28 +257,34 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): model_buffers = dict(self.named_buffers()) loaded_buffers = [] for key, value in weights: - if key == "state_dict": - weights_to_parse = value - for name, weight in weights_to_parse.items(): - if "pos_embed" in name: - continue + if isinstance(value, (dict, OrderedDict)): + if key == "state_dict": + weights_to_parse = value + for name, weight in weights_to_parse.items(): + name = f"inference_runner.{name}" - if "_timm_module." in name: - name = name.replace("_timm_module.", "") + if "pos_embed" in name: + continue - # this model requires a couple of buffers to be loaded - # that are not loadable with the AutoWeightsLoader - if name in model_buffers: if "_timm_module." in name: name = name.replace("_timm_module.", "") - buffer = model_buffers[name] - weight_loader = getattr(buffer, "weight_loader", - default_weight_loader) - weight_loader(buffer, weight) - loaded_buffers.append(name) - else: - params_list.append((name, weight)) - break + + # this model requires a couple of buffers to be loaded + # that are not loadable with the AutoWeightsLoader + if name in model_buffers: + if "_timm_module." in name: + name = name.replace("_timm_module.", "") + buffer = model_buffers[name] + weight_loader = getattr(buffer, "weight_loader", + default_weight_loader) + weight_loader(buffer, weight) + loaded_buffers.append(name) + else: + params_list.append((name, weight)) + break + + elif isinstance(value, torch.Tensor): + params_list.append((f"inference_runner.model.{key}", value)) # Load the remaining model parameters loader = AutoWeightsLoader(self)