From 6eaf1e5c52d5e72a577ad03d378a28b39f0e849e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 17 Mar 2025 18:00:17 +0800 Subject: [PATCH] [Misc] Add `--seed` option to offline multi-modal examples (#14934) Signed-off-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 7 +- examples/offline_inference/audio_language.py | 132 +++-- .../encoder_decoder_multimodal.py | 48 +- examples/offline_inference/vision_language.py | 455 ++++++++++++------ .../vision_language_embedding.py | 31 +- .../vision_language_multi_image.py | 179 ++++--- 6 files changed, 537 insertions(+), 315 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f85572e7c234c..f5be8dca05f1d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -226,10 +226,13 @@ steps: - python3 offline_inference/basic/chat.py - python3 offline_inference/prefix_caching.py - python3 offline_inference/llm_engine_example.py - - python3 offline_inference/vision_language.py - - python3 offline_inference/vision_language_multi_image.py + - python3 offline_inference/audio_language.py --seed 0 + - python3 offline_inference/vision_language.py --seed 0 + - python3 offline_inference/vision_language_embedding.py --seed 0 + - python3 offline_inference/vision_language_multi_image.py --seed 0 - VLLM_USE_V1=0 python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference/encoder_decoder.py + - python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0 - python3 offline_inference/basic/classify.py - python3 offline_inference/basic/embed.py - python3 offline_inference/basic/score.py diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 293b9fddac89e..02dbdcb64232f 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -7,11 +7,13 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ import os +from dataclasses import asdict +from typing import NamedTuple, Optional from huggingface_hub import snapshot_download from transformers import AutoTokenizer -from vllm import LLM, SamplingParams +from vllm import LLM, EngineArgs, SamplingParams from vllm.assets.audio import AudioAsset from vllm.lora.request import LoRARequest from vllm.utils import FlexibleArgumentParser @@ -23,21 +25,31 @@ question_per_audio_count = { 2: "What sport and what nursery rhyme are referenced?" } + +class ModelRequestData(NamedTuple): + engine_args: EngineArgs + prompt: str + stop_token_ids: Optional[list[int]] = None + lora_requests: Optional[list[LoRARequest]] = None + + # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on # lower-end GPUs. # Unless specified, these settings have been tested to work on a single L4. # MiniCPM-O -def run_minicpmo(question: str, audio_count: int): +def run_minicpmo(question: str, audio_count: int) -> ModelRequestData: model_name = "openbmb/MiniCPM-o-2_6" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - llm = LLM(model=model_name, - trust_remote_code=True, - max_model_len=4096, - max_num_seqs=5, - limit_mm_per_prompt={"audio": audio_count}) + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}, + ) stop_tokens = ['<|im_end|>', '<|endoftext|>'] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] @@ -52,11 +64,16 @@ def run_minicpmo(question: str, audio_count: int): tokenize=False, add_generation_prompt=True, chat_template=audio_chat_template) - return llm, prompt, stop_token_ids + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + stop_token_ids=stop_token_ids, + ) # Phi-4-multimodal-instruct -def run_phi4mm(questions: str, audio_count: int): +def run_phi4mm(question: str, audio_count: int) -> ModelRequestData: """ Phi-4-multimodal-instruct supports both image and audio inputs. Here, we show how to process audio inputs. @@ -67,9 +84,9 @@ def run_phi4mm(questions: str, audio_count: int): speech_lora_path = os.path.join(model_path, "speech-lora") placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)]) - prompts = f"<|user|>{placeholders}{questions}<|end|><|assistant|>" + prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>" - llm = LLM( + engine_args = EngineArgs( model=model_path, trust_remote_code=True, max_model_len=4096, @@ -79,24 +96,24 @@ def run_phi4mm(questions: str, audio_count: int): lora_extra_vocab_size=0, limit_mm_per_prompt={"audio": audio_count}, ) - lora_request = LoRARequest("speech", 1, speech_lora_path) - # To maintain code compatibility in this script, we add LoRA here. - llm.llm_engine.add_lora(lora_request=lora_request) - # You can also add LoRA using: - # llm.generate(prompts, lora_request=lora_request,...) - stop_token_ids = None - return llm, prompts, stop_token_ids + return ModelRequestData( + engine_args=engine_args, + prompt=prompts, + lora_requests=[LoRARequest("speech", 1, speech_lora_path)], + ) # Qwen2-Audio -def run_qwen2_audio(question: str, audio_count: int): +def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData: model_name = "Qwen/Qwen2-Audio-7B-Instruct" - llm = LLM(model=model_name, - max_model_len=4096, - max_num_seqs=5, - limit_mm_per_prompt={"audio": audio_count}) + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}, + ) audio_in_prompt = "".join([ f"Audio {idx+1}: " @@ -107,12 +124,15 @@ def run_qwen2_audio(question: str, audio_count: int): "<|im_start|>user\n" f"{audio_in_prompt}{question}<|im_end|>\n" "<|im_start|>assistant\n") - stop_token_ids = None - return llm, prompt, stop_token_ids + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) # Ultravox 0.5-1B -def run_ultravox(question: str, audio_count: int): +def run_ultravox(question: str, audio_count: int) -> ModelRequestData: model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -124,29 +144,39 @@ def run_ultravox(question: str, audio_count: int): tokenize=False, add_generation_prompt=True) - llm = LLM(model=model_name, - max_model_len=4096, - max_num_seqs=5, - trust_remote_code=True, - limit_mm_per_prompt={"audio": audio_count}) - stop_token_ids = None - return llm, prompt, stop_token_ids + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + trust_remote_code=True, + limit_mm_per_prompt={"audio": audio_count}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) # Whisper -def run_whisper(question: str, audio_count: int): +def run_whisper(question: str, audio_count: int) -> ModelRequestData: assert audio_count == 1, ( "Whisper only support single audio input per prompt") model_name = "openai/whisper-large-v3-turbo" prompt = "<|startoftranscript|>" - llm = LLM(model=model_name, - max_model_len=448, - max_num_seqs=5, - limit_mm_per_prompt={"audio": audio_count}) - stop_token_ids = None - return llm, prompt, stop_token_ids + engine_args = EngineArgs( + model=model_name, + max_model_len=448, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) model_example_map = { @@ -164,14 +194,24 @@ def main(args): raise ValueError(f"Model type {model} is not supported.") audio_count = args.num_audios - llm, prompt, stop_token_ids = model_example_map[model]( - question_per_audio_count[audio_count], audio_count) + req_data = model_example_map[model](question_per_audio_count[audio_count], + audio_count) + + engine_args = asdict(req_data.engine_args) | {"seed": args.seed} + llm = LLM(**engine_args) + + # To maintain code compatibility in this script, we add LoRA here. + # You can also add LoRA using: + # llm.generate(prompts, lora_request=lora_request,...) + if req_data.lora_requests: + for lora_request in req_data.lora_requests: + llm.llm_engine.add_lora(lora_request=lora_request) # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. sampling_params = SamplingParams(temperature=0.2, max_tokens=64, - stop_token_ids=stop_token_ids) + stop_token_ids=req_data.stop_token_ids) mm_data = {} if audio_count > 0: @@ -183,7 +223,7 @@ def main(args): } assert args.num_prompts > 0 - inputs = {"prompt": prompt, "multi_modal_data": mm_data} + inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data} if args.num_prompts > 1: # Batch inference inputs = [inputs] * args.num_prompts @@ -214,6 +254,10 @@ if __name__ == "__main__": default=1, choices=[0, 1, 2], help="Number of audio items per prompt.") + parser.add_argument("--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.") args = parser.parse_args() main(args) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index f44bc423658ec..6d0c3ac1ee09a 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -4,16 +4,23 @@ This example shows how to use vLLM for running offline inference with the explicit/implicit prompt format on enc-dec LMMs for text generation. """ import time +from collections.abc import Sequence +from dataclasses import asdict +from typing import NamedTuple -from vllm import LLM, SamplingParams +from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset from vllm.utils import FlexibleArgumentParser +class ModelRequestData(NamedTuple): + engine_args: EngineArgs + prompts: Sequence[PromptType] + + def run_florence2(): - # Create a Florence-2 encoder/decoder model instance - llm = LLM( + engine_args = EngineArgs( model="microsoft/Florence-2-large", tokenizer="facebook/bart-large", max_num_seqs=8, @@ -39,12 +46,15 @@ def run_florence2(): "decoder_prompt": "", }, ] - return llm, prompts + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) def run_mllama(): - # Create a Mllama encoder/decoder model instance - llm = LLM( + engine_args = EngineArgs( model="meta-llama/Llama-3.2-11B-Vision-Instruct", max_model_len=4096, max_num_seqs=2, @@ -69,12 +79,15 @@ def run_mllama(): "decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501 }, ] - return llm, prompts + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) def run_whisper(): - # Create a Whisper encoder/decoder model instance - llm = LLM( + engine_args = EngineArgs( model="openai/whisper-large-v3-turbo", max_model_len=448, max_num_seqs=16, @@ -99,7 +112,11 @@ def run_whisper(): "decoder_prompt": "<|startoftranscript|>", } ] - return llm, prompts + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) model_example_map = { @@ -114,7 +131,12 @@ def main(args): if model not in model_example_map: raise ValueError(f"Model type {model} is not supported.") - llm, prompts = model_example_map[model]() + req_data = model_example_map[model]() + + engine_args = asdict(req_data.engine_args) | {"seed": args.seed} + llm = LLM(**engine_args) + + prompts = req_data.prompts # Create a sampling params object. sampling_params = SamplingParams( @@ -153,6 +175,10 @@ if __name__ == "__main__": default="mllama", choices=model_example_map.keys(), help='Huggingface "model_type".') + parser.add_argument("--seed", + type=int, + default=None, + help="Set the seed when initializing `vllm.LLM`.") args = parser.parse_args() main(args) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 432cda5e24396..58fd5e53bf8dc 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -8,122 +8,164 @@ on HuggingFace model repository. """ import os import random +from dataclasses import asdict +from typing import NamedTuple, Optional from huggingface_hub import snapshot_download from transformers import AutoTokenizer -from vllm import LLM, SamplingParams +from vllm import LLM, EngineArgs, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.lora.request import LoRARequest from vllm.utils import FlexibleArgumentParser + +class ModelRequestData(NamedTuple): + engine_args: EngineArgs + prompts: list[str] + stop_token_ids: Optional[list[int]] = None + lora_requests: Optional[list[LoRARequest]] = None + + # NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on # lower-end GPUs. # Unless specified, these settings have been tested to work on a single L4. # Aria -def run_aria(questions: list[str], modality: str): +def run_aria(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "rhymes-ai/Aria" # NOTE: Need L40 (or equivalent) to avoid OOM - llm = LLM(model=model_name, - max_model_len=4096, - max_num_seqs=2, - dtype="bfloat16", - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + dtype="bfloat16", + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) prompts = [(f"<|im_start|>user\n<|img|>{question}" "<|im_end|>\n<|im_start|>assistant\n") for question in questions] stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] - return llm, prompts, stop_token_ids + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + stop_token_ids=stop_token_ids, + ) # BLIP-2 -def run_blip2(questions: list[str], modality: str): +def run_blip2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" # BLIP-2 prompt format is inaccurate on HuggingFace model repository. # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa prompts = [f"Question: {question} Answer:" for question in questions] - llm = LLM(model="Salesforce/blip2-opt-2.7b", - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - stop_token_ids = None - return llm, prompts, stop_token_ids + engine_args = EngineArgs( + model="Salesforce/blip2-opt-2.7b", + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) # Chameleon -def run_chameleon(questions: list[str], modality: str): +def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" prompts = [f"{question}" for question in questions] - llm = LLM(model="facebook/chameleon-7b", - max_model_len=4096, - max_num_seqs=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - stop_token_ids = None - return llm, prompts, stop_token_ids + engine_args = EngineArgs( + model="facebook/chameleon-7b", + max_model_len=4096, + max_num_seqs=2, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) # Deepseek-VL2 -def run_deepseek_vl2(questions: list[str], modality: str): +def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "deepseek-ai/deepseek-vl2-tiny" - llm = LLM(model=model_name, - max_model_len=4096, - max_num_seqs=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, - hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}) + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}, + ) prompts = [ f"<|User|>: \n{question}\n\n<|Assistant|>:" for question in questions ] - stop_token_ids = None - return llm, prompts, stop_token_ids + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) # Florence2 -def run_florence2(question: str, modality: str): +def run_florence2(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" - llm = LLM(model="microsoft/Florence-2-large", - tokenizer="facebook/bart-large", - max_num_seqs=8, - trust_remote_code=True, - dtype="bfloat16", - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + engine_args = EngineArgs( + model="microsoft/Florence-2-large", + tokenizer="facebook/bart-large", + max_num_seqs=8, + trust_remote_code=True, + dtype="bfloat16", + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) - prompt = "" - stop_token_ids = None - return llm, prompt, stop_token_ids + prompts = ["" for _ in questions] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) # Fuyu -def run_fuyu(questions: list[str], modality: str): +def run_fuyu(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" prompts = [f"{question}\n" for question in questions] - llm = LLM(model="adept/fuyu-8b", - max_model_len=2048, - max_num_seqs=2, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - stop_token_ids = None - return llm, prompts, stop_token_ids + engine_args = EngineArgs( + model="adept/fuyu-8b", + max_model_len=2048, + max_num_seqs=2, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) # Gemma 3 -def run_gemma3(questions: list[str], modality: str): +def run_gemma3(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "google/gemma-3-4b-it" - llm = LLM( + engine_args = EngineArgs( model=model_name, max_model_len=2048, max_num_seqs=2, @@ -135,22 +177,27 @@ def run_gemma3(questions: list[str], modality: str): prompts = [("user\n" f"{question}\n" "model\n") for question in questions] - stop_token_ids = None - return llm, prompts, stop_token_ids + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) # GLM-4v -def run_glm4v(questions: list[str], modality: str): +def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "THUDM/glm-4v-9b" - llm = LLM(model=model_name, - max_model_len=2048, - max_num_seqs=2, - trust_remote_code=True, - enforce_eager=True, - hf_overrides={"architectures": ["GLM4VForCausalLM"]}, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) + engine_args = EngineArgs( + model=model_name, + max_model_len=2048, + max_num_seqs=2, + trust_remote_code=True, + enforce_eager=True, + hf_overrides={"architectures": ["GLM4VForCausalLM"]}, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) prompts = [ f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ @@ -158,16 +205,21 @@ def run_glm4v(questions: list[str], modality: str): ] stop_token_ids = [151329, 151336, 151338] - return llm, prompts, stop_token_ids + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + stop_token_ids=stop_token_ids, + ) # H2OVL-Mississippi -def run_h2ovl(questions: list[str], modality: str): +def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "h2oai/h2ovl-mississippi-800m" - llm = LLM( + engine_args = EngineArgs( model=model_name, trust_remote_code=True, max_model_len=8192, @@ -187,15 +239,20 @@ def run_h2ovl(questions: list[str], modality: str): # Stop tokens for H2OVL-Mississippi # https://huggingface.co/h2oai/h2ovl-mississippi-800m stop_token_ids = [tokenizer.eos_token_id] - return llm, prompts, stop_token_ids + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + stop_token_ids=stop_token_ids, + ) # Idefics3-8B-Llama3 -def run_idefics3(questions: list[str], modality: str): +def run_idefics3(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "HuggingFaceM4/Idefics3-8B-Llama3" - llm = LLM( + engine_args = EngineArgs( model=model_name, max_model_len=8192, max_num_seqs=2, @@ -212,17 +269,20 @@ def run_idefics3(questions: list[str], modality: str): prompts = [( f"<|begin_of_text|>User:{question}\nAssistant:" ) for question in questions] - stop_token_ids = None - return llm, prompts, stop_token_ids + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) # InternVL -def run_internvl(questions: list[str], modality: str): +def run_internvl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" model_name = "OpenGVLab/InternVL2-2B" - llm = LLM( + engine_args = EngineArgs( model=model_name, trust_remote_code=True, max_model_len=4096, @@ -245,53 +305,75 @@ def run_internvl(questions: list[str], modality: str): # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] - return llm, prompts, stop_token_ids + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + stop_token_ids=stop_token_ids, + ) # LLaVA-1.5 -def run_llava(questions: list[str], modality: str): +def run_llava(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" prompts = [ f"USER: \n{question}\nASSISTANT:" for question in questions ] - llm = LLM(model="llava-hf/llava-1.5-7b-hf", - max_model_len=4096, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - stop_token_ids = None - return llm, prompts, stop_token_ids + engine_args = EngineArgs( + model="llava-hf/llava-1.5-7b-hf", + max_model_len=4096, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) # LLaVA-1.6/LLaVA-NeXT -def run_llava_next(questions: list[str], modality: str): +def run_llava_next(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" prompts = [f"[INST] \n{question} [/INST]" for question in questions] - llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", - max_model_len=8192, - disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - stop_token_ids = None - return llm, prompts, stop_token_ids + engine_args = EngineArgs( + model="llava-hf/llava-v1.6-mistral-7b-hf", + max_model_len=8192, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) # LlaVA-NeXT-Video # Currently only support for video input -def run_llava_next_video(questions: list[str], modality: str): +def run_llava_next_video(questions: list[str], + modality: str) -> ModelRequestData: assert modality == "video" prompts = [ f"USER: