[Core][Model] Terratorch backend integration (#23513)

Signed-off-by: Michele Gazzetti <michele.gazzetti1@ibm.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Co-authored-by: Christian Pinto <christian.pinto@ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
mgazz 2025-09-04 08:22:41 +01:00 committed by GitHub
parent e7fc70016f
commit 51d5e9be7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 305 additions and 208 deletions

View File

@ -45,7 +45,11 @@ datamodule_config = {
class PrithviMAE: class PrithviMAE:
def __init__(self, model): def __init__(self, model):
self.model = LLM( self.model = LLM(
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True model=model,
skip_tokenizer_init=True,
dtype="float16",
enforce_eager=True,
model_impl="terratorch",
) )
def run(self, input_data, location_coords): def run(self, input_data, location_coords):

View File

@ -37,6 +37,7 @@ def main():
# The maximum number depends on the available GPU memory # The maximum number depends on the available GPU memory
max_num_seqs=32, max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff_india", io_processor_plugin="prithvi_to_tiff_india",
model_impl="terratorch",
) )
pooling_params = PoolingParams(task="encode", softmax=False) pooling_params = PoolingParams(task="encode", softmax=False)

View File

@ -15,6 +15,7 @@ import requests
# https://github.com/christian-pinto/prithvi_io_processor_plugin # https://github.com/christian-pinto/prithvi_io_processor_plugin
# - start vllm in serving mode with the below args # - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM' # --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --model-impl terratorch
# --task embed --trust-remote-code # --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager # --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff_india # --io-processor-plugin prithvi_to_tiff_india

View File

@ -53,5 +53,5 @@ runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0 runai-model-streamer-s3==0.11.0
fastsafetensors>=0.1.10 fastsafetensors>=0.1.10
pydantic>=2.10 # 2.9 leads to error on python 3.10 pydantic>=2.10 # 2.9 leads to error on python 3.10
terratorch==1.1rc2 # required for PrithviMAE test
decord==0.6.0 decord==0.6.0
terratorch==1.1rc3 # required for PrithviMAE test

View File

@ -1042,7 +1042,7 @@ tensorboardx==2.6.4
# via lightning # via lightning
tensorizer==2.10.1 tensorizer==2.10.1
# via -r requirements/test.in # via -r requirements/test.in
terratorch==1.1rc2 terratorch==1.1rc3
# via -r requirements/test.in # via -r requirements/test.in
threadpoolctl==3.5.0 threadpoolctl==3.5.0
# via scikit-learn # via scikit-learn

View File

@ -298,6 +298,8 @@ def _compare_tp(
tokenizer_mode = model_info.tokenizer_mode tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides hf_overrides = model_info.hf_overrides
hf_config = get_config(model_id, trust_remote_code) hf_config = get_config(model_id, trust_remote_code)
skip_tokenizer_init = model_info.skip_tokenizer_init
max_num_seqs = model_info.max_num_seqs
dtype = "float16" dtype = "float16"
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS: if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
@ -351,6 +353,10 @@ def _compare_tp(
common_args.extend(["--load-format", load_format]) common_args.extend(["--load-format", load_format])
if hf_overrides: if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init:
common_args.append("--skip-tokenizer-init")
if max_num_seqs:
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
testing_ray_compiled_graph = False testing_ray_compiled_graph = False

View File

@ -178,6 +178,7 @@ def _compare_sp(
trust_remote_code = model_info.trust_remote_code trust_remote_code = model_info.trust_remote_code
tokenizer_mode = model_info.tokenizer_mode tokenizer_mode = model_info.tokenizer_mode
hf_overrides = model_info.hf_overrides hf_overrides = model_info.hf_overrides
skip_tokenizer_init = model_info.skip_tokenizer_init
if load_format == "dummy": if load_format == "dummy":
# Avoid OOM # Avoid OOM
@ -227,6 +228,8 @@ def _compare_sp(
common_args.extend(["--load-format", load_format]) common_args.extend(["--load-format", load_format])
if hf_overrides: if hf_overrides:
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
if skip_tokenizer_init:
common_args.append("--skip-tokenizer-init")
compilation_config = { compilation_config = {
'level': 3, 'level': 3,

View File

@ -104,7 +104,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
revision=model_info.revision, revision=model_info.revision,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Initialize the tokenizer # Initialize the tokenizer
tokenizer = get_tokenizer( tokenizer = get_tokenizer(

View File

@ -11,7 +11,7 @@ import torch
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
DTYPE = "float16" DTYPE = "float16"
@ -35,7 +35,9 @@ def server():
"--trust-remote-code", "--trust-remote-code",
"--skip-tokenizer-init", "--skip-tokenizer-init",
"--max-num-seqs", "--max-num-seqs",
"32" "32",
"--model-impl",
"terratorch"
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

View File

@ -1266,7 +1266,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Build the tokenizer group and grab the underlying tokenizer # Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
@ -1322,7 +1324,9 @@ def test_resolve_content_format_hf_defined(model, expected_format):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
model, model,
@ -1382,7 +1386,9 @@ def test_resolve_content_format_fallbacks(model, expected_format):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
model_config.tokenizer, model_config.tokenizer,

View File

@ -69,6 +69,9 @@ def run_test(
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
if model_info.hf_overrides: if model_info.hf_overrides:
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
if model_info.skip_tokenizer_init:
vllm_runner_kwargs_[
"skip_tokenizer_init"] = model_info.skip_tokenizer_init
if vllm_runner_kwargs: if vllm_runner_kwargs:
vllm_runner_kwargs_.update(vllm_runner_kwargs) vllm_runner_kwargs_.update(vllm_runner_kwargs)

View File

@ -46,7 +46,7 @@ def _run_test(
vllm_model.encode(prompt) vllm_model.encode(prompt)
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
@pytest.mark.core_model @pytest.mark.core_model

View File

@ -66,7 +66,9 @@ def _test_processing_correctness(
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
# Ensure that the cache can fit all of the data # Ensure that the cache can fit all of the data
mm_processor_cache_gb=2048, mm_processor_cache_gb=2048,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]

View File

@ -196,7 +196,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]

View File

@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str):
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
) skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
original_weights = create_repo_dummy_weights(model_id) original_weights = create_repo_dummy_weights(model_id)

View File

@ -6,10 +6,11 @@ from dataclasses import dataclass, field
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
import pytest import pytest
import torch
from packaging.version import Version from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.config import TokenizerMode from vllm.config import ModelDType, TokenizerMode
@dataclass(frozen=True) @dataclass(frozen=True)
@ -47,6 +48,23 @@ class _HfExamplesInfo:
The reason for the minimum/maximum version requirement. The reason for the minimum/maximum version requirement.
""" """
skip_tokenizer_init: bool = False
"""
If true, skip initialization of tokenizer and detokenizer.
"""
dtype: ModelDType = "auto"
"""
The data type for the model weights and activations.
"""
enforce_eager: bool = False
"""
Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
"""
is_available_online: bool = True is_available_online: bool = True
""" """
Set this to ``False`` if the name of this architecture no longer exists on Set this to ``False`` if the name of this architecture no longer exists on
@ -76,6 +94,9 @@ class _HfExamplesInfo:
If not specified, the default revision will be used. If not specified, the default revision will be used.
""" """
max_num_seqs: Optional[int] = None
"""Maximum number of sequences to be processed in a single iteration."""
def check_transformers_version( def check_transformers_version(
self, self,
*, *,
@ -361,8 +382,21 @@ _EMBEDDING_EXAMPLE_MODELS = {
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full", "Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
trust_remote_code=True), trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501 "PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
is_available_online=False), # noqa: E501 dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model
# going OOM in CI
max_num_seqs=32,
),
"Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# This is to avoid the model going OOM in CI
max_num_seqs=32,
),
} }
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { _SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {

View File

@ -73,6 +73,9 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
tokenizer=model_info.tokenizer, tokenizer=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode, tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision, revision=model_info.revision,
enforce_eager=model_info.enforce_eager,
skip_tokenizer_init=model_info.skip_tokenizer_init,
dtype=model_info.dtype,
speculative_config={ speculative_config={
"model": model_info.speculative_model, "model": model_info.speculative_model,
"num_speculative_tokens": 1, "num_speculative_tokens": 1,
@ -85,7 +88,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
model_impl=ModelImpl.TRANSFORMERS model_impl=ModelImpl.TRANSFORMERS
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM, if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
) max_num_seqs=model_info.max_num_seqs)
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) @pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())

View File

@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.conftest import VllmRunner
from vllm.utils import set_default_torch_num_threads
@pytest.mark.parametrize(
"model",
[
"mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
"mgazz/Prithvi_v2_eo_300_tl_unet_agb"
],
)
def test_inference(
vllm_runner: type[VllmRunner],
model: str,
) -> None:
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
prompt = dict(prompt_token_ids=[1],
multi_modal_data=dict(pixel_values=pixel_values,
location_coords=location_coords))
with (
set_default_torch_num_threads(1),
vllm_runner(
model,
runner="pooling",
dtype=torch.float16,
enforce_eager=True,
skip_tokenizer_init=True,
# Limit the maximum number of sequences to avoid the
# test going OOM during the warmup run
max_num_seqs=32,
) as vllm_model,
):
vllm_output = vllm_model.llm.encode(prompt)
assert torch.equal(
torch.isnan(vllm_output[0].outputs.data).any(),
torch.tensor(False))

View File

@ -294,6 +294,8 @@ def build_model_context(
limit_mm_per_prompt=limit_mm_per_prompt, limit_mm_per_prompt=limit_mm_per_prompt,
mm_processor_cache_gb=mm_processor_cache_gb, mm_processor_cache_gb=mm_processor_cache_gb,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
**model_config_kwargs, **model_config_kwargs,
) )
return InputContext(model_config) return InputContext(model_config)

View File

@ -7,12 +7,11 @@ import requests
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.entrypoints.llm import LLM
from vllm.entrypoints.openai.protocol import IOProcessorResponse from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501 image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
@ -23,61 +22,7 @@ def test_loading_missing_plugin():
get_io_processor(vllm_config, "wrong_plugin") get_io_processor(vllm_config, "wrong_plugin")
def test_loading_engine_with_wrong_plugin(): @pytest.fixture(scope="function")
with pytest.raises(ValueError):
LLM(
model=MODEL_NAME,
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=32,
io_processor_plugin="wrong_plugin",
)
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
pooling_params = PoolingParams(task="encode", softmax=False)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=1,
io_processor_plugin="prithvi_to_tiff_valencia",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(output.data)
@pytest.fixture(scope="module")
def server(): def server():
args = [ args = [
"--runner", "--runner",
@ -90,7 +35,9 @@ def server():
"--max-num-seqs", "--max-num-seqs",
"32", "32",
"--io-processor-plugin", "--io-processor-plugin",
"prithvi_to_tiff_valencia" "prithvi_to_tiff_valencia",
"--model-impl",
"terratorch",
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@ -136,3 +83,43 @@ async def test_prithvi_mae_plugin_online(
# We just check that the output is a valid base64 string. # We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted. # Raises an exception and fails the test if the string is corrupted.
base64.b64decode(plugin_data["data"]) base64.b64decode(plugin_data["data"])
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
pooling_params = PoolingParams(task="encode", softmax=False)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=1,
model_impl="terratorch",
io_processor_plugin="prithvi_to_tiff_valencia",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(output.data)

View File

@ -171,6 +171,7 @@ class ModelImpl(str, enum.Enum):
AUTO = "auto" AUTO = "auto"
VLLM = "vllm" VLLM = "vllm"
TRANSFORMERS = "transformers" TRANSFORMERS = "transformers"
TERRATORCH = "terratorch"
def get_attr_docs(cls: type[Any]) -> dict[str, str]: def get_attr_docs(cls: type[Any]) -> dict[str, str]:
@ -496,7 +497,9 @@ class ModelConfig:
back to the Transformers implementation if no vLLM implementation is back to the Transformers implementation if no vLLM implementation is
available.\n available.\n
- "vllm" will use the vLLM model implementation.\n - "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation.""" - "transformers" will use the Transformers model implementation.\n
- "terratorch" will use the TerraTorch model implementation.
"""
override_attention_dtype: Optional[str] = None override_attention_dtype: Optional[str] = None
"""Override dtype for attention""" """Override dtype for attention"""
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None

View File

@ -184,10 +184,11 @@ _EMBEDDING_MODELS = {
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
# Technically PrithviGeoSpatialMAE is a model that works on images, both in # Technically Terratorch models work on images, both in
# input and output. I am adding it here because it piggybacks on embedding # input and output. I am adding it here because it piggy-backs on embedding
# models for the time being. # models for the time being.
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"), "PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
"Terratorch": ("terratorch", "Terratorch"),
} }
_CROSS_ENCODER_MODELS = { _CROSS_ENCODER_MODELS = {
@ -639,6 +640,9 @@ class _ModelRegistry:
model_info = self._try_inspect_model_cls(arch) model_info = self._try_inspect_model_cls(arch)
if model_info is not None: if model_info is not None:
return (model_info, arch) return (model_info, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
model_info = self._try_inspect_model_cls("Terratorch")
return (model_info, "Terratorch")
# Fallback to transformers impl (after resolving convert_type) # Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures) if (all(arch not in self.models for arch in architectures)
@ -687,6 +691,11 @@ class _ModelRegistry:
model_cls = self._try_load_model_cls(arch) model_cls = self._try_load_model_cls(arch)
if model_cls is not None: if model_cls is not None:
return (model_cls, arch) return (model_cls, arch)
elif model_config.model_impl == ModelImpl.TERRATORCH:
arch = "Terratorch"
model_cls = self._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
# Fallback to transformers impl (after resolving convert_type) # Fallback to transformers impl (after resolving convert_type)
if (all(arch not in self.models for arch in architectures) if (all(arch not in self.models for arch in architectures)

View File

@ -15,13 +15,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model.""" """Wrapper around `Terratorch` models"""
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
InputDefinition, InputTypeEnum)
from transformers import BatchFeature from transformers import BatchFeature
from vllm.config import VllmConfig from vllm.config import VllmConfig
@ -29,6 +32,7 @@ from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import (ImageItem, ModalityData, from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig, MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargsItems, MultiModalInputs, MultiModalKwargsItems,
@ -45,52 +49,46 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
from .interfaces_base import default_pooling_type from .interfaces_base import default_pooling_type
def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]): def _terratorch_field_names(pretrained_cfg: dict):
# This model receives in input a multi-dimensional tensor representing input_definition = InputDefinition(**pretrained_cfg["input"])
# a single image patch and therefore it is not to be split return set(input_definition.data.keys())
# into multiple elements, but rather to be considered a single one.
# Hence, the decision of using a MultiModalSharedField.
# The expected shape is (num_channels, width, height).
# This model however allows the user to also submit multiple image
# patches as a batch, adding a further dimension to the above shape.
# At this stage we only support submitting one patch per request and
# batching is achieved via vLLM batching.
# TODO (christian-pinto): enable support for multi patch requests
# in tandem with vLLM batching.
return dict(
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
location_coords=MultiModalFieldConfig.shared(batch_size=1,
modality="image"),
)
class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser): def _terratorch_field_factory(
pretrained_cfg: dict
) -> Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
]:
def _parse_image_data( def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
self, input_definition = InputDefinition(**pretrained_cfg["input"])
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], fields = {}
) -> Optional[ModalityDataItems[Any, Any]]: for input_name, input in input_definition.data.items():
if isinstance(data, dict): if input.type == InputTypeEnum.tensor:
return DictEmbeddingItems( fields[input_name] = "image"
data,
modality="image",
required_fields={"pixel_values", "location_coords"},
fields_factory=_prithvi_field_config,
)
return super()._parse_image_data(data) mm_fields_config = {}
for field_name, field_modality in fields.items():
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
batch_size=1, modality=field_modality)
return mm_fields_config
return _terratorch_field_config
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): class TerratorchProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None}
class PrithviGeoSpatialMAEInputBuilder( class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
def __init__(self, info: TerratorchProcessingInfo):
super().__init__(info)
self.dummy_data_generator = DummyDataGenerator(
self.info.get_hf_config().to_dict()["pretrained_cfg"])
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return "" return ""
@ -100,29 +98,57 @@ class PrithviGeoSpatialMAEInputBuilder(
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> MultiModalDataDict: ) -> MultiModalDataDict:
# This model input is fixed and is in the form of a torch Tensor. # Dummy data is generated based on the 'input' section
# The size of pixel_values might change in the cases where we resize # defined in the HF configuration file
# the input but never exceeds the dimensions below. return self.dummy_data_generator.get_dummy_mm_data()
image_data = {
"pixel_values": torch.full((6, 512, 512), 1.0,
dtype=torch.float16),
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
}
return {"image": image_data}
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): class TerratorchMultiModalDataParser(MultiModalDataParser):
def __init__(self, pretrained_cfg: dict, *args, **kwargs):
self._pretrained_cfg = pretrained_cfg
super().__init__(*args, **kwargs)
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> Optional[ModalityDataItems[Any, Any]]:
if isinstance(data, dict):
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
return DictEmbeddingItems(
data,
modality="image",
required_fields=terratorch_fields,
fields_factory=_terratorch_field_factory(self._pretrained_cfg),
)
return super()._parse_image_data(data)
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
def __init__(
self,
info: TerratorchProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
*,
cache: Optional[MultiModalProcessorOnlyCache] = None) -> None:
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
def _get_data_parser(self) -> MultiModalDataParser: def _get_data_parser(self) -> MultiModalDataParser:
return PrithviGeoSpatialMAEMultiModalDataParser() return TerratorchMultiModalDataParser(
pretrained_cfg=self.pretrained_cfg)
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return _prithvi_field_config(hf_inputs) return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
def _get_prompt_updates( def _get_prompt_updates(
self, self,
@ -173,13 +199,11 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
@default_pooling_type("All") @default_pooling_type("All")
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
PrithviGeoSpatialMAEMultiModalProcessor, TerratorchMultiModalProcessor,
info=PrithviGeoSpatialMAEProcessingInfo, info=TerratorchProcessingInfo,
dummy_inputs=PrithviGeoSpatialMAEInputBuilder, dummy_inputs=TerratorchInputBuilder,
) )
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal): class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
"""Prithvi Masked Autoencoder"""
supports_multimodal_raw_input_only = True supports_multimodal_raw_input_only = True
is_pooling_model = True is_pooling_model = True
@ -190,43 +214,13 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
raise ValueError("Only image modality is supported") raise ValueError("Only image modality is supported")
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask":
from terratorch.cli_tools import SemanticSegmentationTask
task = SemanticSegmentationTask(
config["model_args"],
config["task_args"]["model_factory"],
loss=config["task_args"]["loss"],
lr=config["task_args"]["lr"],
ignore_index=config["task_args"]["ignore_index"],
optimizer=config["task_args"]["optimizer"],
optimizer_hparams=config["optimizer_params"],
scheduler=config["task_args"]["scheduler"],
scheduler_hparams=config["scheduler_params"],
plot_on_val=config["task_args"]["plot_on_val"],
freeze_decoder=config["task_args"]["freeze_decoder"],
freeze_backbone=config["task_args"]["freeze_backbone"],
)
return task.model
else:
return None
def __init__(self, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
# the actual model is dynamically instantiated using terratorch config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]
# allowing us to perform changes to the model architecture
# at startup time (e.g., change the model decoder class.) self.inference_runner = InferenceRunner(config)
self.model = self._instantiate_model( self.model = self.inference_runner.model
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
if self.model is None:
raise ValueError(
"Unsupported task. "
"Only SemanticSegmentationTask is supported for now "
"by PrithviGeospatialMAE.")
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None assert pooler_config is not None
@ -234,23 +228,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
self.pooler = DispatchPooler( self.pooler = DispatchPooler(
{"encode": Pooler.for_encode(pooler_config)}, ) {"encode": Pooler.for_encode(pooler_config)}, )
def _parse_and_validate_multimodal_data(
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):
raise ValueError(f"Incorrect type of pixel_values. "
f"Got type: {type(pixel_values)}")
location_coords = kwargs.pop("location_coords", None)
if not isinstance(location_coords, torch.Tensor):
raise ValueError(f"Incorrect type of location_coords. "
f"Got type: {type(location_coords)}")
location_coords = torch.unbind(location_coords, dim=0)[0]
if location_coords.shape == torch.Size([0]):
location_coords = None
return pixel_values, location_coords
def get_input_embeddings( def get_input_embeddings(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
@ -270,10 +247,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object, **kwargs: object,
): ):
pixel_values, location_coords = ( model_output = self.inference_runner.forward(**kwargs)
self._parse_and_validate_multimodal_data(**kwargs))
model_output = self.model(pixel_values,
location_coords=location_coords)
return model_output.output return model_output.output
@ -283,9 +257,12 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
model_buffers = dict(self.named_buffers()) model_buffers = dict(self.named_buffers())
loaded_buffers = [] loaded_buffers = []
for key, value in weights: for key, value in weights:
if isinstance(value, (dict, OrderedDict)):
if key == "state_dict": if key == "state_dict":
weights_to_parse = value weights_to_parse = value
for name, weight in weights_to_parse.items(): for name, weight in weights_to_parse.items():
name = f"inference_runner.{name}"
if "pos_embed" in name: if "pos_embed" in name:
continue continue
@ -306,6 +283,9 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
params_list.append((name, weight)) params_list.append((name, weight))
break break
elif isinstance(value, torch.Tensor):
params_list.append((f"inference_runner.model.{key}", value))
# Load the remaining model parameters # Load the remaining model parameters
loader = AutoWeightsLoader(self) loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(params_list) autoloaded_weights = loader.load_weights(params_list)