diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 6dc03e85baa99..4fdc7a3cf709e 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -1,122 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This is a demo script showing how to use the -PrithviGeospatialMAE model with vLLM -This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa - -Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa - -The requirements for running this script are: -- Installing [terratorch, albumentations, rasterio] in your python environment -- downloading the model weights in a 'model' folder local to the script - (temporary measure until the proper config.json file is uploaded to HF) -- download an input example image (India_900498_S2Hand.tif) and place it in - the same folder with the script (or specify with the --data_file argument) - -Run the example: -python prithvi_geospatial_mae.py - -""" # noqa: E501 - import argparse import datetime import os +import re from typing import Union import albumentations import numpy as np import rasterio -import regex as re import torch from einops import rearrange from terratorch.datamodules import Sen1Floods11NonGeoDataModule from vllm import LLM +torch.set_default_dtype(torch.float16) + NO_DATA = -9999 NO_DATA_FLOAT = 0.0001 OFFSET = 0 PERCENTILE = 99 -model_config = """{ - "architectures": ["PrithviGeoSpatialMAE"], - "num_classes": 0, - "pretrained_cfg": { - "task_args": { - "task": "SemanticSegmentationTask", - "model_factory": "EncoderDecoderFactory", - "loss": "ce", - "ignore_index": -1, - "lr": 0.001, - "freeze_backbone": false, - "freeze_decoder": false, - "plot_on_val": 10, - "optimizer": "AdamW", - "scheduler": "CosineAnnealingLR" - }, - "model_args": { - "backbone_pretrained": false, - "backbone": "prithvi_eo_v2_300_tl", - "decoder": "UperNetDecoder", - "decoder_channels": 256, - "decoder_scale_modules": true, - "num_classes": 2, - "rescale": true, - "backbone_bands": [ - "BLUE", - "GREEN", - "RED", - "NIR_NARROW", - "SWIR_1", - "SWIR_2" - ], - "head_dropout": 0.1, - "necks": [ - { - "name": "SelectIndices", - "indices": [ - 5, - 11, - 17, - 23 - ] - }, - { - "name": "ReshapeTokensToImage" - } - ] - }, - "optimizer_params" : { - "lr": 5.0e-05, - "betas": [0.9, 0.999], - "eps": [1.0e-08], - "weight_decay": 0.05, - "amsgrad": false, - "maximize": false, - "capturable": false, - "differentiable": false - }, - "scheduler_params" : { - "T_max": 50, - "eta_min": 0, - "last_epoch": -1, - "verbose": "deprecated" - } - }, - - - "torch_dtype": "float32" -} -""" - -# Temporarily creating the "config.json" for the model. -# This is going to disappear once the correct config.json is available on HF -with open( - os.path.join(os.path.dirname(__file__), "./model/config.json"), "w" -) as config_file: - config_file.write(model_config) - datamodule_config = { "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], "batch_size": 16, @@ -138,28 +43,24 @@ datamodule_config = { class PrithviMAE: - def __init__(self): - print("Initializing PrithviMAE model") - self.llm = LLM( - model=os.path.join(os.path.dirname(__file__), "./model"), - skip_tokenizer_init=True, - dtype="float32", + def __init__(self, model): + self.model = LLM( + model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True ) def run(self, input_data, location_coords): - print("################ Running inference on vLLM ##############") # merge the inputs into one data structure + if input_data is not None and input_data.dtype == torch.float32: + input_data = input_data.to(torch.float16) + input_data = input_data[0] + mm_data = { - "pixel_values": torch.empty(0) if input_data is None else input_data, - "location_coords": torch.empty(0) - if location_coords is None - else location_coords, + "pixel_values": input_data, + "location_coords": location_coords, } prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} - - outputs = self.llm.encode(prompt, use_tqdm=False) - print("################ Inference done (it took seconds) ##############") + outputs = self.model.encode(prompt, use_tqdm=False) return outputs[0].outputs.data @@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels): """ Args: orig_img: torch.Tensor representing original image (reference) - with shape = (bands, H, W). + with shape = (bands, H, W). channels: list of indices representing RGB channels. Returns: - torch.Tensor with shape (num_channels, height, width) for original image + torch.Tensor with shape (num_channels, height, width) + for original image """ orig_img = orig_img[channels, ...] @@ -260,10 +162,10 @@ def load_example( Args: file_paths: list of file paths . - mean: list containing mean values for each band in the images - in *file_paths*. - std: list containing std values for each band in the images - in *file_paths*. + mean: list containing mean values for each band in the + images in *file_paths*. + std: list containing std values for each band in the + images in *file_paths*. Returns: np.array containing created example @@ -308,7 +210,7 @@ def load_example( print(f"Could not extract timestamp for {file} ({e})") imgs = np.stack(imgs, axis=0) # num_frames, H, W, C - imgs = np.moveaxis(imgs, -1, 0).astype("float32") + imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W imgs = np.expand_dims(imgs, axis=0) # add batch di return imgs, temporal_coords, location_coords, metas @@ -332,8 +234,10 @@ def run_model( ) # Build sliding window + batch_size = 1 - batch = torch.tensor(input_data, device="cpu") + # batch = torch.tensor(input_data, device="cpu") + batch = torch.tensor(input_data) windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) h1, w1 = windows.shape[3:5] windows = rearrange( @@ -344,18 +248,16 @@ def run_model( num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 windows = torch.tensor_split(windows, num_batches, dim=0) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - if temporal_coords: - temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0) + temporal_coords = torch.tensor(temporal_coords).unsqueeze(0) else: temporal_coords = None if location_coords: - location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0) + location_coords = torch.tensor(location_coords[0]).unsqueeze(0) else: location_coords = None - # Run model + # Run Prithvi-EO-V2-300M-TL-Sen1Floods11 pred_imgs = [] for x in windows: # Apply standardization @@ -363,15 +265,7 @@ def run_model( x = datamodule.aug(x)["image"] with torch.no_grad(): - x = x.to(device) pred = model.run(x, location_coords=location_coords) - if lightning_model: - pred_lightning = lightning_model( - x, temporal_coords=temporal_coords, location_coords=location_coords - ) - pred_lightning = pred_lightning.output.detach().cpu() - if not torch.equal(pred, pred_lightning): - print("Inference output is not equal") y_hat = pred.argmax(dim=1) y_hat = torch.nn.functional.interpolate( @@ -403,52 +297,18 @@ def run_model( return pred_imgs -def parse_args(): - parser = argparse.ArgumentParser("MAE run inference", add_help=False) - - parser.add_argument( - "--data_file", - type=str, - default="./India_900498_S2Hand.tif", - help="Path to the file.", - ) - parser.add_argument( - "--output_dir", - type=str, - default="output", - help="Path to the directory where to save outputs.", - ) - parser.add_argument( - "--input_indices", - default=[1, 2, 3, 8, 11, 12], - type=int, - nargs="+", - help="0-based indices of the six Prithvi channels to be selected from the " - "input. By default selects [1,2,3,8,11,12] for S2L1C data.", - ) - parser.add_argument( - "--rgb_outputs", - action="store_true", - help="If present, output files will only contain RGB channels. " - "Otherwise, all bands will be saved.", - ) - - def main( data_file: str, + model: str, output_dir: str, rgb_outputs: bool, input_indices: list[int] = None, ): os.makedirs(output_dir, exist_ok=True) - # Load model --------------------------------------------------------------- - - model_obj = PrithviMAE() + model_obj = PrithviMAE(model=model) datamodule = generate_datamodule() - img_size = 256 # Size of Sen1Floods11 - - # Loading data ------------------------------------------------------------- + img_size = 512 # Size of Sen1Floods11 input_data, temporal_coords, location_coords, meta_data = load_example( file_paths=[data_file], @@ -460,8 +320,6 @@ def main( if input_data.mean() > 1: input_data = input_data / 10000 # Convert to range 0-1 - # Running model ------------------------------------------------------------ - channels = [ datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"] ] # BGR -> RGB @@ -469,7 +327,6 @@ def main( pred = run_model( input_data, temporal_coords, location_coords, model_obj, datamodule, img_size ) - # Save pred meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) pred_file = os.path.join( @@ -487,6 +344,7 @@ def main( orig_img=torch.Tensor(input_data[0, :, 0, ...]), channels=channels, ) + rgb_orig = rgb_orig.to(torch.float32) pred[pred == 0.0] = np.nan img_pred = rgb_orig * 0.7 + pred * 0.3 @@ -503,9 +361,10 @@ def main( # Save image rgb if rgb_outputs: + name_suffix = os.path.splitext(os.path.basename(data_file))[0] rgb_file = os.path.join( output_dir, - f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff", + f"original_rgb_{name_suffix}.tiff", ) save_geotiff( image=_convert_np_uint8(rgb_orig), @@ -515,6 +374,42 @@ def main( if __name__ == "__main__": - args = parse_args() + parser = argparse.ArgumentParser("MAE run inference", add_help=False) + + parser.add_argument( + "--data_file", + type=str, + default="./India_900498_S2Hand.tif", + help="Path to the file.", + ) + parser.add_argument( + "--model", + type=str, + default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + help="Path to a checkpoint file to load from.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Path to the directory where to save outputs.", + ) + parser.add_argument( + "--input_indices", + default=[1, 2, 3, 8, 11, 12], + type=int, + nargs="+", + help=""" + 0-based indices of the six Prithvi channels to be selected from the input. + By default selects [1,2,3,8,11,12] for S2L1C data. + """, + ) + parser.add_argument( + "--rgb_outputs", + action="store_true", + help="If present, output files will only contain RGB channels. " + "Otherwise, all bands will be saved.", + ) + args = parser.parse_args() main(**vars(args)) diff --git a/requirements/test.in b/requirements/test.in index c6c68891d6a6a..9f66e2d6919a5 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -54,3 +54,4 @@ runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 +terratorch==1.1rc2 # required for PrithviMAE test \ No newline at end of file diff --git a/requirements/test.txt b/requirements/test.txt index aadbab03f6fc8..a2b230102d4ea 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -6,6 +6,10 @@ accelerate==1.0.1 # via # lm-eval # peft +aenum==3.1.16 + # via lightly +affine==2.4.0 + # via rasterio aiohappyeyeballs==2.4.3 # via aiohttp aiohttp==3.10.11 @@ -21,8 +25,18 @@ aiosignal==1.3.1 # via # aiohttp # ray +albucore==0.0.16 + # via terratorch +albumentations==1.4.6 + # via terratorch +alembic==1.16.4 + # via mlflow annotated-types==0.7.0 # via pydantic +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf anyio==4.6.2.post1 # via # httpx @@ -34,10 +48,12 @@ arrow==1.3.0 attrs==24.2.0 # via # aiohttp + # fiona # hypothesis # jsonlines # jsonschema # pytest-subtests + # rasterio # referencing audioread==3.0.1 # via librosa @@ -46,9 +62,13 @@ backoff==2.2.1 # -r requirements/test.in # schemathesis bitsandbytes==0.46.1 - # via -r requirements/test.in + # via + # -r requirements/test.in + # lightning black==24.10.0 # via datamodel-code-generator +blinker==1.9.0 + # via flask blobfile==3.0.0 # via -r requirements/test.in bm25s==0.2.13 @@ -64,11 +84,18 @@ bounded-pool-executor==0.0.3 buildkite-test-collector==0.1.9 # via -r requirements/test.in cachetools==5.5.2 - # via google-auth + # via + # google-auth + # mlflow-skinny certifi==2024.8.30 # via + # fiona # httpcore # httpx + # lightly + # pyogrio + # pyproj + # rasterio # requests cffi==1.17.1 # via soundfile @@ -79,11 +106,28 @@ charset-normalizer==3.4.0 click==8.1.7 # via # black + # click-plugins + # cligj + # fiona + # flask # jiwer + # mlflow-skinny # nltk + # rasterio # ray # schemathesis # typer + # uvicorn +click-plugins==1.1.1.2 + # via + # fiona + # rasterio +cligj==0.7.2 + # via + # fiona + # rasterio +cloudpickle==3.1.1 + # via mlflow-skinny colorama==0.4.6 # via # sacrebleu @@ -99,6 +143,8 @@ cupy-cuda12x==13.3.0 # via ray cycler==0.12.1 # via matplotlib +databricks-sdk==0.59.0 + # via mlflow-skinny datamodel-code-generator==0.26.3 # via -r requirements/test.in dataproperty==1.0.1 @@ -122,13 +168,21 @@ distlib==0.3.9 # via virtualenv dnspython==2.7.0 # via email-validator +docker==7.1.0 + # via mlflow docopt==0.6.2 # via num2words -einops==0.8.0 +docstring-parser==0.17.0 + # via jsonargparse +efficientnet-pytorch==0.7.1 + # via segmentation-models-pytorch +einops==0.8.1 # via # -r requirements/test.in # encodec # mamba-ssm + # terratorch + # torchgeo # vector-quantize-pytorch # vocos einx==0.3.0 @@ -141,6 +195,8 @@ eval-type-backport==0.2.2 # via mteb evaluate==0.4.3 # via lm-eval +fastapi==0.116.1 + # via mlflow-skinny fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -156,6 +212,10 @@ filelock==3.16.1 # torch # transformers # virtualenv +fiona==1.10.1 + # via torchgeo +flask==3.1.1 + # via mlflow fonttools==4.54.1 # via matplotlib fqdn==1.5.1 @@ -173,6 +233,8 @@ fsspec==2024.9.0 # evaluate # fastparquet # huggingface-hub + # lightning + # pytorch-lightning # torch ftfy==6.3.1 # via open-clip-torch @@ -180,18 +242,41 @@ genai-perf==0.0.8 # via -r requirements/test.in genson==1.3.0 # via datamodel-code-generator +geopandas==1.0.1 + # via terratorch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.44 + # via mlflow-skinny google-api-core==2.24.2 # via opencensus google-auth==2.40.2 - # via google-api-core + # via + # databricks-sdk + # google-api-core googleapis-common-protos==1.70.0 # via google-api-core +graphene==3.4.3 + # via mlflow graphql-core==3.2.6 - # via hypothesis-graphql + # via + # graphene + # graphql-relay + # hypothesis-graphql +graphql-relay==3.2.0 + # via graphene +greenlet==3.2.3 + # via sqlalchemy grpcio==1.71.0 # via ray +gunicorn==23.0.0 + # via mlflow h11==0.14.0 - # via httpcore + # via + # httpcore + # uvicorn +h5py==3.13.0 + # via terratorch harfile==0.3.0 # via schemathesis hf-xet==1.1.3 @@ -204,7 +289,7 @@ httpx==0.27.2 # via # -r requirements/test.in # schemathesis -huggingface-hub==0.33.0 +huggingface-hub==0.33.1 # via # -r requirements/test.in # accelerate @@ -212,13 +297,19 @@ huggingface-hub==0.33.0 # evaluate # open-clip-torch # peft + # segmentation-models-pytorch # sentence-transformers + # terratorch # timm # tokenizers # transformers # vocos humanize==4.11.0 # via runai-model-streamer +hydra-core==1.3.2 + # via + # lightly + # lightning hypothesis==6.131.0 # via # hypothesis-graphql @@ -236,6 +327,14 @@ idna==3.10 # jsonschema # requests # yarl +imageio==2.37.0 + # via scikit-image +importlib-metadata==8.7.0 + # via + # mlflow-skinny + # opentelemetry-api +importlib-resources==6.5.2 + # via typeshed-client inflect==5.6.2 # via datamodel-code-generator iniconfig==2.0.0 @@ -244,9 +343,13 @@ isoduration==20.11.0 # via jsonschema isort==5.13.2 # via datamodel-code-generator +itsdangerous==2.2.0 + # via flask jinja2==3.1.6 # via # datamodel-code-generator + # flask + # mlflow # torch jiwer==3.0.5 # via -r requirements/test.in @@ -259,6 +362,10 @@ joblib==1.4.2 # librosa # nltk # scikit-learn +jsonargparse==4.35.0 + # via + # lightning + # terratorch jsonlines==4.0.0 # via lm-eval jsonpointer==3.0.0 @@ -277,12 +384,33 @@ kaleido==0.2.1 # via genai-perf kiwisolver==1.4.7 # via matplotlib +kornia==0.8.1 + # via torchgeo +kornia-rs==0.1.9 + # via kornia lazy-loader==0.4 - # via librosa + # via + # librosa + # scikit-image libnacl==2.1.0 # via tensorizer librosa==0.10.2.post1 # via -r requirements/test.in +lightly==1.5.20 + # via + # terratorch + # torchgeo +lightly-utils==0.0.2 + # via lightly +lightning==2.5.1.post0 + # via + # terratorch + # torchgeo +lightning-utilities==0.14.3 + # via + # lightning + # pytorch-lightning + # torchmetrics llvmlite==0.44.0 # via numba lm-eval==0.4.8 @@ -291,16 +419,27 @@ lxml==5.3.0 # via # blobfile # sacrebleu +mako==1.3.10 + # via alembic mamba-ssm==2.2.4 # via -r requirements/test.in +markdown==3.8.2 + # via mlflow markdown-it-py==3.0.0 # via rich markupsafe==3.0.1 # via + # flask # jinja2 + # mako # werkzeug matplotlib==3.9.2 - # via -r requirements/test.in + # via + # -r requirements/test.in + # lightning + # mlflow + # pycocotools + # torchgeo mbstrdecoder==1.1.3 # via # dataproperty @@ -310,6 +449,10 @@ mdurl==0.1.2 # via markdown-it-py mistral-common==1.8.0 # via -r requirements/test.in +mlflow==2.22.0 + # via terratorch +mlflow-skinny==2.22.0 + # via mlflow more-itertools==10.5.0 # via lm-eval mpmath==1.3.0 @@ -328,10 +471,14 @@ multiprocess==0.70.16 # via # datasets # evaluate +munch==4.0.0 + # via pretrainedmodels mypy-extensions==1.0.0 # via black networkx==3.2.1 - # via torch + # via + # scikit-image + # torch ninja==1.11.1.3 # via mamba-ssm nltk==3.9.1 @@ -348,6 +495,8 @@ numpy==1.26.4 # via # -r requirements/test.in # accelerate + # albucore + # albumentations # bitsandbytes # bm25s # contourpy @@ -358,9 +507,15 @@ numpy==1.26.4 # evaluate # fastparquet # genai-perf + # geopandas + # h5py + # imageio # librosa + # lightly + # lightly-utils # matplotlib # mistral-common + # mlflow # mteb # numba # numexpr @@ -368,18 +523,30 @@ numpy==1.26.4 # pandas # patsy # peft + # pycocotools + # pyogrio + # rasterio + # rioxarray # rouge-score # runai-model-streamer # sacrebleu + # scikit-image # scikit-learn # scipy + # segmentation-models-pytorch + # shapely # soxr # statsmodels + # tensorboardx # tensorizer + # tifffile + # torchgeo + # torchmetrics # torchvision # transformers # tritonclient # vocos + # xarray nvidia-cublas-cu12==12.8.3.14 # via # nvidia-cudnn-cu12 @@ -417,6 +584,10 @@ nvidia-nvjitlink-cu12==12.8.61 # torch nvidia-nvtx-cu12==12.8.55 # via torch +omegaconf==2.3.0 + # via + # hydra-core + # lightning open-clip-torch==2.32.0 # via -r requirements/test.in opencensus==0.11.4 @@ -426,7 +597,18 @@ opencensus-context==0.1.3 opencv-python-headless==4.11.0.86 # via # -r requirements/test.in + # albucore + # albumentations # mistral-common +opentelemetry-api==1.35.0 + # via + # mlflow-skinny + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.35.0 + # via mlflow-skinny +opentelemetry-semantic-conventions==0.56b0 + # via opentelemetry-sdk packaging==24.2 # via # accelerate @@ -435,26 +617,44 @@ packaging==24.2 # datasets # evaluate # fastparquet + # geopandas + # gunicorn # huggingface-hub + # hydra-core + # kornia # lazy-loader + # lightning + # lightning-utilities # mamba-ssm # matplotlib + # mlflow-skinny # peft # plotly # pooch + # pyogrio # pytest # pytest-rerunfailures + # pytorch-lightning # ray + # rioxarray + # scikit-image # statsmodels + # tensorboardx + # torchmetrics # transformers # typepy + # xarray pandas==2.2.3 # via # datasets # evaluate # fastparquet # genai-perf + # geopandas + # mlflow # statsmodels + # torchgeo + # xarray pathspec==0.12.1 # via black pathvalidate==3.2.1 @@ -468,9 +668,14 @@ peft==0.13.2 pillow==10.4.0 # via # genai-perf + # imageio + # lightly-utils # matplotlib # mistral-common + # scikit-image + # segmentation-models-pytorch # sentence-transformers + # torchgeo # torchvision platformdirs==4.3.6 # via @@ -489,6 +694,8 @@ portalocker==2.10.1 # via sacrebleu pqdm==0.2.0 # via -r requirements/test.in +pretrainedmodels==0.7.4 + # via segmentation-models-pytorch prometheus-client==0.22.0 # via ray propcache==0.2.0 @@ -499,8 +706,10 @@ protobuf==5.28.3 # via # google-api-core # googleapis-common-protos + # mlflow-skinny # proto-plus # ray + # tensorboardx # tensorizer psutil==6.1.0 # via @@ -515,6 +724,7 @@ pyarrow==18.0.0 # via # datasets # genai-perf + # mlflow pyasn1==0.6.1 # via # pyasn1-modules @@ -523,6 +733,8 @@ pyasn1-modules==0.4.2 # via google-auth pybind11==2.13.6 # via lm-eval +pycocotools==2.0.8 + # via terratorch pycountry==24.6.1 # via pydantic-extra-types pycparser==2.22 @@ -532,8 +744,12 @@ pycryptodomex==3.22.0 pydantic==2.11.5 # via # -r requirements/test.in + # albumentations # datamodel-code-generator + # fastapi + # lightly # mistral-common + # mlflow-skinny # mteb # pydantic-extra-types # ray @@ -543,15 +759,24 @@ pydantic-extra-types==2.10.5 # via mistral-common pygments==2.18.0 # via rich +pyogrio==0.11.0 + # via geopandas pyparsing==3.2.0 - # via matplotlib + # via + # matplotlib + # rasterio +pyproj==3.7.1 + # via + # geopandas + # rioxarray + # torchgeo pyrate-limiter==3.7.0 # via schemathesis pystemmer==3.0.0 # via mteb pytablewriter==1.2.0 # via lm-eval -pytest==8.3.3 +pytest==8.3.5 # via # -r requirements/test.in # buildkite-test-collector @@ -564,6 +789,7 @@ pytest==8.3.3 # pytest-subtests # pytest-timeout # schemathesis + # terratorch pytest-asyncio==0.24.0 # via -r requirements/test.in pytest-forked==1.6.0 @@ -578,15 +804,23 @@ pytest-subtests==0.14.1 # via schemathesis pytest-timeout==2.3.1 # via -r requirements/test.in +python-box==7.3.2 + # via terratorch python-dateutil==2.9.0.post0 # via # arrow # botocore + # graphene + # lightly # matplotlib # pandas # typepy python-rapidjson==1.20 # via tritonclient +pytorch-lightning==2.5.2 + # via + # lightly + # lightning pytrec-eval-terrier==0.5.7 # via mteb pytz==2024.2 @@ -596,11 +830,17 @@ pytz==2024.2 pyyaml==6.0.2 # via # accelerate + # albumentations # datamodel-code-generator # datasets # genai-perf # huggingface-hub + # jsonargparse + # lightning + # mlflow-skinny + # omegaconf # peft + # pytorch-lightning # ray # responses # schemathesis @@ -609,6 +849,11 @@ pyyaml==6.0.2 # vocos rapidfuzz==3.12.1 # via jiwer +rasterio==1.4.3 + # via + # rioxarray + # terratorch + # torchgeo ray==2.43.0 # via -r requirements/test.in redis==5.2.0 @@ -627,12 +872,16 @@ regex==2024.9.11 requests==2.32.3 # via # buildkite-test-collector + # databricks-sdk # datasets + # docker # evaluate # google-api-core # huggingface-hub + # lightly # lm-eval # mistral-common + # mlflow-skinny # mteb # pooch # ray @@ -650,8 +899,11 @@ rfc3987==1.3.8 rich==13.9.4 # via # genai-perf + # lightning # mteb # typer +rioxarray==0.19.0 + # via terratorch rouge-score==0.1.2 # via lm-eval rpds-py==0.20.1 @@ -660,6 +912,8 @@ rpds-py==0.20.1 # referencing rsa==4.9.1 # via google-auth +rtree==1.4.0 + # via torchgeo runai-model-streamer==0.11.0 # via -r requirements/test.in runai-model-streamer-s3==0.11.0 @@ -677,21 +931,32 @@ safetensors==0.4.5 # transformers schemathesis==3.39.15 # via -r requirements/test.in +scikit-image==0.25.2 + # via albumentations scikit-learn==1.5.2 # via + # albumentations # librosa # lm-eval + # mlflow # mteb # sentence-transformers scipy==1.13.1 # via + # albumentations # bm25s # librosa + # mlflow # mteb + # scikit-image # scikit-learn # sentence-transformers # statsmodels # vocos +segmentation-models-pytorch==0.4.0 + # via + # terratorch + # torchgeo sentence-transformers==3.2.1 # via # -r requirements/test.in @@ -700,21 +965,30 @@ sentencepiece==0.2.0 # via mistral-common setuptools==77.0.3 # via + # lightning-utilities # mamba-ssm # pytablewriter # torch # triton +shapely==2.1.1 + # via + # geopandas + # torchgeo shellingham==1.5.4 # via typer six==1.16.0 # via # junit-xml + # lightly # opencensus # python-dateutil # rfc3339-validator # rouge-score + # segmentation-models-pytorch smart-open==7.1.0 # via ray +smmap==5.0.2 + # via gitdb sniffio==1.3.1 # via # anyio @@ -727,10 +1001,17 @@ soundfile==0.12.1 # librosa soxr==0.5.0.post1 # via librosa +sqlalchemy==2.0.41 + # via + # alembic + # mlflow sqlitedict==2.1.0 # via lm-eval +sqlparse==0.5.3 + # via mlflow-skinny starlette==0.46.2 # via + # fastapi # schemathesis # starlette-testclient starlette-testclient==0.4.1 @@ -751,18 +1032,29 @@ tenacity==9.0.0 # via # lm-eval # plotly +tensorboardx==2.6.4 + # via lightning tensorizer==2.10.1 # via -r requirements/test.in +terratorch==1.1rc2 + # via -r requirements/test.in threadpoolctl==3.5.0 # via scikit-learn +tifffile==2025.3.30 + # via + # scikit-image + # terratorch tiktoken==0.7.0 # via # lm-eval # mistral-common -timm==1.0.11 +timm==1.0.15 # via # -r requirements/test.in # open-clip-torch + # segmentation-models-pytorch + # terratorch + # torchgeo tokenizers==0.21.1 # via # -r requirements/test.in @@ -776,18 +1068,28 @@ torch==2.7.1+cu128 # -r requirements/test.in # accelerate # bitsandbytes + # efficientnet-pytorch # encodec # fastsafetensors + # kornia + # lightly + # lightning # lm-eval # mamba-ssm # mteb # open-clip-torch # peft + # pretrainedmodels + # pytorch-lightning # runai-model-streamer + # segmentation-models-pytorch # sentence-transformers # tensorizer + # terratorch # timm # torchaudio + # torchgeo + # torchmetrics # torchvision # vector-quantize-pytorch # vocos @@ -796,22 +1098,40 @@ torchaudio==2.7.1+cu128 # -r requirements/test.in # encodec # vocos +torchgeo==0.7.0 + # via terratorch +torchmetrics==1.7.4 + # via + # lightning + # pytorch-lightning + # terratorch + # torchgeo torchvision==0.22.1+cu128 # via # -r requirements/test.in + # lightly # open-clip-torch + # pretrainedmodels + # segmentation-models-pytorch + # terratorch # timm + # torchgeo tqdm==4.66.6 # via # datasets # evaluate # huggingface-hub + # lightly + # lightning # lm-eval # mteb # nltk # open-clip-torch # peft # pqdm + # pretrainedmodels + # pytorch-lightning + # segmentation-models-pytorch # sentence-transformers # tqdm-multiprocess # transformers @@ -843,18 +1163,34 @@ typer==0.15.2 # via fastsafetensors types-python-dateutil==2.9.0.20241206 # via arrow +typeshed-client==2.8.2 + # via jsonargparse typing-extensions==4.12.2 # via + # albumentations + # alembic + # fastapi + # graphene # huggingface-hub # librosa + # lightning + # lightning-utilities # mistral-common + # mlflow-skinny # mteb + # opentelemetry-api + # opentelemetry-sdk + # opentelemetry-semantic-conventions # pqdm # pydantic # pydantic-core # pydantic-extra-types + # pytorch-lightning + # sqlalchemy # torch + # torchgeo # typer + # typeshed-client # typing-inspection typing-inspection==0.4.1 # via pydantic @@ -866,9 +1202,13 @@ urllib3==2.2.3 # via # blobfile # botocore + # docker + # lightly # requests # responses # tritonclient +uvicorn==0.35.0 + # via mlflow-skinny vector-quantize-pytorch==1.21.2 # via -r requirements/test.in virtualenv==20.31.2 @@ -880,11 +1220,15 @@ wcwidth==0.2.13 webcolors==24.11.1 # via jsonschema werkzeug==3.1.3 - # via schemathesis + # via + # flask + # schemathesis word2number==1.1 # via lm-eval wrapt==1.17.2 # via smart-open +xarray==2025.7.1 + # via rioxarray xxhash==3.5.0 # via # datasets @@ -893,5 +1237,7 @@ yarl==1.17.1 # via # aiohttp # schemathesis +zipp==3.23.0 + # via importlib-metadata zstandard==0.23.0 # via lm-eval diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py new file mode 100644 index 0000000000000..f08d83c082125 --- /dev/null +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.utils import set_default_torch_num_threads + +from ....conftest import VllmRunner + + +def generate_test_mm_data(): + mm_data = { + "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), + } + return mm_data + + +def _run_test( + vllm_runner: type[VllmRunner], + model: str, +) -> None: + + prompt = [ + { + # This model deals with no text input + "prompt_token_ids": [1], + "multi_modal_data": generate_test_mm_data(), + } for _ in range(10) + ] + + with ( + set_default_torch_num_threads(1), + vllm_runner( + model, + task="embed", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # Limit the maximum number of sequences to avoid the + # test going OOM during the warmup run + max_num_seqs=32, + ) as vllm_model, + ): + vllm_model.encode(prompt) + + +MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", MODELS) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, +) -> None: + _run_test( + vllm_runner, + model, + ) diff --git a/vllm/config.py b/vllm/config.py index 223c1968c2750..764472c47ef64 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -651,6 +651,8 @@ class ModelConfig: self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.multimodal_config = self._init_multimodal_config() + self.model_supports_multimodal_raw_input = ( + self.registry.supports_multimodal_raw_input(self.architectures)) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() @@ -1243,10 +1245,10 @@ class ModelConfig: return self.get_hf_config_sliding_window() def get_vocab_size(self) -> int: - return self.hf_text_config.vocab_size + return getattr(self.hf_text_config, "vocab_size", 0) def get_hidden_size(self) -> int: - return self.hf_text_config.hidden_size + return getattr(self.hf_text_config, "hidden_size", 0) @property def is_deepseek_mla(self) -> bool: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e2f8de1990b5f..3081995e693f2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -238,14 +238,14 @@ class LLMEngine: self.log_stats = log_stats self.use_cached_outputs = use_cached_outputs - if not self.model_config.skip_tokenizer_init: - self.tokenizer = self._init_tokenizer() - self.detokenizer = Detokenizer(self.tokenizer) - tokenizer_group = self.get_tokenizer_group() - else: + if self.model_config.skip_tokenizer_init: self.tokenizer = None self.detokenizer = None tokenizer_group = None + else: + self.tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() # Ensure that the function doesn't contain a reference to self, # to avoid engine GC issues diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 8f6a7db7aa8db..957b57276b4ca 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -136,6 +136,40 @@ def supports_multimodal( return getattr(model, "supports_multimodal", False) +@runtime_checkable +class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): + """The interface required for all multi-modal models.""" + + supports_multimodal_raw_input: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports multi-modal inputs and processes + them in their raw form and not embeddings. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + +@overload +def supports_multimodal_raw_input( + model: object) -> TypeIs[SupportsMultiModalWithRawInput]: + ... + + +@overload +def supports_multimodal_raw_input( + model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: + ... + + +def supports_multimodal_raw_input( + model: Union[type[object], object] +) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], + TypeIs[SupportsMultiModalWithRawInput]]: + return getattr(model, "supports_multimodal_raw_input", False) + + @runtime_checkable class SupportsScoreTemplate(Protocol): """The interface required for all models that support score template.""" diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index d51fcec07fd6a..0f00fd47fe4fc 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM/NASA Prithvi Geospatial model.""" + from collections.abc import Iterable, Mapping, Sequence from typing import Optional, Union @@ -27,13 +28,14 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.pooler import (AllPool, PoolerHead, PoolerIdentity, SimplePooler) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (IsAttentionFree, - SupportsMultiModal, - SupportsV0Only) +from vllm.model_executor.models.interfaces import ( + IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput) from vllm.model_executor.models.utils import AutoWeightsLoader from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) + MultiModalFieldElem, MultiModalInputs, + MultiModalKwargs, MultiModalKwargsItem, + MultiModalSharedField, PlaceholderRange) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) @@ -62,8 +64,9 @@ class PrithviGeoSpatialMAEInputBuilder( # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. return { - "pixel_values": torch.full((1, 6, 512, 512), 1.0), - "location_coords": torch.full((1, 2), 1.0), + "pixel_values": torch.full((6, 512, 512), 1.0, + dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } @@ -75,8 +78,10 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - location_coords=MultiModalFieldConfig.batched("image"), + pixel_values=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + location_coords=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), ) def _get_prompt_updates( @@ -99,23 +104,48 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor): for k, v in mm_data.items(): mm_kwargs[k] = v + mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} + + # This model receives in input a multi-dimensional tensor representing + # a single image patch and therefore it is not to be split + # into multiple elements, but rather to be considered a single one. + # Hence, the decision of using a MultiModalSharedField. + # The expected shape is (num_channels, width, height). + + # This model however allows the user to also submit multiple image + # patches as a batch, adding a further dimension to the above shape. + # At this stage we only support submitting one patch per request and + # batching is achieved via vLLM batching. + # TODO (christian-pinto): enable support for multi patch requests + # in tandem with vLLM batching. + multimodal_kwargs_items = [ + MultiModalKwargsItem.from_elems([ + MultiModalFieldElem( + modality="image", + key=key, + data=data, + field=MultiModalSharedField(1), + ) for key, data in mm_kwargs.items() + ]) + ] return MultiModalInputs( type="multimodal", prompt=prompt, prompt_token_ids=[1], - mm_kwargs=MultiModalKwargs(mm_kwargs), + mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items), mm_hashes=None, - mm_placeholders={}, + mm_placeholders=mm_placeholders, ) @MULTIMODAL_REGISTRY.register_processor( PrithviGeoSpatialMAEMultiModalProcessor, info=PrithviGeoSpatialMAEProcessingInfo, - dummy_inputs=PrithviGeoSpatialMAEInputBuilder) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, - SupportsV0Only): + dummy_inputs=PrithviGeoSpatialMAEInputBuilder, +) +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, + SupportsMultiModalWithRawInput): """Prithvi Masked Autoencoder""" is_pooling_model = True @@ -128,10 +158,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, raise ValueError("Only image modality is supported") def _instantiate_model(self, config: dict) -> Optional[nn.Module]: - # We might be able/need to support different tasks with this same model if config["task_args"]["task"] == "SemanticSegmentationTask": from terratorch.cli_tools import SemanticSegmentationTask + task = SemanticSegmentationTask( config["model_args"], config["task_args"]["model_factory"], @@ -144,7 +174,8 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, scheduler_hparams=config["scheduler_params"], plot_on_val=config["task_args"]["plot_on_val"], freeze_decoder=config["task_args"]["freeze_decoder"], - freeze_backbone=config["task_args"]["freeze_backbone"]) + freeze_backbone=config["task_args"]["freeze_backbone"], + ) return task.model else: @@ -168,12 +199,10 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, def _parse_and_validate_multimodal_data( self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - pixel_values = kwargs.pop("pixel_values", None) if not isinstance(pixel_values, torch.Tensor): raise ValueError(f"Incorrect type of pixel_values. " f"Got type: {type(pixel_values)}") - pixel_values = torch.unbind(pixel_values, dim=0)[0] location_coords = kwargs.pop("location_coords", None) if not isinstance(location_coords, torch.Tensor): @@ -185,6 +214,17 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, return pixel_values, location_coords + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + # We do not really use any input tokens and therefore no embeddings + # to be calculated. However, due to the mandatory token ids in + # the input prompt we pass one token and the size of the dummy + # embedding tensors must reflect that. + return torch.empty((input_ids.shape[0], 0)) + def forward( self, input_ids: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index fafb6a704383b..2aaac7798fc01 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -22,8 +22,8 @@ from vllm.logger import init_logger from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_pp, - supports_transcription, supports_v0_only) + supports_multimodal, supports_multimodal_raw_input, + supports_pp, supports_transcription, supports_v0_only) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -287,6 +287,7 @@ class _ModelInfo: is_pooling_model: bool supports_cross_encoding: bool supports_multimodal: bool + supports_multimodal_raw_input: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @@ -304,6 +305,7 @@ class _ModelInfo: is_pooling_model=True, # Can convert any model into a pooling model supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), + supports_multimodal_raw_input=supports_multimodal_raw_input(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), @@ -573,6 +575,13 @@ class _ModelRegistry: model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_multimodal + def supports_multimodal_raw_input( + self, + architectures: Union[str, list[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_multimodal_raw_input + def is_pp_supported_model( self, architectures: Union[str, list[str]], diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 27aaa661c35c8..c44fcacd246c4 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -266,7 +266,7 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if tokenizer is None: + if tokenizer is None and not model_config.skip_tokenizer_init: tokenizer = cached_tokenizer_from_config(model_config) if disable_cache is None: mm_config = model_config.get_multimodal_config() diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 79b5d5ae4a23e..95a474228d4f9 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -94,11 +94,14 @@ class AsyncLLM(EngineClient): self.log_requests = log_requests self.log_stats = log_stats - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + if self.model_config.skip_tokenizer_init: + self.tokenizer = None + else: + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( @@ -525,6 +528,10 @@ class AsyncLLM(EngineClient): self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: + if self.tokenizer is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + return self.tokenizer.get_lora_tokenizer(lora_request) async def is_tracing_enabled(self) -> bool: diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a2328c37ba0c5..29aca1ad698e7 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -82,11 +82,14 @@ class LLMEngine: self.dp_group = None self.should_execute_dummy_batch = False - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + if self.model_config.skip_tokenizer_init: + self.tokenizer = None + else: + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) # Processor (convert Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config=vllm_config, diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bcd61d1f0aa1..3be6c48212140 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -327,14 +327,16 @@ class OutputProcessor: if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") - req_state = RequestState.from_new_request( - tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats) + tokenizer = None if not self.tokenizer else \ + self.tokenizer.get_lora_tokenizer(request.lora_request) + + req_state = RequestState.from_new_request(tokenizer=tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7af4ed54a2207..725152f978d64 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -380,7 +380,6 @@ class Processor: prompt_type: Literal["encoder", "decoder"], ): model_config = self.model_config - tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) prompt_ids = prompt_inputs["prompt_token_ids"] if not prompt_ids: @@ -389,9 +388,14 @@ class Processor: else: raise ValueError(f"The {prompt_type} prompt cannot be empty") - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: - raise ValueError(f"Token id {max_input_id} is out of vocabulary") + if self.model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + max_input_id = max(prompt_ids, default=0) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2078fedac9223..864cf91e78508 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -126,6 +126,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.is_multimodal_model = model_config.is_multimodal_model self.is_pooling_model = model_config.pooler_config is not None + self.model_supports_multimodal_raw_input = ( + model_config.model_supports_multimodal_raw_input) self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -328,6 +330,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): Args: scheduler_output: The scheduler output. """ + # Attention free models have zero kv_cache_goups, however models + # like Mamba are also attention free but use the kv_cache for + # keeping its internal state. This is why we check the number + # of kv_cache groups instead of solely checking + # for self.model_config.is_attention_free. + if len(self.kv_cache_config.kv_cache_groups) == 0: + return + self.attn_metadata_builders[0].reorder_batch(self.input_batch, scheduler_output) @@ -565,6 +575,38 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _init_model_kwargs_for_multimodal_model( + self, + scheduler_output: Optional["SchedulerOutput"] = None, + num_reqs: int = -1, + ) -> dict[str, Any]: + + model_kwargs: dict[str, Any] = {} + if self.model_supports_multimodal_raw_input: + # This model requires the raw multimodal data in input. + if scheduler_output: + multi_modal_kwargs_list = [] + for req in scheduler_output.scheduled_new_reqs: + req_mm_inputs = req.mm_inputs + if not isinstance(req_mm_inputs, list): + req_mm_inputs = list(req_mm_inputs) + multi_modal_kwargs_list.extend(req_mm_inputs) + multi_modal_kwargs = MultiModalKwargs.batch( + multi_modal_kwargs_list) + else: + # The only case where SchedulerOutput is None is for + # a dummy run let's get some dummy data. + dummy_data = [ + self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=1).multi_modal_data for i in range(num_reqs) + ] + multi_modal_kwargs = MultiModalKwargs.batch(dummy_data) + + model_kwargs.update(multi_modal_kwargs) + + return model_kwargs + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, @@ -1359,10 +1401,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] + + model_kwargs = self._init_model_kwargs_for_multimodal_model( + scheduler_output=scheduler_output) inputs_embeds = self.model.get_input_embeddings( input_ids=input_ids, multimodal_embeddings=mm_embeds or None, ) + # TODO(woosuk): Avoid the copy. Optimize. self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) inputs_embeds = self.inputs_embeds[:num_input_tokens] @@ -1374,6 +1420,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None + model_kwargs = {} if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] else: @@ -1406,6 +1453,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device, + ), ) self.maybe_wait_for_kv_save() @@ -2084,11 +2135,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_scheduled_tokens): model = self.model if self.is_multimodal_model: + model_kwargs = self._init_model_kwargs_for_multimodal_model( + num_reqs=num_reqs) input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None + model_kwargs = {} + if self.uses_mrope: positions = self.mrope_positions[:, :num_tokens] else: @@ -2117,7 +2172,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device, + ), ) + if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: