mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.conftest import VllmRunner
|
|
from vllm.utils import set_default_torch_num_threads
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model",
|
|
[
|
|
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11",
|
|
"mgazz/Prithvi_v2_eo_300_tl_unet_agb"
|
|
],
|
|
)
|
|
def test_inference(
|
|
vllm_runner: type[VllmRunner],
|
|
model: str,
|
|
) -> None:
|
|
|
|
pixel_values = torch.full((6, 512, 512), 1.0, dtype=torch.float16)
|
|
location_coords = torch.full((1, 2), 1.0, dtype=torch.float16)
|
|
prompt = dict(prompt_token_ids=[1],
|
|
multi_modal_data=dict(pixel_values=pixel_values,
|
|
location_coords=location_coords))
|
|
with (
|
|
set_default_torch_num_threads(1),
|
|
vllm_runner(
|
|
model,
|
|
runner="pooling",
|
|
dtype=torch.float16,
|
|
enforce_eager=True,
|
|
skip_tokenizer_init=True,
|
|
# Limit the maximum number of sequences to avoid the
|
|
# test going OOM during the warmup run
|
|
max_num_seqs=32,
|
|
) as vllm_model,
|
|
):
|
|
|
|
vllm_output = vllm_model.llm.encode(prompt)
|
|
assert torch.equal(
|
|
torch.isnan(vllm_output[0].outputs.data).any(),
|
|
torch.tensor(False))
|