mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[Neuron] Add Multi-Modal model support for Neuron (#18921)
Signed-off-by: Satyajith Chilappagari <satchill@amazon.com> Co-authored-by: Ashraf Mahgoub <ashymahg@amazon.com> Co-authored-by: Rohith Nallamaddi <nalrohit@amazon.com> Co-authored-by: FeliciaLuo <luof@amazon.com> Co-authored-by: Elaine Zhao <elaineyz@amazon.com>
This commit is contained in:
parent
b8b904795d
commit
2a50ef5760
105
examples/offline_inference/neuron_multimodal.py
Normal file
105
examples/offline_inference/neuron_multimodal.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from neuronx_distributed_inference.models.mllama.utils import add_instruct
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams, TextPrompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_image(image_url):
|
||||||
|
image = Image.open(requests.get(image_url, stream=True).raw)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
# Model Inputs
|
||||||
|
PROMPTS = [
|
||||||
|
"What is in this image? Tell me a story",
|
||||||
|
"What is the recipe of mayonnaise in two sentences?",
|
||||||
|
"Describe this image",
|
||||||
|
"What is the capital of Italy famous for?",
|
||||||
|
]
|
||||||
|
IMAGES = [
|
||||||
|
get_image(
|
||||||
|
"https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
get_image(
|
||||||
|
"https://images.pexels.com/photos/1108099/pexels-photo-1108099.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
SAMPLING_PARAMS = [
|
||||||
|
dict(top_k=1, temperature=1.0, top_p=1.0, max_tokens=16)
|
||||||
|
for _ in range(len(PROMPTS))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_VLLM_mllama_model_inputs(prompt, single_image, sampling_params):
|
||||||
|
# Prepare all inputs for mllama generation, including:
|
||||||
|
# 1. put text prompt into instruct chat template
|
||||||
|
# 2. compose single text and single image prompt into Vllm's prompt class
|
||||||
|
# 3. prepare sampling parameters
|
||||||
|
input_image = single_image
|
||||||
|
has_image = torch.tensor([1])
|
||||||
|
if isinstance(single_image, torch.Tensor) and single_image.numel() == 0:
|
||||||
|
has_image = torch.tensor([0])
|
||||||
|
|
||||||
|
instruct_prompt = add_instruct(prompt, has_image)
|
||||||
|
inputs = TextPrompt(prompt=instruct_prompt)
|
||||||
|
|
||||||
|
if input_image is not None:
|
||||||
|
inputs["multi_modal_data"] = {"image": input_image}
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(**sampling_params)
|
||||||
|
return inputs, sampling_params
|
||||||
|
|
||||||
|
|
||||||
|
def print_outputs(outputs):
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
assert (
|
||||||
|
len(PROMPTS) == len(IMAGES) == len(SAMPLING_PARAMS)
|
||||||
|
), f"""Text, image prompts and sampling parameters should have the
|
||||||
|
same batch size; but got {len(PROMPTS)}, {len(IMAGES)},
|
||||||
|
and {len(SAMPLING_PARAMS)}"""
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(
|
||||||
|
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
|
max_num_seqs=1,
|
||||||
|
max_model_len=4096,
|
||||||
|
block_size=4096,
|
||||||
|
device="neuron",
|
||||||
|
tensor_parallel_size=32,
|
||||||
|
override_neuron_config={
|
||||||
|
"sequence_parallel_enabled": False,
|
||||||
|
"skip_warmup": True,
|
||||||
|
"save_sharded_checkpoint": True,
|
||||||
|
"on_device_sampling_config": {
|
||||||
|
"global_topk": 1,
|
||||||
|
"dynamic": False,
|
||||||
|
"deterministic": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
batched_inputs = []
|
||||||
|
batched_sample_params = []
|
||||||
|
for pmpt, img, params in zip(PROMPTS, IMAGES, SAMPLING_PARAMS):
|
||||||
|
inputs, sampling_params = get_VLLM_mllama_model_inputs(pmpt, img, params)
|
||||||
|
# test batch-size = 1
|
||||||
|
outputs = llm.generate(inputs, sampling_params)
|
||||||
|
print_outputs(outputs)
|
||||||
|
batched_inputs.append(inputs)
|
||||||
|
batched_sample_params.append(sampling_params)
|
||||||
|
|
||||||
|
# test batch-size = 4
|
||||||
|
outputs = llm.generate(batched_inputs, batched_sample_params)
|
||||||
|
print_outputs(outputs)
|
||||||
@ -1360,6 +1360,16 @@ class ModelConfig:
|
|||||||
@property
|
@property
|
||||||
def is_encoder_decoder(self) -> bool:
|
def is_encoder_decoder(self) -> bool:
|
||||||
"""Extract the HF encoder/decoder model flag."""
|
"""Extract the HF encoder/decoder model flag."""
|
||||||
|
"""
|
||||||
|
For Mllama, VLLM overrides HF's is_encoder_decoder flag and sets it to
|
||||||
|
True to enable cross-attention
|
||||||
|
Neuron needs all multimodal data to be in the decoder and does not
|
||||||
|
need to explicitly enable cross-attention
|
||||||
|
"""
|
||||||
|
if (current_platform.is_neuron()
|
||||||
|
and self.hf_config.model_type == "mllama"):
|
||||||
|
return False
|
||||||
|
|
||||||
return is_encoder_decoder(self.hf_config)
|
return is_encoder_decoder(self.hf_config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -204,6 +204,11 @@ class NeuronMllamaForCausalLM(nn.Module):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
on_device_sampling_disabled: bool = False) -> None:
|
on_device_sampling_disabled: bool = False) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# has_image is the only multimodal input that is used in
|
||||||
|
# token-generation
|
||||||
|
# This is a cache (on CPU) that saves has_image data per sequence id
|
||||||
|
# The number of entries in this cache is <= Batch-Size
|
||||||
|
self.has_image_cache: dict[int, torch.Tensor] = {}
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logits_processor = LogitsProcessor(
|
self.logits_processor = LogitsProcessor(
|
||||||
config.get_text_config().vocab_size, logits_as_input=True)
|
config.get_text_config().vocab_size, logits_as_input=True)
|
||||||
@ -215,11 +220,57 @@ class NeuronMllamaForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Lazy initialized
|
# Lazy initialized
|
||||||
self.model: nn.Module
|
self.model: nn.Module
|
||||||
|
self.is_reorder_needed: bool = True
|
||||||
|
|
||||||
|
def read_from_has_image_cache(self, seq_ids: torch.Tensor):
|
||||||
|
has_image_list = []
|
||||||
|
for index in range(len(seq_ids)):
|
||||||
|
seq_id = seq_ids[index].item()
|
||||||
|
if seq_id in self.has_image_cache:
|
||||||
|
has_image_list.append(self.has_image_cache[seq_id])
|
||||||
|
else:
|
||||||
|
has_image_list.append(torch.tensor([0]))
|
||||||
|
return torch.tensor(has_image_list)
|
||||||
|
|
||||||
|
def write_to_has_image_cache(self, seq_ids: torch.Tensor,
|
||||||
|
has_image: torch.Tensor):
|
||||||
|
for index in range(len(seq_ids)):
|
||||||
|
seq_id = seq_ids[index].item()
|
||||||
|
if index < len(has_image):
|
||||||
|
self.has_image_cache[seq_id] = has_image[index]
|
||||||
|
else:
|
||||||
|
self.has_image_cache[seq_id] = torch.zeros(1)
|
||||||
|
|
||||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||||
seq_ids: torch.Tensor, pixel_values: torch.Tensor,
|
seq_ids: torch.Tensor, pixel_values: torch.Tensor,
|
||||||
aspect_ratios: torch.Tensor, num_chunks: torch.Tensor,
|
aspect_ratios: torch.Tensor, num_chunks: torch.Tensor,
|
||||||
has_image: torch.Tensor, sampling_params) -> torch.Tensor:
|
has_image: torch.Tensor, sampling_params) -> torch.Tensor:
|
||||||
|
|
||||||
|
# We update the has_image cache during prefill
|
||||||
|
# and read the has_image cache during decode
|
||||||
|
if input_ids.shape[-1] > 1: # prefill
|
||||||
|
self.write_to_has_image_cache(seq_ids, has_image)
|
||||||
|
else:
|
||||||
|
has_image = self.read_from_has_image_cache(seq_ids)
|
||||||
|
bs = input_ids.shape[0]
|
||||||
|
num_chunks = torch.zeros((bs, 1))
|
||||||
|
aspect_ratios = torch.zeros((bs, 1, 2))
|
||||||
|
|
||||||
|
input_block_ids = seq_ids
|
||||||
|
origin_input_block_ids = seq_ids
|
||||||
|
if self.is_reorder_needed:
|
||||||
|
# sort block ids sequentially for perf/neuron support reasons
|
||||||
|
input_block_ids, sorted_indices = torch.sort(input_block_ids)
|
||||||
|
input_ids = torch.index_select(input_ids, 0, sorted_indices)
|
||||||
|
positions = torch.index_select(positions, 0, sorted_indices)
|
||||||
|
sampling_params = torch.index_select(sampling_params, 0,
|
||||||
|
sorted_indices)
|
||||||
|
pixel_values = torch.index_select(pixel_values, 0, sorted_indices)
|
||||||
|
aspect_ratios = torch.index_select(aspect_ratios, 0,
|
||||||
|
sorted_indices)
|
||||||
|
num_chunks = torch.index_select(num_chunks, 0, sorted_indices)
|
||||||
|
has_image = torch.index_select(has_image, 0, sorted_indices)
|
||||||
|
|
||||||
self.vision_mask = create_vision_mask(input_ids, self.vision_token_id)
|
self.vision_mask = create_vision_mask(input_ids, self.vision_token_id)
|
||||||
output = self.model(
|
output = self.model(
|
||||||
input_ids.to(torch.int32),
|
input_ids.to(torch.int32),
|
||||||
@ -235,8 +286,14 @@ class NeuronMllamaForCausalLM(nn.Module):
|
|||||||
has_image=has_image.to(torch.int32),
|
has_image=has_image.to(torch.int32),
|
||||||
)
|
)
|
||||||
if self.config.neuron_config.on_device_sampling_config:
|
if self.config.neuron_config.on_device_sampling_config:
|
||||||
return output.hidden_states
|
output = output.hidden_states
|
||||||
return output.logits[:, -1, :]
|
else:
|
||||||
|
output = output.logits[:, -1, :]
|
||||||
|
|
||||||
|
if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1:
|
||||||
|
restored_indices = torch.argsort(sorted_indices)
|
||||||
|
output = torch.index_select(output, 0, restored_indices)
|
||||||
|
return output
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
@ -299,7 +356,7 @@ class NeuronMllamaForCausalLM(nn.Module):
|
|||||||
self.model = neuronx_model_cls(compiled_model_path)
|
self.model = neuronx_model_cls(compiled_model_path)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||||
self.vision_token_id = tokenizer(
|
self.vision_token_id = tokenizer(
|
||||||
"<|image|>", add_special_tokens=False).input_ids
|
"<|image|>", add_special_tokens=False).input_ids[0]
|
||||||
self.model.load(compiled_model_path)
|
self.model.load(compiled_model_path)
|
||||||
return
|
return
|
||||||
except (FileNotFoundError, ValueError):
|
except (FileNotFoundError, ValueError):
|
||||||
@ -326,7 +383,7 @@ class NeuronMllamaForCausalLM(nn.Module):
|
|||||||
|
|
||||||
# Read "<|image|>" token_id from the tokenizer
|
# Read "<|image|>" token_id from the tokenizer
|
||||||
self.vision_token_id = tokenizer("<|image|>",
|
self.vision_token_id = tokenizer("<|image|>",
|
||||||
add_special_tokens=False).input_ids
|
add_special_tokens=False).input_ids[0]
|
||||||
logger.info("\nLoading model from compiled checkpoint...")
|
logger.info("\nLoading model from compiled checkpoint...")
|
||||||
self.model.load(compiled_model_path)
|
self.model.load(compiled_model_path)
|
||||||
|
|
||||||
|
|||||||
@ -169,6 +169,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
|
|
||||||
mm_kwargs = seq_group_metadata.multi_modal_data
|
mm_kwargs = seq_group_metadata.multi_modal_data
|
||||||
if mm_kwargs:
|
if mm_kwargs:
|
||||||
|
mm_kwargs = self.process_multi_modal_data_neuron(mm_kwargs)
|
||||||
multi_modal_kwargs_list.append(mm_kwargs)
|
multi_modal_kwargs_list.append(mm_kwargs)
|
||||||
|
|
||||||
max_seq_len = max(seq_lens)
|
max_seq_len = max(seq_lens)
|
||||||
@ -274,6 +275,14 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
sampling_params.top_p = top_p
|
sampling_params.top_p = top_p
|
||||||
sampling_params.temperature = temperature
|
sampling_params.temperature = temperature
|
||||||
|
|
||||||
|
# we need multi_modal_data for later tokens as well
|
||||||
|
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
mm_data = seq_group_metadata.multi_modal_data
|
||||||
|
if mm_data:
|
||||||
|
multi_modal_kwargs_list.append(mm_data)
|
||||||
|
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||||
|
|
||||||
sampling_metadata = SamplingMetadata.prepare(
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
seq_group_metadata_list,
|
seq_group_metadata_list,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
@ -422,6 +431,10 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
def vocab_size(self) -> int:
|
def vocab_size(self) -> int:
|
||||||
return self.model_config.get_vocab_size()
|
return self.model_config.get_vocab_size()
|
||||||
|
|
||||||
|
def process_multi_modal_data_neuron(self, mm_data):
|
||||||
|
# this is a no-op for NeuronModelRunner
|
||||||
|
return mm_data
|
||||||
|
|
||||||
def remove_all_loras(self):
|
def remove_all_loras(self):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"LoRAs are not supported for Transformers NeuronX framework")
|
"LoRAs are not supported for Transformers NeuronX framework")
|
||||||
|
|||||||
@ -3,6 +3,8 @@
|
|||||||
from typing import List, Optional, Set
|
from typing import List, Optional, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from neuronx_distributed_inference.models.mllama.aspect_ratio_utils import (
|
||||||
|
get_all_supported_aspect_ratios)
|
||||||
from neuronx_distributed_inference.modules.generation.sampling import (
|
from neuronx_distributed_inference.modules.generation.sampling import (
|
||||||
prepare_sampling_params)
|
prepare_sampling_params)
|
||||||
from neuronx_distributed_inference.modules.lora_serving import (
|
from neuronx_distributed_inference.modules.lora_serving import (
|
||||||
@ -17,7 +19,7 @@ from vllm.model_executor import SamplingMetadata
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader.neuronx_distributed import (
|
from vllm.model_executor.model_loader.neuronx_distributed import (
|
||||||
_get_model_architecture, get_neuron_model)
|
_get_model_architecture, get_neuron_model)
|
||||||
from vllm.platforms import current_platform
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||||
from vllm.worker.neuron_model_runner import (ModelInputForNeuron,
|
from vllm.worker.neuron_model_runner import (ModelInputForNeuron,
|
||||||
NeuronModelRunner)
|
NeuronModelRunner)
|
||||||
@ -121,42 +123,28 @@ class NeuronxDistributedModelRunner(NeuronModelRunner):
|
|||||||
sampling_params = self.get_nxd_sampling_params(
|
sampling_params = self.get_nxd_sampling_params(
|
||||||
model_input.sampling_metadata)
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
if model_input.multi_modal_kwargs.get('image') is not None:
|
if model_input.multi_modal_kwargs.get('pixel_values') is not None:
|
||||||
pixel_values = []
|
|
||||||
aspect_ratios = []
|
|
||||||
num_chunks = []
|
|
||||||
has_image = []
|
|
||||||
for multi_modal_input in model_input.multi_modal_kwargs.get(
|
|
||||||
'image'):
|
|
||||||
image_tensors = self.get_multi_modal_data_neuron(
|
|
||||||
multi_modal_input.squeeze(0))
|
|
||||||
pixel_values.append(image_tensors[0])
|
|
||||||
aspect_ratios.append(image_tensors[1])
|
|
||||||
num_chunks.append(image_tensors[2])
|
|
||||||
has_image.append(image_tensors[3])
|
|
||||||
|
|
||||||
pixel_values = torch.cat(pixel_values, dim=0)
|
|
||||||
aspect_ratios = torch.cat(aspect_ratios, dim=0)
|
|
||||||
num_chunks = torch.cat(num_chunks, dim=0)
|
|
||||||
has_image = torch.cat(has_image, dim=0)
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
seq_ids=model_input.input_block_ids,
|
seq_ids=model_input.input_block_ids,
|
||||||
pixel_values=pixel_values,
|
pixel_values=model_input.multi_modal_kwargs.get(
|
||||||
aspect_ratios=aspect_ratios,
|
'pixel_values'),
|
||||||
|
aspect_ratios=model_input.multi_modal_kwargs.get(
|
||||||
|
'aspect_ratios'),
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
num_chunks=num_chunks,
|
num_chunks=model_input.multi_modal_kwargs.get('num_chunks'),
|
||||||
has_image=has_image,
|
has_image=model_input.multi_modal_kwargs.get(
|
||||||
|
'has_image').squeeze(1),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560],
|
bs = model_input.input_tokens.shape[0] if (model_input.input_tokens
|
||||||
|
is not None) else 1
|
||||||
|
empty_pixel_values = torch.zeros([bs, 1, 4, 3, 560, 560],
|
||||||
dtype=torch.bfloat16)
|
dtype=torch.bfloat16)
|
||||||
empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64)
|
empty_aspect_ratios = torch.ones([bs, 1, 2], dtype=torch.int64)
|
||||||
num_chunks = torch.tensor([[1]
|
num_chunks = torch.zeros((bs, 1), dtype=torch.int32)
|
||||||
]) # dummy num_chunks, will not be used
|
has_image = torch.zeros([bs], dtype=torch.int32)
|
||||||
has_image = torch.tensor([0])
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
@ -175,6 +163,27 @@ class NeuronxDistributedModelRunner(NeuronModelRunner):
|
|||||||
|
|
||||||
return [output]
|
return [output]
|
||||||
|
|
||||||
|
def process_multi_modal_data_neuron(self, mm_data):
|
||||||
|
# Neuron uses aspect_ratios instead of aspect_ratio_ids
|
||||||
|
all_supported_aspect_ratios = get_all_supported_aspect_ratios(
|
||||||
|
self.model.config.vision_config.max_num_tiles)
|
||||||
|
aspect_ratio_ids = mm_data.get("aspect_ratio_ids")
|
||||||
|
mm_data["aspect_ratios"] = torch.tensor(
|
||||||
|
all_supported_aspect_ratios[aspect_ratio_ids]).unsqueeze(0)
|
||||||
|
|
||||||
|
# Neuron's num_chunks is HF's num_tiles
|
||||||
|
mm_data["num_chunks"] = mm_data.get("num_tiles")
|
||||||
|
|
||||||
|
# Input has an image if it has pixel_values
|
||||||
|
bs = mm_data["num_chunks"].shape[0]
|
||||||
|
pixel_values = mm_data.get("pixel_values")
|
||||||
|
if pixel_values is not None and not torch.all(pixel_values == 0):
|
||||||
|
mm_data["has_image"] = torch.ones(bs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
mm_data["has_image"] = torch.zeros(bs)
|
||||||
|
return mm_data
|
||||||
|
|
||||||
def _get_lora_adapter_ids(self, seq_group_metadata_list):
|
def _get_lora_adapter_ids(self, seq_group_metadata_list):
|
||||||
# set LoRA adapter IDs for multi-lora serving
|
# set LoRA adapter IDs for multi-lora serving
|
||||||
batch_size = len(seq_group_metadata_list)
|
batch_size = len(seq_group_metadata_list)
|
||||||
@ -200,7 +209,6 @@ class NeuronxDistributedModelRunner(NeuronModelRunner):
|
|||||||
virtual_engine: int = 0,
|
virtual_engine: int = 0,
|
||||||
finished_requests_ids: Optional[List[str]] = None
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
) -> ModelInputForNeuron:
|
) -> ModelInputForNeuron:
|
||||||
multi_modal_kwargs = None
|
|
||||||
# NOTE: We assume that all sequences in the group are all prompts or
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
# all decodes.
|
# all decodes.
|
||||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
@ -223,6 +231,14 @@ class NeuronxDistributedModelRunner(NeuronModelRunner):
|
|||||||
sampling_params.top_p = top_p
|
sampling_params.top_p = top_p
|
||||||
sampling_params.temperature = temperature
|
sampling_params.temperature = temperature
|
||||||
|
|
||||||
|
# we need multi_modal_data for later tokens as well
|
||||||
|
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
mm_data = seq_group_metadata.multi_modal_data
|
||||||
|
if mm_data:
|
||||||
|
multi_modal_kwargs_list.append(mm_data)
|
||||||
|
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||||
|
|
||||||
lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list)
|
lora_adapter_ids = self._get_lora_adapter_ids(seq_group_metadata_list)
|
||||||
|
|
||||||
sampling_metadata = SamplingMetadata.prepare(
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
@ -236,18 +252,6 @@ class NeuronxDistributedModelRunner(NeuronModelRunner):
|
|||||||
self.pin_memory,
|
self.pin_memory,
|
||||||
generators=self.get_generators(finished_requests_ids))
|
generators=self.get_generators(finished_requests_ids))
|
||||||
|
|
||||||
if current_platform.use_transformers_neuronx(
|
|
||||||
) and not self._on_device_sampling_disabled:
|
|
||||||
# Once the request IDs are changed in current iteration, we will
|
|
||||||
# update the on-device sampling parameters.
|
|
||||||
current_batch_request_ids = [
|
|
||||||
seq_group_meta_data.request_id
|
|
||||||
for seq_group_meta_data in seq_group_metadata_list
|
|
||||||
]
|
|
||||||
if current_batch_request_ids != self._previous_batch_request_ids:
|
|
||||||
self._update_neuron_sampling_params(seq_group_metadata_list)
|
|
||||||
self._previous_batch_request_ids = current_batch_request_ids
|
|
||||||
|
|
||||||
return ModelInputForNeuron(input_tokens=input_tokens,
|
return ModelInputForNeuron(input_tokens=input_tokens,
|
||||||
input_positions=input_positions,
|
input_positions=input_positions,
|
||||||
input_block_ids=input_block_ids,
|
input_block_ids=input_block_ids,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user