mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-20 04:15:01 +08:00
[Core][Model] Terratorch backend integration (#23513)
Signed-off-by: Michele Gazzetti <michele.gazzetti1@ibm.com> Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e7fc70016f
commit
51d5e9be7d
@ -45,7 +45,11 @@ datamodule_config = {
|
|||||||
class PrithviMAE:
|
class PrithviMAE:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = LLM(
|
self.model = LLM(
|
||||||
model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True
|
model=model,
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
dtype="float16",
|
||||||
|
enforce_eager=True,
|
||||||
|
model_impl="terratorch",
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self, input_data, location_coords):
|
def run(self, input_data, location_coords):
|
||||||
|
|||||||
@ -37,6 +37,7 @@ def main():
|
|||||||
# The maximum number depends on the available GPU memory
|
# The maximum number depends on the available GPU memory
|
||||||
max_num_seqs=32,
|
max_num_seqs=32,
|
||||||
io_processor_plugin="prithvi_to_tiff_india",
|
io_processor_plugin="prithvi_to_tiff_india",
|
||||||
|
model_impl="terratorch",
|
||||||
)
|
)
|
||||||
|
|
||||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import requests
|
|||||||
# https://github.com/christian-pinto/prithvi_io_processor_plugin
|
# https://github.com/christian-pinto/prithvi_io_processor_plugin
|
||||||
# - start vllm in serving mode with the below args
|
# - start vllm in serving mode with the below args
|
||||||
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
||||||
|
# --model-impl terratorch
|
||||||
# --task embed --trust-remote-code
|
# --task embed --trust-remote-code
|
||||||
# --skip-tokenizer-init --enforce-eager
|
# --skip-tokenizer-init --enforce-eager
|
||||||
# --io-processor-plugin prithvi_to_tiff_india
|
# --io-processor-plugin prithvi_to_tiff_india
|
||||||
|
|||||||
@ -53,5 +53,5 @@ runai-model-streamer==0.11.0
|
|||||||
runai-model-streamer-s3==0.11.0
|
runai-model-streamer-s3==0.11.0
|
||||||
fastsafetensors>=0.1.10
|
fastsafetensors>=0.1.10
|
||||||
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
pydantic>=2.10 # 2.9 leads to error on python 3.10
|
||||||
terratorch==1.1rc2 # required for PrithviMAE test
|
|
||||||
decord==0.6.0
|
decord==0.6.0
|
||||||
|
terratorch==1.1rc3 # required for PrithviMAE test
|
||||||
|
|||||||
@ -1042,7 +1042,7 @@ tensorboardx==2.6.4
|
|||||||
# via lightning
|
# via lightning
|
||||||
tensorizer==2.10.1
|
tensorizer==2.10.1
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
terratorch==1.1rc2
|
terratorch==1.1rc3
|
||||||
# via -r requirements/test.in
|
# via -r requirements/test.in
|
||||||
threadpoolctl==3.5.0
|
threadpoolctl==3.5.0
|
||||||
# via scikit-learn
|
# via scikit-learn
|
||||||
|
|||||||
@ -298,6 +298,8 @@ def _compare_tp(
|
|||||||
tokenizer_mode = model_info.tokenizer_mode
|
tokenizer_mode = model_info.tokenizer_mode
|
||||||
hf_overrides = model_info.hf_overrides
|
hf_overrides = model_info.hf_overrides
|
||||||
hf_config = get_config(model_id, trust_remote_code)
|
hf_config = get_config(model_id, trust_remote_code)
|
||||||
|
skip_tokenizer_init = model_info.skip_tokenizer_init
|
||||||
|
max_num_seqs = model_info.max_num_seqs
|
||||||
|
|
||||||
dtype = "float16"
|
dtype = "float16"
|
||||||
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
|
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
|
||||||
@ -351,6 +353,10 @@ def _compare_tp(
|
|||||||
common_args.extend(["--load-format", load_format])
|
common_args.extend(["--load-format", load_format])
|
||||||
if hf_overrides:
|
if hf_overrides:
|
||||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||||
|
if skip_tokenizer_init:
|
||||||
|
common_args.append("--skip-tokenizer-init")
|
||||||
|
if max_num_seqs:
|
||||||
|
common_args.extend(["--max-num-seqs", f"{max_num_seqs}"])
|
||||||
|
|
||||||
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
|
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
|
||||||
testing_ray_compiled_graph = False
|
testing_ray_compiled_graph = False
|
||||||
|
|||||||
@ -178,6 +178,7 @@ def _compare_sp(
|
|||||||
trust_remote_code = model_info.trust_remote_code
|
trust_remote_code = model_info.trust_remote_code
|
||||||
tokenizer_mode = model_info.tokenizer_mode
|
tokenizer_mode = model_info.tokenizer_mode
|
||||||
hf_overrides = model_info.hf_overrides
|
hf_overrides = model_info.hf_overrides
|
||||||
|
skip_tokenizer_init = model_info.skip_tokenizer_init
|
||||||
|
|
||||||
if load_format == "dummy":
|
if load_format == "dummy":
|
||||||
# Avoid OOM
|
# Avoid OOM
|
||||||
@ -227,6 +228,8 @@ def _compare_sp(
|
|||||||
common_args.extend(["--load-format", load_format])
|
common_args.extend(["--load-format", load_format])
|
||||||
if hf_overrides:
|
if hf_overrides:
|
||||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||||
|
if skip_tokenizer_init:
|
||||||
|
common_args.append("--skip-tokenizer-init")
|
||||||
|
|
||||||
compilation_config = {
|
compilation_config = {
|
||||||
'level': 3,
|
'level': 3,
|
||||||
|
|||||||
@ -104,7 +104,9 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
trust_remote_code=model_info.trust_remote_code,
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
revision=model_info.revision,
|
revision=model_info.revision,
|
||||||
hf_overrides=model_info.hf_overrides,
|
hf_overrides=model_info.hf_overrides,
|
||||||
)
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||||
|
enforce_eager=model_info.enforce_eager,
|
||||||
|
dtype=model_info.dtype)
|
||||||
|
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import torch
|
|||||||
|
|
||||||
from ...utils import RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
|
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||||
DTYPE = "float16"
|
DTYPE = "float16"
|
||||||
|
|
||||||
|
|
||||||
@ -35,7 +35,9 @@ def server():
|
|||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
"--skip-tokenizer-init",
|
"--skip-tokenizer-init",
|
||||||
"--max-num-seqs",
|
"--max-num-seqs",
|
||||||
"32"
|
"32",
|
||||||
|
"--model-impl",
|
||||||
|
"terratorch"
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
|||||||
@ -1266,7 +1266,9 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
|||||||
revision=model_info.revision,
|
revision=model_info.revision,
|
||||||
trust_remote_code=model_info.trust_remote_code,
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
hf_overrides=model_info.hf_overrides,
|
hf_overrides=model_info.hf_overrides,
|
||||||
)
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||||
|
enforce_eager=model_info.enforce_eager,
|
||||||
|
dtype=model_info.dtype)
|
||||||
|
|
||||||
# Build the tokenizer group and grab the underlying tokenizer
|
# Build the tokenizer group and grab the underlying tokenizer
|
||||||
tokenizer_group = TokenizerGroup(
|
tokenizer_group = TokenizerGroup(
|
||||||
@ -1322,7 +1324,9 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
|||||||
revision=model_info.revision,
|
revision=model_info.revision,
|
||||||
trust_remote_code=model_info.trust_remote_code,
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
hf_overrides=model_info.hf_overrides,
|
hf_overrides=model_info.hf_overrides,
|
||||||
)
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||||
|
enforce_eager=model_info.enforce_eager,
|
||||||
|
dtype=model_info.dtype)
|
||||||
|
|
||||||
tokenizer_group = TokenizerGroup(
|
tokenizer_group = TokenizerGroup(
|
||||||
model,
|
model,
|
||||||
@ -1382,7 +1386,9 @@ def test_resolve_content_format_fallbacks(model, expected_format):
|
|||||||
revision=model_info.revision,
|
revision=model_info.revision,
|
||||||
trust_remote_code=model_info.trust_remote_code,
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
hf_overrides=model_info.hf_overrides,
|
hf_overrides=model_info.hf_overrides,
|
||||||
)
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||||
|
enforce_eager=model_info.enforce_eager,
|
||||||
|
dtype=model_info.dtype)
|
||||||
|
|
||||||
tokenizer_group = TokenizerGroup(
|
tokenizer_group = TokenizerGroup(
|
||||||
model_config.tokenizer,
|
model_config.tokenizer,
|
||||||
|
|||||||
@ -69,6 +69,9 @@ def run_test(
|
|||||||
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
|
vllm_runner_kwargs_["tokenizer_mode"] = model_info.tokenizer_mode
|
||||||
if model_info.hf_overrides:
|
if model_info.hf_overrides:
|
||||||
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
|
vllm_runner_kwargs_["hf_overrides"] = model_info.hf_overrides
|
||||||
|
if model_info.skip_tokenizer_init:
|
||||||
|
vllm_runner_kwargs_[
|
||||||
|
"skip_tokenizer_init"] = model_info.skip_tokenizer_init
|
||||||
|
|
||||||
if vllm_runner_kwargs:
|
if vllm_runner_kwargs:
|
||||||
vllm_runner_kwargs_.update(vllm_runner_kwargs)
|
vllm_runner_kwargs_.update(vllm_runner_kwargs)
|
||||||
|
|||||||
@ -46,7 +46,7 @@ def _run_test(
|
|||||||
vllm_model.encode(prompt)
|
vllm_model.encode(prompt)
|
||||||
|
|
||||||
|
|
||||||
MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"]
|
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.core_model
|
@pytest.mark.core_model
|
||||||
|
|||||||
@ -66,7 +66,9 @@ def _test_processing_correctness(
|
|||||||
hf_overrides=model_info.hf_overrides,
|
hf_overrides=model_info.hf_overrides,
|
||||||
# Ensure that the cache can fit all of the data
|
# Ensure that the cache can fit all of the data
|
||||||
mm_processor_cache_gb=2048,
|
mm_processor_cache_gb=2048,
|
||||||
)
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||||
|
enforce_eager=model_info.enforce_eager,
|
||||||
|
dtype=model_info.dtype)
|
||||||
|
|
||||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||||
|
|||||||
@ -196,7 +196,9 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
|
|||||||
revision=model_info.revision,
|
revision=model_info.revision,
|
||||||
trust_remote_code=model_info.trust_remote_code,
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
hf_overrides=hf_overrides_fn,
|
hf_overrides=hf_overrides_fn,
|
||||||
)
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||||
|
enforce_eager=model_info.enforce_eager,
|
||||||
|
dtype=model_info.dtype)
|
||||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||||
|
|
||||||
|
|||||||
@ -59,7 +59,9 @@ def test_hf_model_weights_mapper(model_arch: str):
|
|||||||
revision=model_info.revision,
|
revision=model_info.revision,
|
||||||
trust_remote_code=model_info.trust_remote_code,
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
hf_overrides=model_info.hf_overrides,
|
hf_overrides=model_info.hf_overrides,
|
||||||
)
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||||
|
enforce_eager=model_info.enforce_eager,
|
||||||
|
dtype=model_info.dtype)
|
||||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||||
|
|
||||||
original_weights = create_repo_dummy_weights(model_id)
|
original_weights = create_repo_dummy_weights(model_id)
|
||||||
|
|||||||
@ -6,10 +6,11 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||||
|
|
||||||
from vllm.config import TokenizerMode
|
from vllm.config import ModelDType, TokenizerMode
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -47,6 +48,23 @@ class _HfExamplesInfo:
|
|||||||
The reason for the minimum/maximum version requirement.
|
The reason for the minimum/maximum version requirement.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
skip_tokenizer_init: bool = False
|
||||||
|
"""
|
||||||
|
If true, skip initialization of tokenizer and detokenizer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dtype: ModelDType = "auto"
|
||||||
|
"""
|
||||||
|
The data type for the model weights and activations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enforce_eager: bool = False
|
||||||
|
"""
|
||||||
|
Whether to enforce eager execution. If True, we will
|
||||||
|
disable CUDA graph and always execute the model in eager mode.
|
||||||
|
If False, we will use CUDA graph and eager execution in hybrid.
|
||||||
|
"""
|
||||||
|
|
||||||
is_available_online: bool = True
|
is_available_online: bool = True
|
||||||
"""
|
"""
|
||||||
Set this to ``False`` if the name of this architecture no longer exists on
|
Set this to ``False`` if the name of this architecture no longer exists on
|
||||||
@ -76,6 +94,9 @@ class _HfExamplesInfo:
|
|||||||
If not specified, the default revision will be used.
|
If not specified, the default revision will be used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
max_num_seqs: Optional[int] = None
|
||||||
|
"""Maximum number of sequences to be processed in a single iteration."""
|
||||||
|
|
||||||
def check_transformers_version(
|
def check_transformers_version(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -361,8 +382,21 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
|||||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
||||||
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
|
"PrithviGeoSpatialMAE": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
|
||||||
is_available_online=False), # noqa: E501
|
dtype=torch.float16,
|
||||||
|
enforce_eager=True,
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
# This is to avoid the model
|
||||||
|
# going OOM in CI
|
||||||
|
max_num_seqs=32,
|
||||||
|
),
|
||||||
|
"Terratorch": _HfExamplesInfo("mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||||
|
dtype=torch.float16,
|
||||||
|
enforce_eager=True,
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
# This is to avoid the model going OOM in CI
|
||||||
|
max_num_seqs=32,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = {
|
||||||
|
|||||||
@ -73,6 +73,9 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
|||||||
tokenizer=model_info.tokenizer,
|
tokenizer=model_info.tokenizer,
|
||||||
tokenizer_mode=model_info.tokenizer_mode,
|
tokenizer_mode=model_info.tokenizer_mode,
|
||||||
revision=model_info.revision,
|
revision=model_info.revision,
|
||||||
|
enforce_eager=model_info.enforce_eager,
|
||||||
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||||
|
dtype=model_info.dtype,
|
||||||
speculative_config={
|
speculative_config={
|
||||||
"model": model_info.speculative_model,
|
"model": model_info.speculative_model,
|
||||||
"num_speculative_tokens": 1,
|
"num_speculative_tokens": 1,
|
||||||
@ -85,7 +88,7 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
|||||||
model_impl=ModelImpl.TRANSFORMERS
|
model_impl=ModelImpl.TRANSFORMERS
|
||||||
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
|
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
|
||||||
hf_overrides=hf_overrides_fn,
|
hf_overrides=hf_overrides_fn,
|
||||||
)
|
max_num_seqs=model_info.max_num_seqs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
|
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
|
||||||
|
|||||||
45
tests/models/test_terratorch.py
Normal file
45
tests/models/test_terratorch.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.conftest import VllmRunner
|
||||||
|
from vllm.utils import set_default_torch_num_threads
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
||||||
|
"mgazz/Prithvi_v2_eo_300_tl_unet_agb"
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_inference(
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
model: str,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
|
||||||
|
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
|
||||||
|
prompt = dict(prompt_token_ids=[1],
|
||||||
|
multi_modal_data=dict(pixel_values=pixel_values,
|
||||||
|
location_coords=location_coords))
|
||||||
|
with (
|
||||||
|
set_default_torch_num_threads(1),
|
||||||
|
vllm_runner(
|
||||||
|
model,
|
||||||
|
runner="pooling",
|
||||||
|
dtype=torch.float16,
|
||||||
|
enforce_eager=True,
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
# Limit the maximum number of sequences to avoid the
|
||||||
|
# test going OOM during the warmup run
|
||||||
|
max_num_seqs=32,
|
||||||
|
) as vllm_model,
|
||||||
|
):
|
||||||
|
|
||||||
|
vllm_output = vllm_model.llm.encode(prompt)
|
||||||
|
assert torch.equal(
|
||||||
|
torch.isnan(vllm_output[0].outputs.data).any(),
|
||||||
|
torch.tensor(False))
|
||||||
@ -294,6 +294,8 @@ def build_model_context(
|
|||||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||||
mm_processor_cache_gb=mm_processor_cache_gb,
|
mm_processor_cache_gb=mm_processor_cache_gb,
|
||||||
hf_overrides=model_info.hf_overrides,
|
hf_overrides=model_info.hf_overrides,
|
||||||
|
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||||
|
enforce_eager=model_info.enforce_eager,
|
||||||
**model_config_kwargs,
|
**model_config_kwargs,
|
||||||
)
|
)
|
||||||
return InputContext(model_config)
|
return InputContext(model_config)
|
||||||
|
|||||||
@ -7,12 +7,11 @@ import requests
|
|||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
from tests.utils import RemoteOpenAIServer
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.entrypoints.llm import LLM
|
|
||||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||||
from vllm.plugins.io_processors import get_io_processor
|
from vllm.plugins.io_processors import get_io_processor
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
|
|
||||||
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
|
MODEL_NAME = "mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||||
|
|
||||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||||
|
|
||||||
@ -23,61 +22,7 @@ def test_loading_missing_plugin():
|
|||||||
get_io_processor(vllm_config, "wrong_plugin")
|
get_io_processor(vllm_config, "wrong_plugin")
|
||||||
|
|
||||||
|
|
||||||
def test_loading_engine_with_wrong_plugin():
|
@pytest.fixture(scope="function")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
LLM(
|
|
||||||
model=MODEL_NAME,
|
|
||||||
skip_tokenizer_init=True,
|
|
||||||
trust_remote_code=True,
|
|
||||||
enforce_eager=True,
|
|
||||||
# Limit the maximum number of parallel requests
|
|
||||||
# to avoid the model going OOM in CI.
|
|
||||||
max_num_seqs=32,
|
|
||||||
io_processor_plugin="wrong_plugin",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
|
||||||
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
|
||||||
|
|
||||||
img_prompt = dict(
|
|
||||||
data=image_url,
|
|
||||||
data_format="url",
|
|
||||||
image_format="tiff",
|
|
||||||
out_data_format="b64_json",
|
|
||||||
)
|
|
||||||
|
|
||||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
|
||||||
|
|
||||||
with vllm_runner(
|
|
||||||
model_name,
|
|
||||||
runner="pooling",
|
|
||||||
skip_tokenizer_init=True,
|
|
||||||
trust_remote_code=True,
|
|
||||||
enforce_eager=True,
|
|
||||||
# Limit the maximum number of parallel requests
|
|
||||||
# to avoid the model going OOM in CI.
|
|
||||||
max_num_seqs=1,
|
|
||||||
io_processor_plugin="prithvi_to_tiff_valencia",
|
|
||||||
) as llm_runner:
|
|
||||||
pooler_output = llm_runner.get_llm().encode(
|
|
||||||
img_prompt,
|
|
||||||
pooling_params=pooling_params,
|
|
||||||
)
|
|
||||||
output = pooler_output[0].outputs
|
|
||||||
|
|
||||||
# verify the output is formatted as expected for this plugin
|
|
||||||
assert all(
|
|
||||||
hasattr(output, attr)
|
|
||||||
for attr in ["type", "format", "data", "request_id"])
|
|
||||||
|
|
||||||
# We just check that the output is a valid base64 string.
|
|
||||||
# Raises an exception and fails the test if the string is corrupted.
|
|
||||||
base64.b64decode(output.data)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def server():
|
def server():
|
||||||
args = [
|
args = [
|
||||||
"--runner",
|
"--runner",
|
||||||
@ -90,7 +35,9 @@ def server():
|
|||||||
"--max-num-seqs",
|
"--max-num-seqs",
|
||||||
"32",
|
"32",
|
||||||
"--io-processor-plugin",
|
"--io-processor-plugin",
|
||||||
"prithvi_to_tiff_valencia"
|
"prithvi_to_tiff_valencia",
|
||||||
|
"--model-impl",
|
||||||
|
"terratorch",
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
@ -136,3 +83,43 @@ async def test_prithvi_mae_plugin_online(
|
|||||||
# We just check that the output is a valid base64 string.
|
# We just check that the output is a valid base64 string.
|
||||||
# Raises an exception and fails the test if the string is corrupted.
|
# Raises an exception and fails the test if the string is corrupted.
|
||||||
base64.b64decode(plugin_data["data"])
|
base64.b64decode(plugin_data["data"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||||
|
|
||||||
|
img_prompt = dict(
|
||||||
|
data=image_url,
|
||||||
|
data_format="url",
|
||||||
|
image_format="tiff",
|
||||||
|
out_data_format="b64_json",
|
||||||
|
)
|
||||||
|
|
||||||
|
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model_name,
|
||||||
|
runner="pooling",
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
enforce_eager=True,
|
||||||
|
# Limit the maximum number of parallel requests
|
||||||
|
# to avoid the model going OOM in CI.
|
||||||
|
max_num_seqs=1,
|
||||||
|
model_impl="terratorch",
|
||||||
|
io_processor_plugin="prithvi_to_tiff_valencia",
|
||||||
|
) as llm_runner:
|
||||||
|
pooler_output = llm_runner.get_llm().encode(
|
||||||
|
img_prompt,
|
||||||
|
pooling_params=pooling_params,
|
||||||
|
)
|
||||||
|
output = pooler_output[0].outputs
|
||||||
|
|
||||||
|
# verify the output is formatted as expected for this plugin
|
||||||
|
assert all(
|
||||||
|
hasattr(output, attr)
|
||||||
|
for attr in ["type", "format", "data", "request_id"])
|
||||||
|
|
||||||
|
# We just check that the output is a valid base64 string.
|
||||||
|
# Raises an exception and fails the test if the string is corrupted.
|
||||||
|
base64.b64decode(output.data)
|
||||||
|
|||||||
@ -171,6 +171,7 @@ class ModelImpl(str, enum.Enum):
|
|||||||
AUTO = "auto"
|
AUTO = "auto"
|
||||||
VLLM = "vllm"
|
VLLM = "vllm"
|
||||||
TRANSFORMERS = "transformers"
|
TRANSFORMERS = "transformers"
|
||||||
|
TERRATORCH = "terratorch"
|
||||||
|
|
||||||
|
|
||||||
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||||
@ -496,7 +497,9 @@ class ModelConfig:
|
|||||||
back to the Transformers implementation if no vLLM implementation is
|
back to the Transformers implementation if no vLLM implementation is
|
||||||
available.\n
|
available.\n
|
||||||
- "vllm" will use the vLLM model implementation.\n
|
- "vllm" will use the vLLM model implementation.\n
|
||||||
- "transformers" will use the Transformers model implementation."""
|
- "transformers" will use the Transformers model implementation.\n
|
||||||
|
- "terratorch" will use the TerraTorch model implementation.
|
||||||
|
"""
|
||||||
override_attention_dtype: Optional[str] = None
|
override_attention_dtype: Optional[str] = None
|
||||||
"""Override dtype for attention"""
|
"""Override dtype for attention"""
|
||||||
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
|
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
|
||||||
|
|||||||
@ -184,10 +184,11 @@ _EMBEDDING_MODELS = {
|
|||||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||||
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
|
# Technically Terratorch models work on images, both in
|
||||||
# input and output. I am adding it here because it piggybacks on embedding
|
# input and output. I am adding it here because it piggy-backs on embedding
|
||||||
# models for the time being.
|
# models for the time being.
|
||||||
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
|
"PrithviGeoSpatialMAE": ("terratorch", "Terratorch"),
|
||||||
|
"Terratorch": ("terratorch", "Terratorch"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_CROSS_ENCODER_MODELS = {
|
_CROSS_ENCODER_MODELS = {
|
||||||
@ -639,6 +640,9 @@ class _ModelRegistry:
|
|||||||
model_info = self._try_inspect_model_cls(arch)
|
model_info = self._try_inspect_model_cls(arch)
|
||||||
if model_info is not None:
|
if model_info is not None:
|
||||||
return (model_info, arch)
|
return (model_info, arch)
|
||||||
|
elif model_config.model_impl == ModelImpl.TERRATORCH:
|
||||||
|
model_info = self._try_inspect_model_cls("Terratorch")
|
||||||
|
return (model_info, "Terratorch")
|
||||||
|
|
||||||
# Fallback to transformers impl (after resolving convert_type)
|
# Fallback to transformers impl (after resolving convert_type)
|
||||||
if (all(arch not in self.models for arch in architectures)
|
if (all(arch not in self.models for arch in architectures)
|
||||||
@ -687,6 +691,11 @@ class _ModelRegistry:
|
|||||||
model_cls = self._try_load_model_cls(arch)
|
model_cls = self._try_load_model_cls(arch)
|
||||||
if model_cls is not None:
|
if model_cls is not None:
|
||||||
return (model_cls, arch)
|
return (model_cls, arch)
|
||||||
|
elif model_config.model_impl == ModelImpl.TERRATORCH:
|
||||||
|
arch = "Terratorch"
|
||||||
|
model_cls = self._try_load_model_cls(arch)
|
||||||
|
if model_cls is not None:
|
||||||
|
return (model_cls, arch)
|
||||||
|
|
||||||
# Fallback to transformers impl (after resolving convert_type)
|
# Fallback to transformers impl (after resolving convert_type)
|
||||||
if (all(arch not in self.models for arch in architectures)
|
if (all(arch not in self.models for arch in architectures)
|
||||||
|
|||||||
@ -15,13 +15,16 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
"""Wrapper around `Terratorch` models"""
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from terratorch.vllm import (DummyDataGenerator, InferenceRunner,
|
||||||
|
InputDefinition, InputTypeEnum)
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -29,6 +32,7 @@ from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||||
MultiModalDataDict, MultiModalFieldConfig,
|
MultiModalDataDict, MultiModalFieldConfig,
|
||||||
MultiModalInputs, MultiModalKwargsItems,
|
MultiModalInputs, MultiModalKwargsItems,
|
||||||
@ -45,52 +49,46 @@ from .interfaces import (IsAttentionFree, MultiModalEmbeddings,
|
|||||||
from .interfaces_base import default_pooling_type
|
from .interfaces_base import default_pooling_type
|
||||||
|
|
||||||
|
|
||||||
def _prithvi_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
def _terratorch_field_names(pretrained_cfg: dict):
|
||||||
# This model receives in input a multi-dimensional tensor representing
|
input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||||
# a single image patch and therefore it is not to be split
|
return set(input_definition.data.keys())
|
||||||
# into multiple elements, but rather to be considered a single one.
|
|
||||||
# Hence, the decision of using a MultiModalSharedField.
|
|
||||||
# The expected shape is (num_channels, width, height).
|
|
||||||
|
|
||||||
# This model however allows the user to also submit multiple image
|
|
||||||
# patches as a batch, adding a further dimension to the above shape.
|
|
||||||
# At this stage we only support submitting one patch per request and
|
|
||||||
# batching is achieved via vLLM batching.
|
|
||||||
# TODO (christian-pinto): enable support for multi patch requests
|
|
||||||
# in tandem with vLLM batching.
|
|
||||||
return dict(
|
|
||||||
pixel_values=MultiModalFieldConfig.shared(batch_size=1,
|
|
||||||
modality="image"),
|
|
||||||
location_coords=MultiModalFieldConfig.shared(batch_size=1,
|
|
||||||
modality="image"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PrithviGeoSpatialMAEMultiModalDataParser(MultiModalDataParser):
|
def _terratorch_field_factory(
|
||||||
|
pretrained_cfg: dict
|
||||||
|
) -> Callable[
|
||||||
|
[Mapping[str, torch.Tensor]],
|
||||||
|
Mapping[str, MultiModalFieldConfig],
|
||||||
|
]:
|
||||||
|
|
||||||
def _parse_image_data(
|
def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||||
self,
|
input_definition = InputDefinition(**pretrained_cfg["input"])
|
||||||
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
fields = {}
|
||||||
) -> Optional[ModalityDataItems[Any, Any]]:
|
for input_name, input in input_definition.data.items():
|
||||||
if isinstance(data, dict):
|
if input.type == InputTypeEnum.tensor:
|
||||||
return DictEmbeddingItems(
|
fields[input_name] = "image"
|
||||||
data,
|
|
||||||
modality="image",
|
|
||||||
required_fields={"pixel_values", "location_coords"},
|
|
||||||
fields_factory=_prithvi_field_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
return super()._parse_image_data(data)
|
mm_fields_config = {}
|
||||||
|
for field_name, field_modality in fields.items():
|
||||||
|
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
|
||||||
|
batch_size=1, modality=field_modality)
|
||||||
|
return mm_fields_config
|
||||||
|
|
||||||
|
return _terratorch_field_config
|
||||||
|
|
||||||
|
|
||||||
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
|
class TerratorchProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"image": None}
|
return {"image": None}
|
||||||
|
|
||||||
|
|
||||||
class PrithviGeoSpatialMAEInputBuilder(
|
class TerratorchInputBuilder(BaseDummyInputsBuilder[TerratorchProcessingInfo]):
|
||||||
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
|
|
||||||
|
def __init__(self, info: TerratorchProcessingInfo):
|
||||||
|
super().__init__(info)
|
||||||
|
self.dummy_data_generator = DummyDataGenerator(
|
||||||
|
self.info.get_hf_config().to_dict()["pretrained_cfg"])
|
||||||
|
|
||||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
return ""
|
return ""
|
||||||
@ -100,29 +98,57 @@ class PrithviGeoSpatialMAEInputBuilder(
|
|||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> MultiModalDataDict:
|
) -> MultiModalDataDict:
|
||||||
# This model input is fixed and is in the form of a torch Tensor.
|
# Dummy data is generated based on the 'input' section
|
||||||
# The size of pixel_values might change in the cases where we resize
|
# defined in the HF configuration file
|
||||||
# the input but never exceeds the dimensions below.
|
return self.dummy_data_generator.get_dummy_mm_data()
|
||||||
image_data = {
|
|
||||||
"pixel_values": torch.full((6, 512, 512), 1.0,
|
|
||||||
dtype=torch.float16),
|
|
||||||
"location_coords": torch.full((1, 2), 1.0, dtype=torch.float16),
|
|
||||||
}
|
|
||||||
|
|
||||||
return {"image": image_data}
|
|
||||||
|
|
||||||
|
|
||||||
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
class TerratorchMultiModalDataParser(MultiModalDataParser):
|
||||||
|
|
||||||
|
def __init__(self, pretrained_cfg: dict, *args, **kwargs):
|
||||||
|
self._pretrained_cfg = pretrained_cfg
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _parse_image_data(
|
||||||
|
self,
|
||||||
|
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||||
|
) -> Optional[ModalityDataItems[Any, Any]]:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
|
||||||
|
terratorch_fields = _terratorch_field_names(self._pretrained_cfg)
|
||||||
|
|
||||||
|
return DictEmbeddingItems(
|
||||||
|
data,
|
||||||
|
modality="image",
|
||||||
|
required_fields=terratorch_fields,
|
||||||
|
fields_factory=_terratorch_field_factory(self._pretrained_cfg),
|
||||||
|
)
|
||||||
|
|
||||||
|
return super()._parse_image_data(data)
|
||||||
|
|
||||||
|
|
||||||
|
class TerratorchMultiModalProcessor(BaseMultiModalProcessor):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
info: TerratorchProcessingInfo,
|
||||||
|
dummy_inputs: "BaseDummyInputsBuilder[TerratorchProcessingInfo]",
|
||||||
|
*,
|
||||||
|
cache: Optional[MultiModalProcessorOnlyCache] = None) -> None:
|
||||||
|
|
||||||
|
self.pretrained_cfg = info.get_hf_config().to_dict()["pretrained_cfg"]
|
||||||
|
super().__init__(info=info, dummy_inputs=dummy_inputs, cache=cache)
|
||||||
|
|
||||||
def _get_data_parser(self) -> MultiModalDataParser:
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
return PrithviGeoSpatialMAEMultiModalDataParser()
|
return TerratorchMultiModalDataParser(
|
||||||
|
pretrained_cfg=self.pretrained_cfg)
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
self,
|
self,
|
||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
return _prithvi_field_config(hf_inputs)
|
return _terratorch_field_factory(self.pretrained_cfg)(hf_inputs)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
self,
|
self,
|
||||||
@ -173,13 +199,11 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
|||||||
|
|
||||||
@default_pooling_type("All")
|
@default_pooling_type("All")
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
PrithviGeoSpatialMAEMultiModalProcessor,
|
TerratorchMultiModalProcessor,
|
||||||
info=PrithviGeoSpatialMAEProcessingInfo,
|
info=TerratorchProcessingInfo,
|
||||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder,
|
dummy_inputs=TerratorchInputBuilder,
|
||||||
)
|
)
|
||||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||||
"""Prithvi Masked Autoencoder"""
|
|
||||||
|
|
||||||
supports_multimodal_raw_input_only = True
|
supports_multimodal_raw_input_only = True
|
||||||
is_pooling_model = True
|
is_pooling_model = True
|
||||||
|
|
||||||
@ -190,43 +214,13 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
|||||||
|
|
||||||
raise ValueError("Only image modality is supported")
|
raise ValueError("Only image modality is supported")
|
||||||
|
|
||||||
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
|
|
||||||
# We might be able/need to support different tasks with this same model
|
|
||||||
if config["task_args"]["task"] == "SemanticSegmentationTask":
|
|
||||||
from terratorch.cli_tools import SemanticSegmentationTask
|
|
||||||
|
|
||||||
task = SemanticSegmentationTask(
|
|
||||||
config["model_args"],
|
|
||||||
config["task_args"]["model_factory"],
|
|
||||||
loss=config["task_args"]["loss"],
|
|
||||||
lr=config["task_args"]["lr"],
|
|
||||||
ignore_index=config["task_args"]["ignore_index"],
|
|
||||||
optimizer=config["task_args"]["optimizer"],
|
|
||||||
optimizer_hparams=config["optimizer_params"],
|
|
||||||
scheduler=config["task_args"]["scheduler"],
|
|
||||||
scheduler_hparams=config["scheduler_params"],
|
|
||||||
plot_on_val=config["task_args"]["plot_on_val"],
|
|
||||||
freeze_decoder=config["task_args"]["freeze_decoder"],
|
|
||||||
freeze_backbone=config["task_args"]["freeze_backbone"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return task.model
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# the actual model is dynamically instantiated using terratorch
|
config = vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"]
|
||||||
# allowing us to perform changes to the model architecture
|
|
||||||
# at startup time (e.g., change the model decoder class.)
|
self.inference_runner = InferenceRunner(config)
|
||||||
self.model = self._instantiate_model(
|
self.model = self.inference_runner.model
|
||||||
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
|
|
||||||
if self.model is None:
|
|
||||||
raise ValueError(
|
|
||||||
"Unsupported task. "
|
|
||||||
"Only SemanticSegmentationTask is supported for now "
|
|
||||||
"by PrithviGeospatialMAE.")
|
|
||||||
|
|
||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
@ -234,23 +228,6 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
|||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler(
|
||||||
{"encode": Pooler.for_encode(pooler_config)}, )
|
{"encode": Pooler.for_encode(pooler_config)}, )
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_data(
|
|
||||||
self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
|
||||||
if not isinstance(pixel_values, torch.Tensor):
|
|
||||||
raise ValueError(f"Incorrect type of pixel_values. "
|
|
||||||
f"Got type: {type(pixel_values)}")
|
|
||||||
|
|
||||||
location_coords = kwargs.pop("location_coords", None)
|
|
||||||
if not isinstance(location_coords, torch.Tensor):
|
|
||||||
raise ValueError(f"Incorrect type of location_coords. "
|
|
||||||
f"Got type: {type(location_coords)}")
|
|
||||||
location_coords = torch.unbind(location_coords, dim=0)[0]
|
|
||||||
if location_coords.shape == torch.Size([0]):
|
|
||||||
location_coords = None
|
|
||||||
|
|
||||||
return pixel_values, location_coords
|
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
@ -270,10 +247,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
|||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
**kwargs: object,
|
**kwargs: object,
|
||||||
):
|
):
|
||||||
pixel_values, location_coords = (
|
model_output = self.inference_runner.forward(**kwargs)
|
||||||
self._parse_and_validate_multimodal_data(**kwargs))
|
|
||||||
model_output = self.model(pixel_values,
|
|
||||||
location_coords=location_coords)
|
|
||||||
|
|
||||||
return model_output.output
|
return model_output.output
|
||||||
|
|
||||||
@ -283,9 +257,12 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
|||||||
model_buffers = dict(self.named_buffers())
|
model_buffers = dict(self.named_buffers())
|
||||||
loaded_buffers = []
|
loaded_buffers = []
|
||||||
for key, value in weights:
|
for key, value in weights:
|
||||||
|
if isinstance(value, (dict, OrderedDict)):
|
||||||
if key == "state_dict":
|
if key == "state_dict":
|
||||||
weights_to_parse = value
|
weights_to_parse = value
|
||||||
for name, weight in weights_to_parse.items():
|
for name, weight in weights_to_parse.items():
|
||||||
|
name = f"inference_runner.{name}"
|
||||||
|
|
||||||
if "pos_embed" in name:
|
if "pos_embed" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -306,6 +283,9 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
|||||||
params_list.append((name, weight))
|
params_list.append((name, weight))
|
||||||
break
|
break
|
||||||
|
|
||||||
|
elif isinstance(value, torch.Tensor):
|
||||||
|
params_list.append((f"inference_runner.model.{key}", value))
|
||||||
|
|
||||||
# Load the remaining model parameters
|
# Load the remaining model parameters
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
autoloaded_weights = loader.load_weights(params_list)
|
autoloaded_weights = loader.load_weights(params_list)
|
||||||
Loading…
x
Reference in New Issue
Block a user