diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index ea7a793d026b..e55181e4f490 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -289,6 +289,106 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_llava(question: str, image_urls: list[str]) -> ModelRequestData: + # NOTE: CAUTION! Original Llava models wasn't really trained on multi-image inputs, + # it will generate poor response for multi-image inputs! + model_name = "llava-hf/llava-1.5-7b-hf" + engine_args = EngineArgs( + model=model_name, + max_num_seqs=16, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + +def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "llava-hf/llava-v1.6-mistral-7b-hf" + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + max_num_seqs=16, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + +def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "llava-hf/llava-onevision-qwen2-7b-ov-hf" + engine_args = EngineArgs( + model=model_name, + max_model_len=16384, + max_num_seqs=16, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + } + ] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct" @@ -737,6 +837,9 @@ model_example_map = { "idefics3": load_idefics3, "internvl_chat": load_internvl, "kimi_vl": load_kimi_vl, + "llava": load_llava, + "llava-next": load_llava_next, + "llava-onevision": load_llava_onevision, "llama4": load_llama4, "mistral3": load_mistral3, "mllama": load_mllama,