mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 19:29:09 +08:00
[Model] IBM/NASA Prithvi Geospatial model (#12830)
This commit is contained in:
parent
3ee696a63d
commit
974dfd4971
530
examples/offline_inference/prithvi_geospatial_mae.py
Normal file
530
examples/offline_inference/prithvi_geospatial_mae.py
Normal file
@ -0,0 +1,530 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
This is a demo script showing how to use the
|
||||||
|
PrithviGeospatialMAE model with vLLM
|
||||||
|
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
|
||||||
|
|
||||||
|
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
|
||||||
|
|
||||||
|
The requirements for running this script are:
|
||||||
|
- Installing [terratorch, albumentations, rasterio] in your python environment
|
||||||
|
- downloading the model weights in a 'model' folder local to the script
|
||||||
|
(temporary measure until the proper config.json file is uploaded to HF)
|
||||||
|
- download an input example image (India_900498_S2Hand.tif) and place it in
|
||||||
|
the same folder with the script (or specify with the --data_file argument)
|
||||||
|
|
||||||
|
Run the example:
|
||||||
|
python prithvi_geospatial_mae.py
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import albumentations
|
||||||
|
import numpy as np
|
||||||
|
import rasterio
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
NO_DATA = -9999
|
||||||
|
NO_DATA_FLOAT = 0.0001
|
||||||
|
OFFSET = 0
|
||||||
|
PERCENTILE = 99
|
||||||
|
|
||||||
|
model_config = """{
|
||||||
|
"architectures": ["PrithviGeoSpatialMAE"],
|
||||||
|
"num_classes": 0,
|
||||||
|
"pretrained_cfg": {
|
||||||
|
"task_args": {
|
||||||
|
"task": "SemanticSegmentationTask",
|
||||||
|
"model_factory": "EncoderDecoderFactory",
|
||||||
|
"loss": "ce",
|
||||||
|
"ignore_index": -1,
|
||||||
|
"lr": 0.001,
|
||||||
|
"freeze_backbone": false,
|
||||||
|
"freeze_decoder": false,
|
||||||
|
"plot_on_val": 10,
|
||||||
|
"optimizer": "AdamW",
|
||||||
|
"scheduler": "CosineAnnealingLR"
|
||||||
|
},
|
||||||
|
"model_args": {
|
||||||
|
"backbone_pretrained": false,
|
||||||
|
"backbone": "prithvi_eo_v2_300_tl",
|
||||||
|
"decoder": "UperNetDecoder",
|
||||||
|
"decoder_channels": 256,
|
||||||
|
"decoder_scale_modules": true,
|
||||||
|
"num_classes": 2,
|
||||||
|
"rescale": true,
|
||||||
|
"backbone_bands": [
|
||||||
|
"BLUE",
|
||||||
|
"GREEN",
|
||||||
|
"RED",
|
||||||
|
"NIR_NARROW",
|
||||||
|
"SWIR_1",
|
||||||
|
"SWIR_2"
|
||||||
|
],
|
||||||
|
"head_dropout": 0.1,
|
||||||
|
"necks": [
|
||||||
|
{
|
||||||
|
"name": "SelectIndices",
|
||||||
|
"indices": [
|
||||||
|
5,
|
||||||
|
11,
|
||||||
|
17,
|
||||||
|
23
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "ReshapeTokensToImage"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"optimizer_params" : {
|
||||||
|
"lr": 5.0e-05,
|
||||||
|
"betas": [0.9, 0.999],
|
||||||
|
"eps": [1.0e-08],
|
||||||
|
"weight_decay": 0.05,
|
||||||
|
"amsgrad": false,
|
||||||
|
"maximize": false,
|
||||||
|
"capturable": false,
|
||||||
|
"differentiable": false
|
||||||
|
},
|
||||||
|
"scheduler_params" : {
|
||||||
|
"T_max": 50,
|
||||||
|
"eta_min": 0,
|
||||||
|
"last_epoch": -1,
|
||||||
|
"verbose": "deprecated"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
|
||||||
|
"torch_dtype": "float32"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Temporarily creating the "config.json" for the model.
|
||||||
|
# This is going to disappear once the correct config.json is available on HF
|
||||||
|
with open(os.path.join(os.path.dirname(__file__), "./model/config.json"),
|
||||||
|
'w') as config_file:
|
||||||
|
config_file.write(model_config)
|
||||||
|
|
||||||
|
datamodule_config = {
|
||||||
|
'bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'],
|
||||||
|
'batch_size':
|
||||||
|
16,
|
||||||
|
'constant_scale':
|
||||||
|
0.0001,
|
||||||
|
'data_root':
|
||||||
|
'/dccstor/geofm-finetuning/datasets/sen1floods11',
|
||||||
|
'drop_last':
|
||||||
|
True,
|
||||||
|
'no_data_replace':
|
||||||
|
0.0,
|
||||||
|
'no_label_replace':
|
||||||
|
-1,
|
||||||
|
'num_workers':
|
||||||
|
8,
|
||||||
|
'test_transform': [
|
||||||
|
albumentations.Resize(always_apply=False,
|
||||||
|
height=448,
|
||||||
|
interpolation=1,
|
||||||
|
p=1,
|
||||||
|
width=448),
|
||||||
|
albumentations.pytorch.ToTensorV2(transpose_mask=False,
|
||||||
|
always_apply=True,
|
||||||
|
p=1.0)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PrithviMAE:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
print("Initializing PrithviMAE model")
|
||||||
|
self.model = LLM(model=os.path.join(os.path.dirname(__file__),
|
||||||
|
"./model"),
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
dtype="float32")
|
||||||
|
|
||||||
|
def run(self, input_data, location_coords):
|
||||||
|
print("################ Running inference on vLLM ##############")
|
||||||
|
# merge the inputs into one data structure
|
||||||
|
mm_data = {
|
||||||
|
"pixel_values":
|
||||||
|
torch.empty(0) if input_data is None else input_data,
|
||||||
|
"location_coords":
|
||||||
|
torch.empty(0) if location_coords is None else location_coords
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
|
||||||
|
|
||||||
|
outputs = self.model.encode(prompt, use_tqdm=False)
|
||||||
|
print(
|
||||||
|
"################ Inference done (it took seconds) ##############"
|
||||||
|
)
|
||||||
|
|
||||||
|
return outputs[0].outputs.data
|
||||||
|
|
||||||
|
|
||||||
|
def generate_datamodule():
|
||||||
|
datamodule = Sen1Floods11NonGeoDataModule(
|
||||||
|
data_root=datamodule_config['data_root'],
|
||||||
|
batch_size=datamodule_config["batch_size"],
|
||||||
|
num_workers=datamodule_config["num_workers"],
|
||||||
|
bands=datamodule_config["bands"],
|
||||||
|
drop_last=datamodule_config["drop_last"],
|
||||||
|
test_transform=datamodule_config["test_transform"
|
||||||
|
""])
|
||||||
|
|
||||||
|
return datamodule
|
||||||
|
|
||||||
|
|
||||||
|
def process_channel_group(orig_img, channels):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
orig_img: torch.Tensor representing original image (reference)
|
||||||
|
with shape = (bands, H, W).
|
||||||
|
channels: list of indices representing RGB channels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor with shape (num_channels, height, width) for original image
|
||||||
|
"""
|
||||||
|
|
||||||
|
orig_img = orig_img[channels, ...]
|
||||||
|
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
|
||||||
|
valid_mask[orig_img == NO_DATA_FLOAT] = False
|
||||||
|
|
||||||
|
# Rescale (enhancing contrast)
|
||||||
|
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
||||||
|
min_value = OFFSET
|
||||||
|
|
||||||
|
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0,
|
||||||
|
1)
|
||||||
|
|
||||||
|
# No data as zeros
|
||||||
|
orig_img[~valid_mask] = 0
|
||||||
|
|
||||||
|
return orig_img
|
||||||
|
|
||||||
|
|
||||||
|
def read_geotiff(file_path: str):
|
||||||
|
"""Read all bands from *file_path* and return image + meta info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: path to image file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray with shape (bands, height, width)
|
||||||
|
meta info dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
with rasterio.open(file_path) as src:
|
||||||
|
img = src.read()
|
||||||
|
meta = src.meta
|
||||||
|
try:
|
||||||
|
coords = src.lnglat()
|
||||||
|
except Exception:
|
||||||
|
# Cannot read coords
|
||||||
|
coords = None
|
||||||
|
|
||||||
|
return img, meta, coords
|
||||||
|
|
||||||
|
|
||||||
|
def save_geotiff(image, output_path: str, meta: dict):
|
||||||
|
"""Save multi-band image in Geotiff file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: np.ndarray with shape (bands, height, width)
|
||||||
|
output_path: path where to save the image
|
||||||
|
meta: dict with meta info.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with rasterio.open(output_path, "w", **meta) as dest:
|
||||||
|
for i in range(image.shape[0]):
|
||||||
|
dest.write(image[i, :, :], i + 1)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_np_uint8(float_image: torch.Tensor):
|
||||||
|
image = float_image.numpy() * 255.0
|
||||||
|
image = image.astype(dtype=np.uint8)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_example(
|
||||||
|
file_paths: List[str],
|
||||||
|
mean: List[float] = None,
|
||||||
|
std: List[float] = None,
|
||||||
|
indices: Union[list[int], None] = None,
|
||||||
|
):
|
||||||
|
"""Build an input example by loading images in *file_paths*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_paths: list of file paths .
|
||||||
|
mean: list containing mean values for each band in the images
|
||||||
|
in *file_paths*.
|
||||||
|
std: list containing std values for each band in the images
|
||||||
|
in *file_paths*.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array containing created example
|
||||||
|
list of meta info for each image in *file_paths*
|
||||||
|
"""
|
||||||
|
|
||||||
|
imgs = []
|
||||||
|
metas = []
|
||||||
|
temporal_coords = []
|
||||||
|
location_coords = []
|
||||||
|
|
||||||
|
for file in file_paths:
|
||||||
|
img, meta, coords = read_geotiff(file)
|
||||||
|
|
||||||
|
# Rescaling (don't normalize on nodata)
|
||||||
|
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
||||||
|
if indices is not None:
|
||||||
|
img = img[..., indices]
|
||||||
|
if mean is not None and std is not None:
|
||||||
|
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
||||||
|
|
||||||
|
imgs.append(img)
|
||||||
|
metas.append(meta)
|
||||||
|
if coords is not None:
|
||||||
|
location_coords.append(coords)
|
||||||
|
|
||||||
|
try:
|
||||||
|
match = re.search(r'(\d{7,8}T\d{6})', file)
|
||||||
|
if match:
|
||||||
|
year = int(match.group(1)[:4])
|
||||||
|
julian_day = match.group(1).split('T')[0][4:]
|
||||||
|
if len(julian_day) == 3:
|
||||||
|
julian_day = int(julian_day)
|
||||||
|
else:
|
||||||
|
julian_day = datetime.datetime.strptime(
|
||||||
|
julian_day, '%m%d').timetuple().tm_yday
|
||||||
|
temporal_coords.append([year, julian_day])
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Could not extract timestamp for {file} ({e})')
|
||||||
|
|
||||||
|
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
||||||
|
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
|
||||||
|
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
||||||
|
|
||||||
|
return imgs, temporal_coords, location_coords, metas
|
||||||
|
|
||||||
|
|
||||||
|
def run_model(input_data,
|
||||||
|
temporal_coords,
|
||||||
|
location_coords,
|
||||||
|
model,
|
||||||
|
datamodule,
|
||||||
|
img_size,
|
||||||
|
lightning_model=None):
|
||||||
|
# Reflect pad if not divisible by img_size
|
||||||
|
original_h, original_w = input_data.shape[-2:]
|
||||||
|
pad_h = (img_size - (original_h % img_size)) % img_size
|
||||||
|
pad_w = (img_size - (original_w % img_size)) % img_size
|
||||||
|
input_data = np.pad(input_data,
|
||||||
|
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
|
||||||
|
mode="reflect")
|
||||||
|
|
||||||
|
# Build sliding window
|
||||||
|
batch_size = 1
|
||||||
|
batch = torch.tensor(input_data, device="cpu")
|
||||||
|
windows = (batch.unfold(3, img_size,
|
||||||
|
img_size).unfold(4, img_size, img_size))
|
||||||
|
h1, w1 = windows.shape[3:5]
|
||||||
|
windows = rearrange(windows,
|
||||||
|
"b c t h1 w1 h w -> (b h1 w1) c t h w",
|
||||||
|
h=img_size,
|
||||||
|
w=img_size)
|
||||||
|
|
||||||
|
# Split into batches if number of windows > batch_size
|
||||||
|
num_batches = windows.shape[0] // batch_size if windows.shape[
|
||||||
|
0] > batch_size else 1
|
||||||
|
windows = torch.tensor_split(windows, num_batches, dim=0)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device('cuda')
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
|
||||||
|
if temporal_coords:
|
||||||
|
temporal_coords = torch.tensor(temporal_coords,
|
||||||
|
device=device).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
temporal_coords = None
|
||||||
|
if location_coords:
|
||||||
|
location_coords = torch.tensor(location_coords[0],
|
||||||
|
device=device).unsqueeze(0)
|
||||||
|
else:
|
||||||
|
location_coords = None
|
||||||
|
|
||||||
|
# Run model
|
||||||
|
pred_imgs = []
|
||||||
|
for x in windows:
|
||||||
|
# Apply standardization
|
||||||
|
x = datamodule.test_transform(
|
||||||
|
image=x.squeeze().numpy().transpose(1, 2, 0))
|
||||||
|
x = datamodule.aug(x)['image']
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
x = x.to(device)
|
||||||
|
pred = model.run(x, location_coords=location_coords)
|
||||||
|
if lightning_model:
|
||||||
|
pred_lightning = lightning_model(
|
||||||
|
x,
|
||||||
|
temporal_coords=temporal_coords,
|
||||||
|
location_coords=location_coords)
|
||||||
|
pred_lightning = pred_lightning.output.detach().cpu()
|
||||||
|
if not torch.equal(pred, pred_lightning):
|
||||||
|
print("Inference output is not equal")
|
||||||
|
y_hat = pred.argmax(dim=1)
|
||||||
|
|
||||||
|
y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(),
|
||||||
|
size=img_size,
|
||||||
|
mode="nearest")
|
||||||
|
|
||||||
|
pred_imgs.append(y_hat)
|
||||||
|
|
||||||
|
pred_imgs = torch.concat(pred_imgs, dim=0)
|
||||||
|
|
||||||
|
# Build images from patches
|
||||||
|
pred_imgs = rearrange(
|
||||||
|
pred_imgs,
|
||||||
|
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
|
||||||
|
h=img_size,
|
||||||
|
w=img_size,
|
||||||
|
b=1,
|
||||||
|
c=1,
|
||||||
|
h1=h1,
|
||||||
|
w1=w1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cut padded area back to original size
|
||||||
|
pred_imgs = pred_imgs[..., :original_h, :original_w]
|
||||||
|
|
||||||
|
# Squeeze (batch size 1)
|
||||||
|
pred_imgs = pred_imgs[0]
|
||||||
|
|
||||||
|
return pred_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
data_file: str,
|
||||||
|
output_dir: str,
|
||||||
|
rgb_outputs: bool,
|
||||||
|
input_indices: list[int] = None,
|
||||||
|
):
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Load model ---------------------------------------------------------------
|
||||||
|
|
||||||
|
model_obj = PrithviMAE()
|
||||||
|
datamodule = generate_datamodule()
|
||||||
|
img_size = 256 # Size of Sen1Floods11
|
||||||
|
|
||||||
|
# Loading data -------------------------------------------------------------
|
||||||
|
|
||||||
|
input_data, temporal_coords, location_coords, meta_data = load_example(
|
||||||
|
file_paths=[data_file],
|
||||||
|
indices=input_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
meta_data = meta_data[0] # only one image
|
||||||
|
|
||||||
|
if input_data.mean() > 1:
|
||||||
|
input_data = input_data / 10000 # Convert to range 0-1
|
||||||
|
|
||||||
|
# Running model ------------------------------------------------------------
|
||||||
|
|
||||||
|
channels = [
|
||||||
|
datamodule_config['bands'].index(b) for b in ["RED", "GREEN", "BLUE"]
|
||||||
|
] # BGR -> RGB
|
||||||
|
|
||||||
|
pred = run_model(input_data, temporal_coords, location_coords, model_obj,
|
||||||
|
datamodule, img_size)
|
||||||
|
|
||||||
|
# Save pred
|
||||||
|
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
||||||
|
pred_file = os.path.join(
|
||||||
|
output_dir,
|
||||||
|
f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||||
|
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
|
||||||
|
|
||||||
|
# Save image + pred
|
||||||
|
meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0)
|
||||||
|
|
||||||
|
if input_data.mean() < 1:
|
||||||
|
input_data = input_data * 10000 # Scale to 0-10000
|
||||||
|
|
||||||
|
rgb_orig = process_channel_group(
|
||||||
|
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
|
||||||
|
channels=channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
pred[pred == 0.] = np.nan
|
||||||
|
img_pred = rgb_orig * 0.7 + pred * 0.3
|
||||||
|
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
|
||||||
|
|
||||||
|
img_pred_file = os.path.join(
|
||||||
|
output_dir,
|
||||||
|
f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||||
|
save_geotiff(
|
||||||
|
image=_convert_np_uint8(img_pred),
|
||||||
|
output_path=img_pred_file,
|
||||||
|
meta=meta_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save image rgb
|
||||||
|
if rgb_outputs:
|
||||||
|
rgb_file = os.path.join(
|
||||||
|
output_dir, "original_rgb_"
|
||||||
|
f"{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||||
|
save_geotiff(
|
||||||
|
image=_convert_np_uint8(rgb_orig),
|
||||||
|
output_path=rgb_file,
|
||||||
|
meta=meta_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_file",
|
||||||
|
type=str,
|
||||||
|
default="./India_900498_S2Hand.tif",
|
||||||
|
help="Path to the file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir",
|
||||||
|
type=str,
|
||||||
|
default="output",
|
||||||
|
help="Path to the directory where to save outputs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_indices",
|
||||||
|
default=[1, 2, 3, 8, 11, 12],
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
help=
|
||||||
|
"0-based indices of the six Prithvi channels to be selected from the "
|
||||||
|
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rgb_outputs",
|
||||||
|
action="store_true",
|
||||||
|
help="If present, output files will only contain RGB channels. "
|
||||||
|
"Otherwise, all bands will be saved.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(**vars(args))
|
||||||
@ -214,6 +214,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
|||||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
||||||
|
# The model on Huggingface is currently being updated,
|
||||||
|
# hence I temporarily mark it as not available online
|
||||||
|
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
|
||||||
|
is_available_online=False),
|
||||||
}
|
}
|
||||||
|
|
||||||
_CROSS_ENCODER_EXAMPLE_MODELS = {
|
_CROSS_ENCODER_EXAMPLE_MODELS = {
|
||||||
|
|||||||
@ -320,9 +320,14 @@ class PlaceholderAttentionMetadataBuilder(
|
|||||||
-1 if cuda graph is not used.
|
-1 if cuda graph is not used.
|
||||||
batch_size: The maybe padded batch size.
|
batch_size: The maybe padded batch size.
|
||||||
"""
|
"""
|
||||||
for inter_data in self.input_builder.inter_data_list:
|
|
||||||
self._add_seq_group(inter_data,
|
# Some input builders such as ModelInputForCPUBuilder do not have the
|
||||||
self.input_builder.chunked_prefill_enabled)
|
# "inter_data_list" attribute.
|
||||||
|
# Let's check inter_data_list exists before we reference it.
|
||||||
|
if hasattr(self.input_builder, "inter_data_list"):
|
||||||
|
for inter_data in self.input_builder.inter_data_list:
|
||||||
|
self._add_seq_group(inter_data,
|
||||||
|
self.input_builder.chunked_prefill_enabled)
|
||||||
|
|
||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
use_captured_graph = cuda_graph_pad_size != -1
|
use_captured_graph = cuda_graph_pad_size != -1
|
||||||
|
|||||||
@ -254,8 +254,14 @@ class InputPreprocessor:
|
|||||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||||
returning the corresponding token IDs and metadata.
|
returning the corresponding token IDs and metadata.
|
||||||
"""
|
"""
|
||||||
tokenizer_group = self.get_tokenizer_group()
|
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||||
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
|
# initialized without a tokenizer while using also multi-modal
|
||||||
|
# input.
|
||||||
|
if not self.tokenizer:
|
||||||
|
tokenizer = None
|
||||||
|
else:
|
||||||
|
tokenizer_group = self.get_tokenizer_group()
|
||||||
|
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
|
||||||
|
|
||||||
mm_processor = self.mm_registry.create_processor(
|
mm_processor = self.mm_registry.create_processor(
|
||||||
self.model_config, tokenizer)
|
self.model_config, tokenizer)
|
||||||
@ -273,9 +279,15 @@ class InputPreprocessor:
|
|||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[LoRARequest],
|
||||||
) -> MultiModalInputs:
|
) -> MultiModalInputs:
|
||||||
"""Async version of :meth:`_process_multimodal`."""
|
"""Async version of :meth:`_process_multimodal`."""
|
||||||
tokenizer_group = self.get_tokenizer_group()
|
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request
|
# initialized without a tokenizer while using also multi-modal
|
||||||
)
|
# input.
|
||||||
|
if not self.tokenizer:
|
||||||
|
tokenizer = None
|
||||||
|
else:
|
||||||
|
tokenizer_group = self.get_tokenizer_group()
|
||||||
|
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
|
||||||
|
lora_request)
|
||||||
|
|
||||||
mm_processor = self.mm_registry.create_processor(
|
mm_processor = self.mm_registry.create_processor(
|
||||||
self.model_config, tokenizer)
|
self.model_config, tokenizer)
|
||||||
|
|||||||
238
vllm/model_executor/models/prithvi_geospatial_mae.py
Normal file
238
vllm/model_executor/models/prithvi_geospatial_mae.py
Normal file
@ -0,0 +1,238 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
# Copyright 2025 The vLLM team.
|
||||||
|
# Copyright 2025 IBM.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||||
|
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import BatchFeature
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
||||||
|
SupportsMultiModal)
|
||||||
|
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
|
MultiModalInputs, MultiModalKwargs)
|
||||||
|
from vllm.multimodal.parse import MultiModalDataItems
|
||||||
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
BaseProcessingInfo, PromptReplacement)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
|
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||||
|
PoolingSequenceGroupOutput)
|
||||||
|
|
||||||
|
|
||||||
|
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PrithviGeoSpatialMAEInputBuilder(
|
||||||
|
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
|
||||||
|
|
||||||
|
def get_dummy_processor_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> ProcessorInputs:
|
||||||
|
return ProcessorInputs(
|
||||||
|
prompt_text="",
|
||||||
|
# This model input is fixed and is in the form of a torch Tensor.
|
||||||
|
# The size of pixel_values might change in the cases where we resize
|
||||||
|
# the input but never exceeds the dimensions below.
|
||||||
|
mm_data={
|
||||||
|
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
|
||||||
|
"location_coords": torch.full((1, 2), 1.0)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return dict(
|
||||||
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
|
location_coords=MultiModalFieldConfig.batched("image"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_prompt_replacements(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> list[PromptReplacement]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, list[int]],
|
||||||
|
mm_data: MultiModalDataDict,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> MultiModalInputs:
|
||||||
|
mm_kwargs = {}
|
||||||
|
|
||||||
|
for k, v in mm_data.items():
|
||||||
|
mm_kwargs[k] = v
|
||||||
|
|
||||||
|
return MultiModalInputs(
|
||||||
|
type="multimodal",
|
||||||
|
prompt=prompt,
|
||||||
|
prompt_token_ids=[1],
|
||||||
|
mm_kwargs=MultiModalKwargs(mm_kwargs),
|
||||||
|
mm_placeholders={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
PrithviGeoSpatialMAEMultiModalProcessor,
|
||||||
|
info=PrithviGeoSpatialMAEProcessingInfo,
|
||||||
|
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
|
||||||
|
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||||
|
""" Prithvi Masked Autoencoder"""
|
||||||
|
|
||||||
|
def _instantiate_model(self, config: dict) -> nn.Module | None:
|
||||||
|
|
||||||
|
# We might be able/need to support different tasks with this same model
|
||||||
|
if config["task_args"]["task"] == "SemanticSegmentationTask":
|
||||||
|
from terratorch.cli_tools import SemanticSegmentationTask
|
||||||
|
task = SemanticSegmentationTask(
|
||||||
|
config["model_args"],
|
||||||
|
config["task_args"]["model_factory"],
|
||||||
|
loss=config["task_args"]["loss"],
|
||||||
|
lr=config["task_args"]["lr"],
|
||||||
|
ignore_index=config["task_args"]["ignore_index"],
|
||||||
|
optimizer=config["task_args"]["optimizer"],
|
||||||
|
optimizer_hparams=config["optimizer_params"],
|
||||||
|
scheduler=config["task_args"]["scheduler"],
|
||||||
|
scheduler_hparams=config["scheduler_params"],
|
||||||
|
plot_on_val=config["task_args"]["plot_on_val"],
|
||||||
|
freeze_decoder=config["task_args"]["freeze_decoder"],
|
||||||
|
freeze_backbone=config["task_args"]["freeze_backbone"])
|
||||||
|
|
||||||
|
return task.model
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# the actual model is dynamically instantiated using terratorch
|
||||||
|
# allowing us to perform changes to the model architecture
|
||||||
|
# at startup time (e.g., change the model decoder class.)
|
||||||
|
self.model = self._instantiate_model(
|
||||||
|
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
|
||||||
|
if self.model is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported task."
|
||||||
|
"Only SemanticSegmentationTask is supported for now"
|
||||||
|
"by PrithviGeospatialMAE.")
|
||||||
|
|
||||||
|
def _parse_and_validate_multimodal_data(
|
||||||
|
self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
|
||||||
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
if not isinstance(pixel_values, torch.Tensor):
|
||||||
|
raise ValueError(f"Incorrect type of pixel_values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
pixel_values = torch.unbind(pixel_values, dim=0)[0]
|
||||||
|
|
||||||
|
location_coords = kwargs.pop("location_coords", None)
|
||||||
|
if not isinstance(location_coords, torch.Tensor):
|
||||||
|
raise ValueError(f"Incorrect type of location_coords. "
|
||||||
|
f"Got type: {type(location_coords)}")
|
||||||
|
location_coords = torch.unbind(location_coords, dim=0)[0]
|
||||||
|
if location_coords.shape == torch.Size([0]):
|
||||||
|
location_coords = None
|
||||||
|
|
||||||
|
return pixel_values, location_coords
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
):
|
||||||
|
|
||||||
|
pixel_values, location_coords = (
|
||||||
|
self._parse_and_validate_multimodal_data(**kwargs))
|
||||||
|
model_output = self.model(pixel_values,
|
||||||
|
location_coords=location_coords)
|
||||||
|
|
||||||
|
return model_output.output
|
||||||
|
|
||||||
|
def pooler(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> Optional[PoolerOutput]:
|
||||||
|
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
params_list = []
|
||||||
|
model_buffers = dict(self.named_buffers())
|
||||||
|
loaded_buffers = []
|
||||||
|
for key, value in weights:
|
||||||
|
if key == "state_dict":
|
||||||
|
weights_to_parse = value
|
||||||
|
for name, weight in weights_to_parse.items():
|
||||||
|
if "pos_embed" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "_timm_module." in name:
|
||||||
|
name = name.replace("_timm_module.", "")
|
||||||
|
|
||||||
|
# this model requires a couple of buffers to be loaded
|
||||||
|
# that are not loadable with the AutoWeightsLoader
|
||||||
|
if name in model_buffers:
|
||||||
|
if "_timm_module." in name:
|
||||||
|
name = name.replace("_timm_module.", "")
|
||||||
|
buffer = model_buffers[name]
|
||||||
|
weight_loader = getattr(buffer, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(buffer, weight)
|
||||||
|
loaded_buffers.append(name)
|
||||||
|
else:
|
||||||
|
params_list.append((name, weight))
|
||||||
|
break
|
||||||
|
|
||||||
|
# Load the remaining model parameters
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
autoloaded_weights = loader.load_weights(params_list)
|
||||||
|
|
||||||
|
return autoloaded_weights.union(set(loaded_buffers))
|
||||||
@ -137,6 +137,10 @@ _EMBEDDING_MODELS = {
|
|||||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||||
# [Auto-converted (see adapters.py)]
|
# [Auto-converted (see adapters.py)]
|
||||||
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
|
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
|
||||||
|
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
|
||||||
|
# input and output. I am adding it here because it piggy-backs on embedding
|
||||||
|
# models for the time being.
|
||||||
|
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
|
||||||
}
|
}
|
||||||
|
|
||||||
_CROSS_ENCODER_MODELS = {
|
_CROSS_ENCODER_MODELS = {
|
||||||
|
|||||||
@ -74,7 +74,16 @@ class PoolingModelRunner(
|
|||||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||||
decode_meta = model_input.attn_metadata.decode_metadata
|
decode_meta = model_input.attn_metadata.decode_metadata
|
||||||
virtual_engine = model_input.virtual_engine
|
virtual_engine = model_input.virtual_engine
|
||||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
# Pooling models are (ab-)used also to integrate non text models that
|
||||||
|
# are not autoregressive (PrithviGeosaptialMAE).
|
||||||
|
# These model might not use attention and do not really have a prefill
|
||||||
|
# and decode phase. The model input is processed in one shot and both
|
||||||
|
# decode_metadata and prefill_metadata would be None for such models.
|
||||||
|
# See the PlaceholderAttentionMetadata class.
|
||||||
|
# TODO: Figure out if cuda_graph is of any use for these models and
|
||||||
|
# explore how to leverage it.
|
||||||
|
if (prefill_meta is None and decode_meta is not None
|
||||||
|
and decode_meta.use_cuda_graph):
|
||||||
assert model_input.input_tokens is not None
|
assert model_input.input_tokens is not None
|
||||||
graph_batch_size = model_input.input_tokens.shape[0]
|
graph_batch_size = model_input.input_tokens.shape[0]
|
||||||
model_executable = self.graph_runners[virtual_engine][
|
model_executable = self.graph_runners[virtual_engine][
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user