diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 1704130d9131..5411ecbb27b2 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -1711,6 +1711,11 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: ): dataset_class = MTBenchDataset args.hf_split = "train" + elif ( + args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS + or args.hf_name in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS + ): + dataset_class = MultiModalConversationDataset elif ( args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS or args.hf_name in ConversationDataset.SUPPORTED_DATASET_PATHS @@ -2272,11 +2277,70 @@ class HuggingFaceDataset(BenchmarkDataset): class ConversationDataset(HuggingFaceDataset): - """Dataset for conversation data with multimodal support.""" + """Dataset for text-only conversation data.""" + + SUPPORTED_DATASET_PATHS = { + "Aeala/ShareGPT_Vicuna_unfiltered", + } + IS_MULTIMODAL = False + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: int | None = None, + enable_multimodal_chat: bool = False, + request_id_prefix: str = "", + no_oversample: bool = False, + **kwargs, + ) -> list: + # Filter examples with at least 2 conversations + filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) + sampled_requests = [] + ind = 0 + dynamic_output = output_len is None + + for item in filtered_data: + if len(sampled_requests) >= num_requests: + break + conv = item["conversations"] + prompt, completion = conv[0]["value"], conv[1]["value"] + + prompt_ids = tokenizer(prompt).input_ids + completion_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_ids) + completion_len = len(completion_ids) + output_len = completion_len if dynamic_output else output_len + assert isinstance(output_len, int) and output_len > 0 + if dynamic_output and not is_valid_sequence(prompt_len, completion_len): + continue + mm_content = process_image(item["image"]) if "image" in item else None + if enable_multimodal_chat: + # Note: when chat is enabled the request prompt_len is no longer + # accurate and we will be using request output to count the + # actual prompt len and output len + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=mm_content, + request_id=request_id_prefix + str(ind), + ) + ) + ind += 1 + self.maybe_oversample_requests( + sampled_requests, num_requests, request_id_prefix, no_oversample + ) + return sampled_requests + + +class MultiModalConversationDataset(HuggingFaceDataset): + """Dataset for multimodal conversation data.""" SUPPORTED_DATASET_PATHS = { "lmms-lab/LLaVA-OneVision-Data", - "Aeala/ShareGPT_Vicuna_unfiltered", } IS_MULTIMODAL = True diff --git a/vllm/benchmarks/throughput.py b/vllm/benchmarks/throughput.py index 78c0f8bbbda7..23b5faa1b2c3 100644 --- a/vllm/benchmarks/throughput.py +++ b/vllm/benchmarks/throughput.py @@ -21,6 +21,7 @@ from vllm.benchmarks.datasets import ( BurstGPTDataset, ConversationDataset, InstructCoderDataset, + MultiModalConversationDataset, PrefixRepetitionRandomDataset, RandomDataset, SampleRequest, @@ -367,6 +368,11 @@ def get_requests(args, tokenizer): elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: dataset_cls = InstructCoderDataset common_kwargs["dataset_split"] = "train" + elif args.dataset_path in MultiModalConversationDataset.SUPPORTED_DATASET_PATHS: + dataset_cls = MultiModalConversationDataset + common_kwargs["dataset_subset"] = args.hf_subset + common_kwargs["dataset_split"] = args.hf_split + sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: dataset_cls = ConversationDataset common_kwargs["dataset_subset"] = args.hf_subset @@ -456,6 +462,7 @@ def validate_args(args): elif args.dataset_name == "hf": if args.dataset_path in ( VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | MultiModalConversationDataset.SUPPORTED_DATASET_PATHS | ConversationDataset.SUPPORTED_DATASET_PATHS ): assert args.backend == "vllm-chat", (