mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 19:57:08 +08:00
Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Boyuan Feng <boyuan@meta.com> Signed-off-by: Boyuan Feng <fby.1994@gmail.com> Signed-off-by: boyuanfeng <boyuan@meta.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: JartX <sagformas@epdcenter.es> Signed-off-by: Chendi Xue <Chendi.Xue@intel.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: wwl2755 <wangwenlong2755@gmail.com> Signed-off-by: Manoel Marques <manoel.marques@ibm.com> Signed-off-by: Manoel Marques <manoelmrqs@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: pengdrumli <pengdrumli@tencent.com> Signed-off-by: windsonsea <haifeng.yao@daocloud.io> Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Huamin Li <3ericli@gmail.com> Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com> Signed-off-by: Rahul Tuli <rtuli@redhat.com> Signed-off-by: Yang <lymailforjob@gmail.com> Signed-off-by: Debolina Roy <debroy@redhat.com> Signed-off-by: David Chen <530634352@qq.com> Signed-off-by: wangzi <3220100013@zju.edu.cn> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com> Signed-off-by: Sara Kokkila Schumacher <saraks@ibm.com> Signed-off-by: Csrayz <jover@cmbchina.com> Signed-off-by: ivyilike <pww123@cmbchina.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Bowen Wang <abmfy@icloud.com> Signed-off-by: qqma <qqma@amazon.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: Zhuohan Li <zhuohan123@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: luka <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Johnny Yang <johnnyyang@google.com> Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Signed-off-by: Russell Bryant <rbryant@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: liuye.hj <liuye.hj@alibaba-inc.com> Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Zhikaiiii <1658973216@qq.com> Signed-off-by: Andreas Hartel <andreas.hartel@aleph-alpha.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Peter Pan <Peter.Pan@daocloud.io> Signed-off-by: Peter Pan <peter.pan@daocloud.io> Signed-off-by: Nicolò Lucchesi<nicolo.lucchesi@gmail.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Sage Moore <sage@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: Weida Hong <wdhongtw@google.com> Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Signed-off-by: Hashem Hashemi <hashem.hashemi@amd.com> Signed-off-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: ElizaWszola <elizaw.9289@gmail.com> Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Signed-off-by: ilmarkov <markovilya197@gmail.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Signed-off-by: rouchenzi <ruochenwen@gmail.com> Signed-off-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Signed-off-by: Andrew Xia <axia@meta.com> Signed-off-by: Kourosh Hakhamaneshi <kourosh@anyscale.com> Signed-off-by: Corey Lowman <clowman1993@gmail.com> Signed-off-by: jpvillam <jpvillam@amd.com> Signed-off-by: dougbtv <dosmith@redhat.com> Signed-off-by: Chenxi Yang <cxyang@fb.com> Signed-off-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: ahao-anyscale <ahao@anyscale.com> Signed-off-by: Yan Lu <luyan@nvidia.com> Signed-off-by: baxingpiaochong <771405853@qq.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com> Signed-off-by: Yong Hoon Shin <yhshin@meta.com> Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai> Signed-off-by: Benjamin Chislett <bchislett@nvidia.com> Signed-off-by: Ben Browning <bbrownin@redhat.com> Signed-off-by: Chengji Yao <chengjiyao@google.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Jackmin801 <ongjackm@gmail.com> Signed-off-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io> Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: Duncan Moss <djm.moss@gmail.com> Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Signed-off-by: Wei Wei <wwei6@meta.com> Signed-off-by: Saman Keon <samanamp@outlook.com> Signed-off-by: yangxurui <yangxurui@meituan.com> Signed-off-by: nicole-lihui <nicole.li@daocloud.io> Signed-off-by: courage17340 <courage17340@163.com> Signed-off-by: Jacob Kahn <jacobkahn1@gmail.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai> Signed-off-by: zxw <1020938856@qq.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Signed-off-by: chenlang <chen.lang5@zte.com.cn> Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: AlonKejzman <alonkeizman@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Signed-off-by: Eugene Khvedchenia <ekhvedchenia@nvidia.com> Signed-off-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Signed-off-by: yiting.jiang <yiting.jiang@daocloud.io> Signed-off-by: xaguilar <Xavier.AguilarFruto@amd.com> Signed-off-by: Iceber Gu <caiwei95@hotmail.com> Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: Icey <1790571317@qq.com> Signed-off-by: 许文卿 <xwq391974@alibaba-inc.com> Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Lucas Kabela <lucasakabela@gmail.com> Co-authored-by: Maximilien de Bayser <mbayser@br.ibm.com> Co-authored-by: Andrew Sansom <andrew@protopia.ai> Co-authored-by: Boyuan Feng <boyuan@meta.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: JartX <sagformas@epdcenter.es> Co-authored-by: Chendi.Xue <chendi.xue@intel.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: xin.li <xin.li@daocloud.io> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Wenlong Wang <wangwenlong2755@gmail.com> Co-authored-by: Manoel Marques <manoelmrqs@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: lirong <56789630+lirong-lirong@users.noreply.github.com> Co-authored-by: Michael Yao <haifeng.yao@daocloud.io> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Huamin Li <3ericli@gmail.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com> Co-authored-by: Simon Danielsson <70206058+simondanielsson@users.noreply.github.com> Co-authored-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Yang Liu <127183760+KKSK-DON@users.noreply.github.com> Co-authored-by: Deboleina <debroy@redhat.com> Co-authored-by: yinz-aizip <yinz@aizip.ai> Co-authored-by: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Co-authored-by: wangzi <3220100013@zju.edu.cn> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Yizhou <136800916+yiz-liu@users.noreply.github.com> Co-authored-by: Sara-KS <50249410+Sara-KS@users.noreply.github.com> Co-authored-by: Csrayz <jover@cmbchina.com> Co-authored-by: ivyilike <pww123@cmbchina.com> Co-authored-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Bowen Wang <abmfy@icloud.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Daisy-Ma-coder <daisy.ma.0117@gmail.com> Co-authored-by: qqma <qqma@amazon.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Johnny Yang <24908445+jcyang43@users.noreply.github.com> Co-authored-by: Chengji Yao <chengjiyao@google.com> Co-authored-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Chris Bamford <chrisbam4d@gmail.com> Co-authored-by: Alexander Matveev <59768536+alexm-redhat@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: liuye.hj <liuye.hj@alibaba-inc.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com> Co-authored-by: Lucia (Lu) Fang <fanglu@meta.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Ming Yang <yming@meta.com> Co-authored-by: Zhikaiiii <55917203+Zhikaiiii@users.noreply.github.com> Co-authored-by: Andreas Hartel <andreas@hartel.me> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: Joel <wuxibin89@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Peter Pan <peter.pan@daocloud.io> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Fanli Lin <fanli.lin@intel.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Sage Moore <sage@neuralmagic.com> Co-authored-by: yewentao256 <zhyanwentao@126.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: rivos-shreeasish <shreeasish@rivosinc.com> Co-authored-by: Chih-Chieh Yang <chih.chieh.yang@ibm.com> Co-authored-by: Weida Hong <wdhongtw@gmail.com> Co-authored-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Hashem Hashemi <159079214+amd-hhashemi@users.noreply.github.com> Co-authored-by: Amir Samani <samani@ualberta.ca> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Ilya Markov <markovilya197@gmail.com> Co-authored-by: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Co-authored-by: Jialin Ouyang <Jialin.Ouyang@gmail.com> Co-authored-by: rouchenzi <40842833+rouchenzi@users.noreply.github.com> Co-authored-by: Andrew Xia <axia@meta.com> Co-authored-by: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Co-authored-by: Corey Lowman <clowman1993@gmail.com> Co-authored-by: Juan Villamizar <100237675+jpvillam-amd@users.noreply.github.com> Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Doug Smith <dosmith@redhat.com> Co-authored-by: Chenxi Yang <cxyang@cs.utexas.edu> Co-authored-by: Chenxi Yang <cxyang@fb.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: 0xNullPath <luyanfcp@foxmail.com> Co-authored-by: baxingpiaochong <771405853@qq.com> Co-authored-by: Benjamin Chislett <bchislett@nvidia.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com> Co-authored-by: Yong Hoon Shin <48474650+sarckk@users.noreply.github.com> Co-authored-by: lhsjohn <huashuoli@tencent.com> Co-authored-by: Ben Browning <bbrownin@redhat.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: Jackmin801 <56836461+Jackmin801@users.noreply.github.com> Co-authored-by: Jonas M. Kübler <44084297+jmkuebler@users.noreply.github.com> Co-authored-by: Tao Hui <taohui3@gmail.com> Co-authored-by: rongfu.leng <rongfu.leng@daocloud.io> Co-authored-by: Shu Wang <shuw@nvidia.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Duncan Moss <djm.moss@gmail.com> Co-authored-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Wei Wei <wwei6@meta.com> Co-authored-by: Saman A. Pour <samanamp@outlook.com> Co-authored-by: XuruiYang <530534756@qq.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Co-authored-by: Nicole LiHui 🥜 <nicolelihui@outlook.com> Co-authored-by: courage17340 <courage17340@users.noreply.github.com> Co-authored-by: Jacob Kahn <jacobkahn1@gmail.com> Co-authored-by: Nicole LiHui 🥜 <nicole.li@daocloud.io> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Co-authored-by: yyzxw <34639446+yyzxw@users.noreply.github.com> Co-authored-by: wang.yuqi <noooop@126.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: chenlang <chen.lang5@zte.com.cn> Co-authored-by: chenlang <10346245@zte.com.cn> Co-authored-by: AlonKejzman <alonkeizman@gmail.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com> Co-authored-by: yitingdc <59356937+yitingdc@users.noreply.github.com> Co-authored-by: xaguilar-amd <xavier.aguilarfruto@amd.com> Co-authored-by: Iceber Gu <caiwei95@hotmail.com> Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com> Co-authored-by: Icey <1790571317@qq.com> Co-authored-by: Xu Wenqing <121550081+Xu-Wenqing@users.noreply.github.com> Co-authored-by: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
420 lines
12 KiB
Python
420 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import argparse
|
|
import datetime
|
|
import os
|
|
from typing import Union
|
|
|
|
import albumentations
|
|
import numpy as np
|
|
import rasterio
|
|
import regex as re
|
|
import torch
|
|
from einops import rearrange
|
|
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
|
|
|
from vllm import LLM
|
|
|
|
torch.set_default_dtype(torch.float16)
|
|
|
|
NO_DATA = -9999
|
|
NO_DATA_FLOAT = 0.0001
|
|
OFFSET = 0
|
|
PERCENTILE = 99
|
|
|
|
datamodule_config = {
|
|
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
|
|
"batch_size": 16,
|
|
"constant_scale": 0.0001,
|
|
"data_root": "/dccstor/geofm-finetuning/datasets/sen1floods11",
|
|
"drop_last": True,
|
|
"no_data_replace": 0.0,
|
|
"no_label_replace": -1,
|
|
"num_workers": 8,
|
|
"test_transform": [
|
|
albumentations.Resize(
|
|
always_apply=False, height=448, interpolation=1, p=1, width=448
|
|
),
|
|
albumentations.pytorch.ToTensorV2(
|
|
transpose_mask=False, always_apply=True, p=1.0
|
|
),
|
|
],
|
|
}
|
|
|
|
|
|
class PrithviMAE:
|
|
def __init__(self, model):
|
|
self.model = LLM(
|
|
model=model,
|
|
skip_tokenizer_init=True,
|
|
dtype="float16",
|
|
enforce_eager=True,
|
|
model_impl="terratorch",
|
|
)
|
|
|
|
def run(self, input_data, location_coords):
|
|
# merge the inputs into one data structure
|
|
if input_data is not None and input_data.dtype == torch.float32:
|
|
input_data = input_data.to(torch.float16)
|
|
input_data = input_data[0]
|
|
|
|
mm_data = {
|
|
"pixel_values": input_data,
|
|
"location_coords": location_coords,
|
|
}
|
|
|
|
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
|
|
outputs = self.model.encode(prompt, use_tqdm=False)
|
|
|
|
return outputs[0].outputs.data
|
|
|
|
|
|
def generate_datamodule():
|
|
datamodule = Sen1Floods11NonGeoDataModule(
|
|
data_root=datamodule_config["data_root"],
|
|
batch_size=datamodule_config["batch_size"],
|
|
num_workers=datamodule_config["num_workers"],
|
|
bands=datamodule_config["bands"],
|
|
drop_last=datamodule_config["drop_last"],
|
|
test_transform=datamodule_config["test_transform"],
|
|
)
|
|
|
|
return datamodule
|
|
|
|
|
|
def process_channel_group(orig_img, channels):
|
|
"""
|
|
Args:
|
|
orig_img: torch.Tensor representing original image (reference)
|
|
with shape = (bands, H, W).
|
|
channels: list of indices representing RGB channels.
|
|
|
|
Returns:
|
|
torch.Tensor with shape (num_channels, height, width)
|
|
for original image
|
|
"""
|
|
|
|
orig_img = orig_img[channels, ...]
|
|
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
|
|
valid_mask[orig_img == NO_DATA_FLOAT] = False
|
|
|
|
# Rescale (enhancing contrast)
|
|
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
|
min_value = OFFSET
|
|
|
|
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
|
|
|
|
# No data as zeros
|
|
orig_img[~valid_mask] = 0
|
|
|
|
return orig_img
|
|
|
|
|
|
def read_geotiff(file_path: str):
|
|
"""Read all bands from *file_path* and return image + meta info.
|
|
|
|
Args:
|
|
file_path: path to image file.
|
|
|
|
Returns:
|
|
np.ndarray with shape (bands, height, width)
|
|
meta info dict
|
|
"""
|
|
|
|
with rasterio.open(file_path) as src:
|
|
img = src.read()
|
|
meta = src.meta
|
|
try:
|
|
coords = src.lnglat()
|
|
except Exception:
|
|
# Cannot read coords
|
|
coords = None
|
|
|
|
return img, meta, coords
|
|
|
|
|
|
def save_geotiff(image, output_path: str, meta: dict):
|
|
"""Save multi-band image in Geotiff file.
|
|
|
|
Args:
|
|
image: np.ndarray with shape (bands, height, width)
|
|
output_path: path where to save the image
|
|
meta: dict with meta info.
|
|
"""
|
|
|
|
with rasterio.open(output_path, "w", **meta) as dest:
|
|
for i in range(image.shape[0]):
|
|
dest.write(image[i, :, :], i + 1)
|
|
|
|
return
|
|
|
|
|
|
def _convert_np_uint8(float_image: torch.Tensor):
|
|
image = float_image.numpy() * 255.0
|
|
image = image.astype(dtype=np.uint8)
|
|
|
|
return image
|
|
|
|
|
|
def load_example(
|
|
file_paths: list[str],
|
|
mean: list[float] = None,
|
|
std: list[float] = None,
|
|
indices: Union[list[int], None] = None,
|
|
):
|
|
"""Build an input example by loading images in *file_paths*.
|
|
|
|
Args:
|
|
file_paths: list of file paths .
|
|
mean: list containing mean values for each band in the
|
|
images in *file_paths*.
|
|
std: list containing std values for each band in the
|
|
images in *file_paths*.
|
|
|
|
Returns:
|
|
np.array containing created example
|
|
list of meta info for each image in *file_paths*
|
|
"""
|
|
|
|
imgs = []
|
|
metas = []
|
|
temporal_coords = []
|
|
location_coords = []
|
|
|
|
for file in file_paths:
|
|
img, meta, coords = read_geotiff(file)
|
|
|
|
# Rescaling (don't normalize on nodata)
|
|
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
|
if indices is not None:
|
|
img = img[..., indices]
|
|
if mean is not None and std is not None:
|
|
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
|
|
|
imgs.append(img)
|
|
metas.append(meta)
|
|
if coords is not None:
|
|
location_coords.append(coords)
|
|
|
|
try:
|
|
match = re.search(r"(\d{7,8}T\d{6})", file)
|
|
if match:
|
|
year = int(match.group(1)[:4])
|
|
julian_day = match.group(1).split("T")[0][4:]
|
|
if len(julian_day) == 3:
|
|
julian_day = int(julian_day)
|
|
else:
|
|
julian_day = (
|
|
datetime.datetime.strptime(julian_day, "%m%d")
|
|
.timetuple()
|
|
.tm_yday
|
|
)
|
|
temporal_coords.append([year, julian_day])
|
|
except Exception as e:
|
|
print(f"Could not extract timestamp for {file} ({e})")
|
|
|
|
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
|
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
|
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
|
|
|
return imgs, temporal_coords, location_coords, metas
|
|
|
|
|
|
def run_model(
|
|
input_data,
|
|
temporal_coords,
|
|
location_coords,
|
|
model,
|
|
datamodule,
|
|
img_size,
|
|
lightning_model=None,
|
|
):
|
|
# Reflect pad if not divisible by img_size
|
|
original_h, original_w = input_data.shape[-2:]
|
|
pad_h = (img_size - (original_h % img_size)) % img_size
|
|
pad_w = (img_size - (original_w % img_size)) % img_size
|
|
input_data = np.pad(
|
|
input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
|
|
)
|
|
|
|
# Build sliding window
|
|
|
|
batch_size = 1
|
|
# batch = torch.tensor(input_data, device="cpu")
|
|
batch = torch.tensor(input_data)
|
|
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
|
h1, w1 = windows.shape[3:5]
|
|
windows = rearrange(
|
|
windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
|
|
)
|
|
|
|
# Split into batches if number of windows > batch_size
|
|
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
|
windows = torch.tensor_split(windows, num_batches, dim=0)
|
|
|
|
if temporal_coords:
|
|
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
|
|
else:
|
|
temporal_coords = None
|
|
if location_coords:
|
|
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
|
|
else:
|
|
location_coords = None
|
|
|
|
# Run Prithvi-EO-V2-300M-TL-Sen1Floods11
|
|
pred_imgs = []
|
|
for x in windows:
|
|
# Apply standardization
|
|
x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1, 2, 0))
|
|
x = datamodule.aug(x)["image"]
|
|
|
|
with torch.no_grad():
|
|
pred = model.run(x, location_coords=location_coords)
|
|
y_hat = pred.argmax(dim=1)
|
|
|
|
y_hat = torch.nn.functional.interpolate(
|
|
y_hat.unsqueeze(1).float(), size=img_size, mode="nearest"
|
|
)
|
|
|
|
pred_imgs.append(y_hat)
|
|
|
|
pred_imgs = torch.concat(pred_imgs, dim=0)
|
|
|
|
# Build images from patches
|
|
pred_imgs = rearrange(
|
|
pred_imgs,
|
|
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
|
|
h=img_size,
|
|
w=img_size,
|
|
b=1,
|
|
c=1,
|
|
h1=h1,
|
|
w1=w1,
|
|
)
|
|
|
|
# Cut padded area back to original size
|
|
pred_imgs = pred_imgs[..., :original_h, :original_w]
|
|
|
|
# Squeeze (batch size 1)
|
|
pred_imgs = pred_imgs[0]
|
|
|
|
return pred_imgs
|
|
|
|
|
|
def main(
|
|
data_file: str,
|
|
model: str,
|
|
output_dir: str,
|
|
rgb_outputs: bool,
|
|
input_indices: list[int] = None,
|
|
):
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
model_obj = PrithviMAE(model=model)
|
|
datamodule = generate_datamodule()
|
|
img_size = 512 # Size of Sen1Floods11
|
|
|
|
input_data, temporal_coords, location_coords, meta_data = load_example(
|
|
file_paths=[data_file],
|
|
indices=input_indices,
|
|
)
|
|
|
|
meta_data = meta_data[0] # only one image
|
|
|
|
if input_data.mean() > 1:
|
|
input_data = input_data / 10000 # Convert to range 0-1
|
|
|
|
channels = [
|
|
datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"]
|
|
] # BGR -> RGB
|
|
|
|
pred = run_model(
|
|
input_data, temporal_coords, location_coords, model_obj, datamodule, img_size
|
|
)
|
|
# Save pred
|
|
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
|
pred_file = os.path.join(
|
|
output_dir, f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
|
|
)
|
|
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
|
|
|
|
# Save image + pred
|
|
meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0)
|
|
|
|
if input_data.mean() < 1:
|
|
input_data = input_data * 10000 # Scale to 0-10000
|
|
|
|
rgb_orig = process_channel_group(
|
|
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
|
|
channels=channels,
|
|
)
|
|
rgb_orig = rgb_orig.to(torch.float32)
|
|
|
|
pred[pred == 0.0] = np.nan
|
|
img_pred = rgb_orig * 0.7 + pred * 0.3
|
|
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
|
|
|
|
img_pred_file = os.path.join(
|
|
output_dir, f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff"
|
|
)
|
|
save_geotiff(
|
|
image=_convert_np_uint8(img_pred),
|
|
output_path=img_pred_file,
|
|
meta=meta_data,
|
|
)
|
|
|
|
# Save image rgb
|
|
if rgb_outputs:
|
|
name_suffix = os.path.splitext(os.path.basename(data_file))[0]
|
|
rgb_file = os.path.join(
|
|
output_dir,
|
|
f"original_rgb_{name_suffix}.tiff",
|
|
)
|
|
save_geotiff(
|
|
image=_convert_np_uint8(rgb_orig),
|
|
output_path=rgb_file,
|
|
meta=meta_data,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
|
|
|
parser.add_argument(
|
|
"--data_file",
|
|
type=str,
|
|
default="./India_900498_S2Hand.tif",
|
|
help="Path to the file.",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
|
help="Path to a checkpoint file to load from.",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=str,
|
|
default="output",
|
|
help="Path to the directory where to save outputs.",
|
|
)
|
|
parser.add_argument(
|
|
"--input_indices",
|
|
default=[1, 2, 3, 8, 11, 12],
|
|
type=int,
|
|
nargs="+",
|
|
help="""
|
|
0-based indices of the six Prithvi channels to be selected from the input.
|
|
By default selects [1,2,3,8,11,12] for S2L1C data.
|
|
""",
|
|
)
|
|
parser.add_argument(
|
|
"--rgb_outputs",
|
|
action="store_true",
|
|
help="If present, output files will only contain RGB channels. "
|
|
"Otherwise, all bands will be saved.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
main(**vars(args))
|