From 2a50ef57605ffe332a73e50597276b71e9d52676 Mon Sep 17 00:00:00 2001 From: Satyajith Chilappagari Date: Sat, 31 May 2025 03:39:11 -0700 Subject: [PATCH] [Neuron] Add Multi-Modal model support for Neuron (#18921) Signed-off-by: Satyajith Chilappagari Co-authored-by: Ashraf Mahgoub Co-authored-by: Rohith Nallamaddi Co-authored-by: FeliciaLuo Co-authored-by: Elaine Zhao --- .../offline_inference/neuron_multimodal.py | 105 ++++++++++++++++++ vllm/config.py | 10 ++ .../model_loader/neuronx_distributed.py | 65 ++++++++++- vllm/worker/neuron_model_runner.py | 13 +++ .../neuronx_distributed_model_runner.py | 88 ++++++++------- 5 files changed, 235 insertions(+), 46 deletions(-) create mode 100644 examples/offline_inference/neuron_multimodal.py diff --git a/examples/offline_inference/neuron_multimodal.py b/examples/offline_inference/neuron_multimodal.py new file mode 100644 index 000000000000..a9478650b16f --- /dev/null +++ b/examples/offline_inference/neuron_multimodal.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +import requests +import torch +from neuronx_distributed_inference.models.mllama.utils import add_instruct +from PIL import Image + +from vllm import LLM, SamplingParams, TextPrompt + + +def get_image(image_url): + image = Image.open(requests.get(image_url, stream=True).raw) + return image + + +# Model Inputs +PROMPTS = [ + "What is in this image? Tell me a story", + "What is the recipe of mayonnaise in two sentences?", + "Describe this image", + "What is the capital of Italy famous for?", +] +IMAGES = [ + get_image( + "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" + ), + None, + get_image( + "https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" + ), + None, +] +SAMPLING_PARAMS = [ + dict(top_k=1, temperature=1.0, top_p=1.0, max_tokens=16) + for _ in range(len(PROMPTS)) +] + + +def get_VLLM_mllama_model_inputs(prompt, single_image, sampling_params): + # Prepare all inputs for mllama generation, including: + # 1. put text prompt into instruct chat template + # 2. compose single text and single image prompt into Vllm's prompt class + # 3. prepare sampling parameters + input_image = single_image + has_image = torch.tensor([1]) + if isinstance(single_image, torch.Tensor) and single_image.numel() == 0: + has_image = torch.tensor([0]) + + instruct_prompt = add_instruct(prompt, has_image) + inputs = TextPrompt(prompt=instruct_prompt) + + if input_image is not None: + inputs["multi_modal_data"] = {"image": input_image} + + sampling_params = SamplingParams(**sampling_params) + return inputs, sampling_params + + +def print_outputs(outputs): + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == "__main__": + assert ( + len(PROMPTS) == len(IMAGES) == len(SAMPLING_PARAMS) + ), f"""Text, image prompts and sampling parameters should have the + same batch size; but got {len(PROMPTS)}, {len(IMAGES)}, + and {len(SAMPLING_PARAMS)}""" + + # Create an LLM. + llm = LLM( + model="meta-llama/Llama-3.2-11B-Vision-Instruct", + max_num_seqs=1, + max_model_len=4096, + block_size=4096, + device="neuron", + tensor_parallel_size=32, + override_neuron_config={ + "sequence_parallel_enabled": False, + "skip_warmup": True, + "save_sharded_checkpoint": True, + "on_device_sampling_config": { + "global_topk": 1, + "dynamic": False, + "deterministic": False, + }, + }, + ) + + batched_inputs = [] + batched_sample_params = [] + for pmpt, img, params in zip(PROMPTS, IMAGES, SAMPLING_PARAMS): + inputs, sampling_params = get_VLLM_mllama_model_inputs(pmpt, img, params) + # test batch-size = 1 + outputs = llm.generate(inputs, sampling_params) + print_outputs(outputs) + batched_inputs.append(inputs) + batched_sample_params.append(sampling_params) + + # test batch-size = 4 + outputs = llm.generate(batched_inputs, batched_sample_params) + print_outputs(outputs) diff --git a/vllm/config.py b/vllm/config.py index 6cec97a5f11b..dfa44b0440a7 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1360,6 +1360,16 @@ class ModelConfig: @property def is_encoder_decoder(self) -> bool: """Extract the HF encoder/decoder model flag.""" + """ + For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to + True to enable cross-attention + Neuron needs all multimodal data to be in the decoder and does not + need to explicitly enable cross-attention + """ + if (current_platform.is_neuron() + and self.hf_config.model_type == "mllama"): + return False + return is_encoder_decoder(self.hf_config) @property diff --git a/vllm/model_executor/model_loader/neuronx_distributed.py b/vllm/model_executor/model_loader/neuronx_distributed.py index 624bd476c031..72ad4da296ac 100644 --- a/vllm/model_executor/model_loader/neuronx_distributed.py +++ b/vllm/model_executor/model_loader/neuronx_distributed.py @@ -204,6 +204,11 @@ class NeuronMllamaForCausalLM(nn.Module): config: PretrainedConfig, on_device_sampling_disabled: bool = False) -> None: super().__init__() + # has_image is the only multimodal input that is used in + # token-generation + # This is a cache (on CPU) that saves has_image data per sequence id + # The number of entries in this cache is <= Batch-Size + self.has_image_cache: dict[int, torch.Tensor] = {} self.config = config self.logits_processor = LogitsProcessor( config.get_text_config().vocab_size, logits_as_input=True) @@ -215,11 +220,57 @@ class NeuronMllamaForCausalLM(nn.Module): # Lazy initialized self.model: nn.Module + self.is_reorder_needed: bool = True + + def read_from_has_image_cache(self, seq_ids: torch.Tensor): + has_image_list = [] + for index in range(len(seq_ids)): + seq_id = seq_ids[index].item() + if seq_id in self.has_image_cache: + has_image_list.append(self.has_image_cache[seq_id]) + else: + has_image_list.append(torch.tensor([0])) + return torch.tensor(has_image_list) + + def write_to_has_image_cache(self, seq_ids: torch.Tensor, + has_image: torch.Tensor): + for index in range(len(seq_ids)): + seq_id = seq_ids[index].item() + if index < len(has_image): + self.has_image_cache[seq_id] = has_image[index] + else: + self.has_image_cache[seq_id] = torch.zeros(1) def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, seq_ids: torch.Tensor, pixel_values: torch.Tensor, aspect_ratios: torch.Tensor, num_chunks: torch.Tensor, has_image: torch.Tensor, sampling_params) -> torch.Tensor: + + # We update the has_image cache during prefill + # and read the has_image cache during decode + if input_ids.shape[-1] > 1: # prefill + self.write_to_has_image_cache(seq_ids, has_image) + else: + has_image = self.read_from_has_image_cache(seq_ids) + bs = input_ids.shape[0] + num_chunks = torch.zeros((bs, 1)) + aspect_ratios = torch.zeros((bs, 1, 2)) + + input_block_ids = seq_ids + origin_input_block_ids = seq_ids + if self.is_reorder_needed: + # sort block ids sequentially for perf/neuron support reasons + input_block_ids, sorted_indices = torch.sort(input_block_ids) + input_ids = torch.index_select(input_ids, 0, sorted_indices) + positions = torch.index_select(positions, 0, sorted_indices) + sampling_params = torch.index_select(sampling_params, 0, + sorted_indices) + pixel_values = torch.index_select(pixel_values, 0, sorted_indices) + aspect_ratios = torch.index_select(aspect_ratios, 0, + sorted_indices) + num_chunks = torch.index_select(num_chunks, 0, sorted_indices) + has_image = torch.index_select(has_image, 0, sorted_indices) + self.vision_mask = create_vision_mask(input_ids, self.vision_token_id) output = self.model( input_ids.to(torch.int32), @@ -235,8 +286,14 @@ class NeuronMllamaForCausalLM(nn.Module): has_image=has_image.to(torch.int32), ) if self.config.neuron_config.on_device_sampling_config: - return output.hidden_states - return output.logits[:, -1, :] + output = output.hidden_states + else: + output = output.logits[:, -1, :] + + if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1: + restored_indices = torch.argsort(sorted_indices) + output = torch.index_select(output, 0, restored_indices) + return output def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: @@ -299,7 +356,7 @@ class NeuronMllamaForCausalLM(nn.Module): self.model = neuronx_model_cls(compiled_model_path) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.vision_token_id = tokenizer( - "<|image|>", add_special_tokens=False).input_ids + "<|image|>", add_special_tokens=False).input_ids[0] self.model.load(compiled_model_path) return except (FileNotFoundError, ValueError): @@ -326,7 +383,7 @@ class NeuronMllamaForCausalLM(nn.Module): # Read "<|image|>" token_id from the tokenizer self.vision_token_id = tokenizer("<|image|>", - add_special_tokens=False).input_ids + add_special_tokens=False).input_ids[0] logger.info("\nLoading model from compiled checkpoint...") self.model.load(compiled_model_path) diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 292fe57f32ea..3aff3e01aef1 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -169,6 +169,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): mm_kwargs = seq_group_metadata.multi_modal_data if mm_kwargs: + mm_kwargs = self.process_multi_modal_data_neuron(mm_kwargs) multi_modal_kwargs_list.append(mm_kwargs) max_seq_len = max(seq_lens) @@ -274,6 +275,14 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): sampling_params.top_p = top_p sampling_params.temperature = temperature + # we need multi_modal_data for later tokens as well + multi_modal_kwargs_list: List[MultiModalKwargs] = [] + for seq_group_metadata in seq_group_metadata_list: + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + multi_modal_kwargs_list.append(mm_data) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens, @@ -422,6 +431,10 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): def vocab_size(self) -> int: return self.model_config.get_vocab_size() + def process_multi_modal_data_neuron(self, mm_data): + # this is a no-op for NeuronModelRunner + return mm_data + def remove_all_loras(self): raise NotImplementedError( "LoRAs are not supported for Transformers NeuronX framework") diff --git a/vllm/worker/neuronx_distributed_model_runner.py b/vllm/worker/neuronx_distributed_model_runner.py index aa94706c8059..9cd4f88d32f0 100644 --- a/vllm/worker/neuronx_distributed_model_runner.py +++ b/vllm/worker/neuronx_distributed_model_runner.py @@ -3,6 +3,8 @@ from typing import List, Optional, Set import torch +from neuronx_distributed_inference.models.mllama.aspect_ratio_utils import ( + get_all_supported_aspect_ratios) from neuronx_distributed_inference.modules.generation.sampling import ( prepare_sampling_params) from neuronx_distributed_inference.modules.lora_serving import ( @@ -17,7 +19,7 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuronx_distributed import ( _get_model_architecture, get_neuron_model) -from vllm.platforms import current_platform +from vllm.multimodal import MultiModalKwargs from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.worker.neuron_model_runner import (ModelInputForNeuron, NeuronModelRunner) @@ -121,42 +123,28 @@ class NeuronxDistributedModelRunner(NeuronModelRunner): sampling_params = self.get_nxd_sampling_params( model_input.sampling_metadata) - if model_input.multi_modal_kwargs.get('image') is not None: - pixel_values = [] - aspect_ratios = [] - num_chunks = [] - has_image = [] - for multi_modal_input in model_input.multi_modal_kwargs.get( - 'image'): - image_tensors = self.get_multi_modal_data_neuron( - multi_modal_input.squeeze(0)) - pixel_values.append(image_tensors[0]) - aspect_ratios.append(image_tensors[1]) - num_chunks.append(image_tensors[2]) - has_image.append(image_tensors[3]) - - pixel_values = torch.cat(pixel_values, dim=0) - aspect_ratios = torch.cat(aspect_ratios, dim=0) - num_chunks = torch.cat(num_chunks, dim=0) - has_image = torch.cat(has_image, dim=0) - + if model_input.multi_modal_kwargs.get('pixel_values') is not None: hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, seq_ids=model_input.input_block_ids, - pixel_values=pixel_values, - aspect_ratios=aspect_ratios, + pixel_values=model_input.multi_modal_kwargs.get( + 'pixel_values'), + aspect_ratios=model_input.multi_modal_kwargs.get( + 'aspect_ratios'), sampling_params=sampling_params, - num_chunks=num_chunks, - has_image=has_image, + num_chunks=model_input.multi_modal_kwargs.get('num_chunks'), + has_image=model_input.multi_modal_kwargs.get( + 'has_image').squeeze(1), ) else: - empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560], + bs = model_input.input_tokens.shape[0] if (model_input.input_tokens + is not None) else 1 + empty_pixel_values = torch.zeros([bs, 1, 4, 3, 560, 560], dtype=torch.bfloat16) - empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64) - num_chunks = torch.tensor([[1] - ]) # dummy num_chunks, will not be used - has_image = torch.tensor([0]) + empty_aspect_ratios = torch.ones([bs, 1, 2], dtype=torch.int64) + num_chunks = torch.zeros((bs, 1), dtype=torch.int32) + has_image = torch.zeros([bs], dtype=torch.int32) hidden_states = self.model( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -175,6 +163,27 @@ class NeuronxDistributedModelRunner(NeuronModelRunner): return [output] + def process_multi_modal_data_neuron(self, mm_data): + # Neuron uses aspect_ratios instead of aspect_ratio_ids + all_supported_aspect_ratios = get_all_supported_aspect_ratios( + self.model.config.vision_config.max_num_tiles) + aspect_ratio_ids = mm_data.get("aspect_ratio_ids") + mm_data["aspect_ratios"] = torch.tensor( + all_supported_aspect_ratios[aspect_ratio_ids]).unsqueeze(0) + + # Neuron's num_chunks is HF's num_tiles + mm_data["num_chunks"] = mm_data.get("num_tiles") + + # Input has an image if it has pixel_values + bs = mm_data["num_chunks"].shape[0] + pixel_values = mm_data.get("pixel_values") + if pixel_values is not None and not torch.all(pixel_values == 0): + mm_data["has_image"] = torch.ones(bs) + + else: + mm_data["has_image"] = torch.zeros(bs) + return mm_data + def _get_lora_adapter_ids(self, seq_group_metadata_list): # set LoRA adapter IDs for multi-lora serving batch_size = len(seq_group_metadata_list) @@ -200,7 +209,6 @@ class NeuronxDistributedModelRunner(NeuronModelRunner): virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForNeuron: - multi_modal_kwargs = None # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt @@ -223,6 +231,14 @@ class NeuronxDistributedModelRunner(NeuronModelRunner): sampling_params.top_p = top_p sampling_params.temperature = temperature + # we need multi_modal_data for later tokens as well + multi_modal_kwargs_list: List[MultiModalKwargs] = [] + for seq_group_metadata in seq_group_metadata_list: + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + multi_modal_kwargs_list.append(mm_data) + multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) + lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list) sampling_metadata = SamplingMetadata.prepare( @@ -236,18 +252,6 @@ class NeuronxDistributedModelRunner(NeuronModelRunner): self.pin_memory, generators=self.get_generators(finished_requests_ids)) - if current_platform.use_transformers_neuronx( - ) and not self._on_device_sampling_disabled: - # Once the request IDs are changed in current iteration, we will - # update the on-device sampling parameters. - current_batch_request_ids = [ - seq_group_meta_data.request_id - for seq_group_meta_data in seq_group_metadata_list - ] - if current_batch_request_ids != self._previous_batch_request_ids: - self._update_neuron_sampling_params(seq_group_metadata_list) - self._previous_batch_request_ids = current_batch_request_ids - return ModelInputForNeuron(input_tokens=input_tokens, input_positions=input_positions, input_block_ids=input_block_ids,