Add NeuronxDistributedInference support, Speculative Decoding, Dynamic on-device sampling (#16357)

Signed-off-by: Satyajith Chilappagari <satchill@amazon.com>
Co-authored-by: Aaron Dou <yzdou@amazon.com>
Co-authored-by: Shashwat Srijan <sssrijan@amazon.com>
Co-authored-by: Chongming Ni <chongmni@amazon.com>
Co-authored-by: Amulya Ballakur <amulyaab@amazon.com>
Co-authored-by: Patrick Lange <patlange@amazon.com>
Co-authored-by: Elaine Zhao <elaineyz@amazon.com>
Co-authored-by: Lin Lin Pan <tailinpa@amazon.com>
Co-authored-by: Navyadhara Gogineni <navyadha@amazon.com>
Co-authored-by: Yishan McNabb <yishanm@amazon.com>
Co-authored-by: Mrinal Shukla <181322398+mrinalks@users.noreply.github.com>
This commit is contained in:
Satyajith Chilappagari 2025-05-07 00:07:30 -07:00 committed by GitHub
parent ba7703e659
commit 043e4c4955
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 1622 additions and 101 deletions

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to run offline inference with an EAGLE speculative
decoding model on neuron. To use EAGLE speculative decoding, you must use
a draft model that is specifically fine-tuned for EAGLE speculation.
Additionally, to use EAGLE with NxD Inference, the draft model must include
the LM head weights from the target model. These weights are shared between
the draft and target model.
"""
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"What is annapurna labs?",
]
# Create a sampling params object.
sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True)
# Create an LLM.
llm = LLM(
model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct",
speculative_config={
"model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft",
"num_speculative_tokens": 5,
"max_model_len": 2048
},
max_num_seqs=4,
# The max_model_len and block_size arguments are required to be same as
# max sequence length when targeting neuron device.
# Currently, this is a known limitation in continuous batching support
# in neuronx-distributed-inference.
max_model_len=2048,
block_size=2048,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.
device="neuron",
tensor_parallel_size=32,
override_neuron_config={
"enable_eagle_speculation": True,
"enable_fused_speculation": True
},
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}")

View File

@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to run offline inference with a speculative
decoding model on neuron.
"""
import os
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, I am a language model and I can help",
"The president of the United States is",
"The capital of France is",
]
def config_buckets():
"""Configure context length and token gen buckets."""
# creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
def initialize_model():
"""Create an LLM with speculative decoding."""
return LLM(
model="openlm-research/open_llama_7b",
speculative_config={
"model": "openlm-research/open_llama_3b",
"num_speculative_tokens": 4,
"max_model_len": 2048
},
max_num_seqs=4,
max_model_len=2048,
block_size=2048,
use_v2_block_manager=True,
device="neuron",
tensor_parallel_size=32,
)
def process_requests(model: LLM, sampling_params: SamplingParams):
"""Generate texts from prompts and print them."""
outputs = model.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
def main():
"""Main function that sets up the model and processes prompts."""
config_buckets()
model = initialize_model()
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, top_k=1)
process_requests(model, sampling_params)
if __name__ == '__main__':
main()

View File

@ -5,4 +5,5 @@
packaging>=24.2
setuptools>=77.0.3,<80.0.0
torch-neuronx >= 2.5.0
neuronx-cc
neuronx-cc>=2.0.0a0
torchvision # Required for Llama3.2 multimodal image preprocessing

View File

@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
import os
from unittest.mock import MagicMock
from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.platforms.neuron import NeuronFramework
from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.worker.neuron_model_runner import NeuronModelRunner
os.environ[
'VLLM_NEURON_FRAMEWORK'] = NeuronFramework.TRANSFORMERS_NEURONX.value
def _create_neuron_model_runner(model: str, *args,
**kwargs) -> NeuronModelRunner:
engine_args = EngineArgs(model, *args, **kwargs)
engine_config = engine_args.create_engine_config()
vllm_config = VllmConfig(
model_config=engine_config.model_config,
parallel_config=engine_config.parallel_config,
scheduler_config=engine_config.scheduler_config,
device_config=engine_config.device_config,
)
neuron_model_runner = NeuronModelRunner(vllm_config=vllm_config)
return neuron_model_runner
def test_update_neuron_sampling_params_not_full_batch():
os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0"
model_runner = _create_neuron_model_runner(
"facebook/opt-125m",
seed=0,
dtype="float16",
max_num_seqs=2,
)
assert not model_runner._on_device_sampling_disabled
# Test sampling param updating only when TNx is framework
# NxDI handles sampling parameter updating inside model
if current_platform.use_transformers_neuronx():
model_mock = MagicMock()
model_runner.model = model_mock
seq_group_metadata_list = [
SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0.5,
top_k=1,
top_p=0.5),
block_tables={0: [1]},
)
]
model_runner.prepare_model_input(seq_group_metadata_list)
# Index neuron sampling parameters based on block_tables indices.
# The first block_id of the sequence 0 is 1, so its parameters are
# placed at index 1. So the sampling parameters will be:
# Index 0: default sampling parameters
# Index 1: sequecne 0's sampling parameters.
neuron_sampling_params = (
model_runner.model_config.neuron_sampling_params)
assert neuron_sampling_params.temperature == [1.0, 0.5]
assert neuron_sampling_params.top_k == [
model_runner._MAX_NEURON_SAMPLING_TOP_K, 1
]
assert neuron_sampling_params.top_p == [1.0, 0.5]
model_mock.model.update_generation_config.assert_called_once_with(
neuron_sampling_params)
def test_update_neuron_sampling_params_full_batch():
os.environ["NEURON_ON_DEVICE_SAMPLING_DISABLED"] = "0"
model_runner = _create_neuron_model_runner(
"facebook/opt-125m",
seed=0,
dtype="float16",
max_num_seqs=2,
)
assert not model_runner._on_device_sampling_disabled
# Test sampling param updating only when TNx is framework
# NxDI handles sampling parameter updating inside model
if current_platform.use_transformers_neuronx():
model_mock = MagicMock()
model_runner.model = model_mock
seq_group_metadata_list = [
SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={0: SequenceData.from_seqs([1, 2, 3])},
sampling_params=SamplingParams(temperature=0.5,
top_k=1,
top_p=0.5),
block_tables={0: [1]},
),
SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={1: SequenceData.from_seqs([4, 5, 6])},
sampling_params=SamplingParams(temperature=0.2,
top_k=2,
top_p=0.2),
block_tables={1: [0]},
)
]
model_runner.prepare_model_input(seq_group_metadata_list)
# Index neuron sampling parameters based on block_tables indices.
# The first block_id of the sequence 0 is 1, so its parameters are
# placed at index 1. So the sampling parameters will be:
# Index 0: sequence 1's sampling parameters
# Index 1: sequecne 0's sampling parameters.
neuron_sampling_params = (
model_runner.model_config.neuron_sampling_params)
assert neuron_sampling_params.temperature == [0.2, 0.5]
assert neuron_sampling_params.top_k == [2, 1]
assert neuron_sampling_params.top_p == [0.2, 0.5]
model_mock.model.update_generation_config.assert_called_once_with(
neuron_sampling_params)

View File

@ -2273,6 +2273,9 @@ class SpeculativeConfig:
"""Scaling factor for entropy-based threshold, applied when using
`TypicalAcceptanceSampler`."""
speculative_token_tree: Optional[str] = None
"""Specifies the tree structure for speculative token generation.
"""
# required configuration params passed from engine
target_model_config: ModelConfig = field(default=None,
init=True) # type: ignore
@ -2447,10 +2450,11 @@ class SpeculativeConfig:
"Chunked prefill and EAGLE are not compatible "
"when using V0.")
from vllm.platforms import current_platform
from vllm.transformers_utils.configs.eagle import (
EAGLEConfig)
if isinstance(self.draft_model_config.hf_config,
EAGLEConfig):
EAGLEConfig) or current_platform.is_neuron():
pass
else:
eagle_config = EAGLEConfig(

View File

@ -399,10 +399,8 @@ class LLMEngine:
self.scheduler,
self.seq_counter,
get_tokenizer_for_seq,
stop_checker=StopChecker(
self.scheduler_config.max_model_len,
get_tokenizer_for_seq,
),
stop_checker=StopChecker(self.scheduler_config.max_model_len,
get_tokenizer_for_seq),
))
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""Utilities for selecting and loading neuron models."""
"""Utilities for selecting and loading Neuron models in transformers-neuronx
framework."""
import ast
import copy
import importlib
import os
@ -9,7 +11,8 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@ -113,6 +116,67 @@ class NeuronCausalLM(nn.Module):
self.model.to_neuron()
class NeuronSpeculationCausalLM(nn.Module):
"""A Neuron-optimized causal language model with speculative decoding."""
SPECULATION_TERMINATION_ID = -1
def __init__(self, speculation_model) -> None:
super().__init__()
self.model = speculation_model
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
) -> torch.Tensor:
tokens, counts = self.model.speculative_iteration(
input_ids, positions, input_block_ids)
# Mark the end of accepted speculative tokens for each sequence with the
# speculation termination id.
batch_size, steps = tokens.shape
mask = torch.arange(steps).expand(batch_size, -1) >= counts
tokens[mask] = self.SPECULATION_TERMINATION_ID
return tokens
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[List[SamplerOutput]]:
batch_size, num_steps = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
sampler_output_list = []
for step_index in range(num_steps):
if all(token_id == self.SPECULATION_TERMINATION_ID
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
for sequence_index in range(batch_size):
token_id = accepted_token_ids_by_step[step_index][
sequence_index]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
return sampler_output_list
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
@ -138,6 +202,7 @@ def _get_buckets(env: str, default_value: List[int]) -> List[int]:
def _get_default_neuron_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
"""Generate a neuron config based on vllm config args."""
from transformers_neuronx.config import ContinuousBatchingConfig
from transformers_neuronx.constants import LAYOUT_BSH
@ -162,6 +227,27 @@ def _get_default_neuron_config(model_config: ModelConfig,
return default_neuron_args
def _get_default_neuron_config_for_speculation(
model_config: ModelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
"""Generate a neuron config for speculative decoding based on
vllm config args."""
from transformers_neuronx.config import ContinuousBatchingConfig
from transformers_neuronx.constants import LAYOUT_BSH
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
default_neuron_args = dict(collectives_layout=LAYOUT_BSH,
attention_layout=LAYOUT_BSH,
fuse_qkv=True,
on_device_embedding=True,
continuous_batching=continuous_batching_config,
on_device_generation=copy.deepcopy(
model_config.neuron_sampling_params))
return default_neuron_args
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
if not _is_neuron_on_device_sampling_disabled(model_config):
return copy.deepcopy(model_config.neuron_sampling_params)
@ -213,7 +299,7 @@ def _get_neuron_config_after_override(default_neuron_config,
def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
"""Initializes a neuron-optimized model for inference."""
# Create a model instance.
model = NeuronCausalLM(
model_config.hf_config,
@ -230,7 +316,6 @@ def get_neuron_model(model_config: ModelConfig,
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
# Load the weights from the cached or downloaded files.
model.load_weights(model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
@ -240,3 +325,151 @@ def get_neuron_model(model_config: ModelConfig,
batch_size=scheduler_config.max_num_seqs)
return model.eval()
def get_neuron_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized speculation model for inference.
This method is only applicable for speculation with a standalone draft model
"""
from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder
# For Eagle SD, we need to pass in additional parameters in neuron config.
is_eagle = getattr(speculation_config.draft_model_config.hf_config,
"is_eagle", False)
# Create target model instance.
target_model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config_for_speculation(
model_config, parallel_config, scheduler_config)
if is_eagle:
default_neuron_config_args['is_eagle_target'] = True
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
target_model.load_weights(
model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
target_model.eval()
# Create draft model instance.
draft_model = NeuronCausalLM(
speculation_config.draft_model_config.hf_config)
default_draft_neuron_config_args = (
_get_default_neuron_config_for_speculation(
speculation_config.draft_model_config, parallel_config,
scheduler_config))
if is_eagle:
default_draft_neuron_config_args['is_eagle_draft'] = True
default_draft_neuron_config_args['has_pre_attention_norm'] = False
draft_neuron_config = _get_neuron_config_after_override(
default_draft_neuron_config_args,
speculation_config.draft_model_config.override_neuron_config)
draft_model.load_weights(speculation_config.draft_model_config.model,
tp_degree=speculation_config.
draft_parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[
speculation_config.draft_model_config.dtype],
neuron_config=draft_neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
draft_model.eval()
num_speculative_tokens = speculation_config.num_speculative_tokens
# Create speculation model instance.
speculation_model = FusedSpeculativeDecoder(draft_model.model,
target_model.model,
num_speculative_tokens)
speculation_model.to_neuron()
return NeuronSpeculationCausalLM(speculation_model)
def get_neuron_eagle_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized EAGLE speculation model for inference."""
from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder
# Create target model instance.
target_model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config_for_speculation(
model_config, parallel_config, scheduler_config)
default_neuron_config_args['is_eagle_target'] = True
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
target_model.load_weights(
model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
target_model.eval()
# Create draft model instance.
draft_model = NeuronCausalLM(
speculation_config.draft_model_config.hf_config)
default_draft_neuron_config_args = (
_get_default_neuron_config_for_speculation(
speculation_config.draft_model_config, parallel_config,
scheduler_config))
default_draft_neuron_config_args['is_eagle_draft'] = True
default_draft_neuron_config_args['has_pre_attention_norm'] = False
draft_neuron_config = _get_neuron_config_after_override(
default_draft_neuron_config_args,
speculation_config.draft_model_config.override_neuron_config)
draft_model.load_weights(speculation_config.draft_model_config.model,
tp_degree=speculation_config.
draft_parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[
speculation_config.draft_model_config.dtype],
neuron_config=draft_neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
draft_model.eval()
token_tree: Dict[int, List[int]] = ast.literal_eval(
speculation_config.speculative_token_tree)
speculation_model = EagleSpeculativeDecoder(draft_model.model,
target_model.model,
token_tree=token_tree)
speculation_model.to_neuron()
return NeuronSpeculationCausalLM(speculation_model)

View File

@ -0,0 +1,584 @@
# SPDX-License-Identifier: Apache-2.0
"""Utilities for selecting and loading Neuron models in
neuronx-distributed-inference framework."""
# Disabling yapf because yapf and isort have conflicts for the below imports
# yapf: disable
import copy
import hashlib
import importlib
import multiprocessing
import os
import shutil
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from neuronx_distributed_inference.models.config import (
FusedSpecNeuronConfig, OnDeviceSamplingConfig)
from neuronx_distributed_inference.models.mllama.utils import (
create_vision_mask)
from neuronx_distributed_inference.utils.hf_adapter import (
load_pretrained_config)
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput)
# yapf: enable
logger = init_logger(__name__)
TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "float32",
"half": "float16",
"float16": "float16",
"bfloat16": "bfloat16",
"float": "float32",
"float32": "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.float32: "float32",
}
# Models supported by Neuronx distributed for inference.
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str]] = {
"LlamaForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"DbrxForCausalLM":
("neuronx_distributed_inference.models.dbrx.modeling_dbrx",
"NeuronDbrxForCausalLM"),
"MixtralForCausalLM":
("neuronx_distributed_inference.models.mixtral.modeling_mixtral",
"NeuronMixtralForCausalLM"),
"MllamaForConditionalGeneration":
("neuronx_distributed_inference.models.mllama.modeling_mllama",
"NeuronMllamaForCausalLM"),
}
class NeuronCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
) -> torch.Tensor:
output = self.model(input_ids,
attention_mask=None,
position_ids=positions,
seq_ids=input_block_ids,
sampling_params=sampling_params)
# on-device sampling
if self.config.neuron_config.on_device_sampling_config:
return output.hidden_states
else:
return output.logits[:, -1, :]
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
# on-device sampling
if self.config.neuron_config.on_device_sampling_config:
batch_size = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
assert len(seq_ids) == list(batch_size)[0], "batch size mismatch"
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.flatten()
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
step_output_token_ids = []
for i, seq_id in enumerate(seq_ids):
token_id = accepted_token_ids_by_step[i]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
return SamplerOutput(outputs=step_output_token_ids)
else:
return self.sampler(logits, sampling_metadata)
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
self.config.neuron_config = neuron_config
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
hashed_config = hashlib.md5(
config.to_json_string().encode('utf-8')).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
try:
self.model = neuronx_model_cls(compiled_model_path)
override_neuron_config = kwargs["override_neuron_config"]
for k, v in override_neuron_config.items():
setattr(self.model.config.neuron_config, k, v)
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError) as e:
logger.warning("Exception: %s", e)
logger.warning("Failed to load the model from %s, Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
self.model = neuronx_model_cls(model_name_or_path, config)
self.model.compile(compiled_model_path)
self.model.load(compiled_model_path)
class NeuronMllamaForCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
on_device_sampling_disabled: bool = False) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(
config.get_text_config().vocab_size, logits_as_input=True)
self.on_device_sampling_disabled = on_device_sampling_disabled
if self.on_device_sampling_disabled:
# Use default sampler
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
seq_ids: torch.Tensor, pixel_values: torch.Tensor,
aspect_ratios: torch.Tensor, num_chunks: torch.Tensor,
has_image: torch.Tensor, sampling_params) -> torch.Tensor:
self.vision_mask = create_vision_mask(input_ids, self.vision_token_id)
output = self.model(
input_ids.to(torch.int32),
attention_mask=None,
position_ids=positions.to(torch.int32),
seq_ids=seq_ids.flatten().to(torch.int32),
pixel_values=pixel_values.to(
self.config.vision_config.torch_dtype),
aspect_ratios=aspect_ratios.to(torch.int32),
vision_mask=self.vision_mask.to(torch.int32),
sampling_params=sampling_params,
num_chunks=num_chunks.to(torch.int32),
has_image=has_image.to(torch.int32),
)
if self.config.neuron_config.on_device_sampling_config:
return output.hidden_states
return output.logits[:, -1, :]
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(self, hidden_states, sampling_metadata):
if not self.on_device_sampling_disabled:
with torch.profiler.record_function("sample"):
hidden_states = hidden_states.flatten()
res = []
sample_idx = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
samples = []
for seq_id in seq_ids:
token_id = hidden_states[sample_idx].item()
samples.append(
SequenceOutput(
parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
res.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))
next_tokens = SamplerOutput(outputs=res)
else:
next_tokens = self.sampler(None, hidden_states, sampling_metadata)
return next_tokens
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
self.config.neuron_config = neuron_config
logger.info("neuron_config buckets: %s",
self.config.neuron_config.buckets)
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
hashed_config = hashlib.md5(
config.to_json_string().encode('utf-8')).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
try:
self.model = neuronx_model_cls(compiled_model_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.vision_token_id = tokenizer(
"<|image|>", add_special_tokens=False).input_ids
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError):
logger.warning("Failed to load the model from %s, Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
self.model = neuronx_model_cls(model_name_or_path, config)
logger.info("\nCompiling and saving model to %s", model_name_or_path)
p = multiprocessing.Process(target=compile_model,
args=(self, compiled_model_path))
p.start()
p.join()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.save_pretrained(compiled_model_path)
logger.info("Successfully compiled and saved the model in %s",
compiled_model_path)
# Read "<|image|>" token_id from the tokenizer
self.vision_token_id = tokenizer("<|image|>",
add_special_tokens=False).input_ids
logger.info("\nLoading model from compiled checkpoint...")
self.model.load(compiled_model_path)
def compile_model(neuron_model, traced_model_path):
neuron_model.model.compile(traced_model_path)
class NeuronSpeculationCausalLM(nn.Module):
"""A Neuron-optimized causal language model with speculative decoding."""
def __init__(
self,
config: PretrainedConfig,
) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
# Lazy initialized
self.model: nn.Module
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
) -> torch.Tensor:
output = self.model(input_ids,
attention_mask=None,
position_ids=positions,
seq_ids=input_block_ids,
sampling_params=sampling_params)
# CTX encoding
if (positions[:, 0]).sum().item() == 0:
return output.fused_outputs[0][:, 0:1]
# Fused Spec (Generation)
accepted_tokens_with_padding = output.fused_outputs[0]
next_pos_ids = output.fused_outputs[-1]
generated_token_counts = next_pos_ids - positions
assert torch.any(generated_token_counts == 0).item() is False, \
"NxDI model generated no output for one or more sequences."
batch_size, steps = accepted_tokens_with_padding.shape
mask = torch.arange(steps).expand(batch_size,
-1) >= generated_token_counts
accepted_tokens_with_padding[mask] = -1
return accepted_tokens_with_padding
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[List[SamplerOutput]]:
batch_size, num_steps = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
sampler_output_list = []
for step_index in range(num_steps):
if all(token_id == -1
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
for sequence_index in range(batch_size):
token_id = accepted_token_ids_by_step[step_index][
sequence_index]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
return sampler_output_list
def load_weights(self, model_name_or_path: str,
draft_model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
draft_neuron_config = copy.deepcopy(config.neuron_config)
if not config.neuron_config.enable_eagle_speculation:
draft_neuron_config.speculation_length = 0
draft_neuron_config.trace_tokengen_model = True
draft_neuron_config.enable_fused_speculation = False
if config.neuron_config.enable_eagle_speculation:
draft_neuron_config.is_eagle_draft = True
draft_neuron_config.sequence_parallel_enabled = False
draft_config = neuronx_model_cls.get_config_cls()(
draft_neuron_config,
load_config=load_pretrained_config(draft_model_name_or_path))
fused_spec_config = (FusedSpecNeuronConfig(
neuronx_model_cls._model_cls,
draft_config=draft_config,
draft_model_path=draft_model_name_or_path))
config.fused_spec_config = fused_spec_config
self.config.neuron_config = neuron_config
hashed_config = hashlib.md5(
config.to_json_string().encode('utf-8')).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
try:
self.model = neuronx_model_cls(compiled_model_path)
override_neuron_config = kwargs["override_neuron_config"]
for k, v in override_neuron_config.items():
setattr(self.model.config.neuron_config, k, v)
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError) as e:
logger.warning("Exception: %s", e)
logger.warning("Failed to load the model from %s Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
if not os.path.exists(draft_model_name_or_path):
if draft_model_name_or_path != model_name_or_path:
hf_model = AutoModelForCausalLM.from_pretrained(
draft_model_name_or_path)
saved_path = os.path.join("local-models",
draft_model_name_or_path)
hf_model.save_pretrained(saved_path)
draft_model_name_or_path = saved_path
else:
draft_model_name_or_path = model_name_or_path
config.fused_spec_config.draft_model_path = draft_model_name_or_path
self.model = neuronx_model_cls(model_name_or_path, config)
self.model.compile(compiled_model_path)
self.model.load(compiled_model_path)
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS:
return arch
raise ValueError(
f"Model architectures {architectures} are not supported on Neuron "
f"for now. Supported architectures: "
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def _get_default_neuron_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
"""Generate a neuron config based on vllm config args."""
on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True,
deterministic=False)
batch_size = scheduler_config.max_num_seqs
neuron_config = dict(
tp_degree=parallel_config.tensor_parallel_size,
ctx_batch_size=1,
batch_size=batch_size,
max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len,
enable_bucketing=True,
is_continuous_batching=(batch_size > 1),
quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
padding_side="right",
on_device_sampling_config=on_device_sampling_config,
sequence_parallel_enabled=True,
)
return neuron_config
def _get_default_speculation_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Generate a neuron config for speculative decoding based on vllm config
args."""
neuron_config = dict(
tp_degree=parallel_config.tensor_parallel_size,
batch_size=scheduler_config.max_num_seqs,
max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len,
speculation_length=speculation_config.num_speculative_tokens,
trace_tokengen_model=False,
enable_fused_speculation=True,
enable_bucketing=True,
quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
on_device_sampling_config=dict(
top_k=1,
do_sample=False,
))
return neuron_config
def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
"""Update default neuron config values with override args"""
overridden_neuron_config = overridden_neuron_config or {}
default_neuron_config.update(overridden_neuron_config)
return default_neuron_config
def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
"""Initializes a neuron-optimized model for inference."""
model_arch = _get_model_architecture(model_config.hf_config)
if model_arch == "MllamaForConditionalGeneration":
model = NeuronMllamaForCausalLM(model_config.hf_config)
else:
model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
override_neuron_config = model_config.override_neuron_config
model.load_weights(model_config.model,
neuron_config=neuron_config,
override_neuron_config=override_neuron_config)
return model.eval()
def get_neuron_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized speculation model for inference.
This model handles speculation using both a draft model and an EAGLE draft.
"""
model = NeuronSpeculationCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_speculation_config(
model_config, parallel_config, scheduler_config, speculation_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
override_neuron_config = model_config.override_neuron_config
model.load_weights(model_config.model,
speculation_config.draft_model_config.model,
neuron_config=neuron_config,
override_neuron_config=override_neuron_config)
return model.eval()

View File

@ -176,17 +176,26 @@ def cpu_platform_plugin() -> Optional[str]:
def neuron_platform_plugin() -> Optional[str]:
is_neuron = False
tnx_installed = False
nxd_installed = False
logger.debug("Checking if Neuron platform is available.")
try:
import transformers_neuronx # noqa: F401
is_neuron = True
tnx_installed = True
logger.debug("Confirmed Neuron platform is available because"
" transformers_neuronx is found.")
except ImportError as e:
logger.debug("Neuron platform is not available because: %s", str(e))
except ImportError:
pass
try:
import neuronx_distributed_inference # noqa: F401
nxd_installed = True
logger.debug("Confirmed Neuron platform is available because"
" neuronx_distributed_inference is found.")
except ImportError:
pass
is_neuron = tnx_installed or nxd_installed
return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import enum
import os
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from vllm import envs
@ -15,6 +17,11 @@ else:
logger = init_logger(__name__)
class NeuronFramework(enum.Enum):
TRANSFORMERS_NEURONX = "transformers-neuronx"
NEURONX_DISTRIBUTED_INFERENCE = "neuronx-distributed-inference"
class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
device_name: str = "neuron"
@ -43,8 +50,6 @@ class NeuronPlatform(Platform):
assert (vllm_config.lora_config
is None), "LoRA is not supported for Neuron backend."
assert (not vllm_config.speculative_config
), "Speculative decoding not yet supported for Neuron backend."
cache_config = vllm_config.cache_config
if cache_config:
@ -67,3 +72,71 @@ class NeuronPlatform(Platform):
@classmethod
def use_all_gather(cls) -> bool:
return True
@classmethod
@lru_cache
def is_neuronx_distributed_inference(cls) -> bool:
try:
import neuronx_distributed_inference
except ImportError:
neuronx_distributed_inference = None
return neuronx_distributed_inference is not None
@classmethod
@lru_cache
def is_transformers_neuronx(cls) -> bool:
try:
import transformers_neuronx
except ImportError:
transformers_neuronx = None
return transformers_neuronx is not None
def get_neuron_framework_to_use(self):
"""Return the specified framework if corresponding installations are
available.
If no framework is specified, use neuronx-distributed-inference by
default.
If that's unavailable, check and switch to transformers-neuronx.
"""
if not self.is_neuron():
raise AssertionError(
f"Neuron Framework unavailable for platform: {self}")
tnx_installed = self.is_transformers_neuronx()
nxd_installed = self.is_neuronx_distributed_inference()
specified_framework = os.environ.get("VLLM_NEURON_FRAMEWORK")
tnx_framework = NeuronFramework.TRANSFORMERS_NEURONX.value
nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE.value
if specified_framework == tnx_framework and tnx_installed:
return self.TRANSFORMERS_NEURONX
if ((specified_framework == nxd_framework and nxd_installed)
or (specified_framework is None and nxd_installed)):
return NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE
if specified_framework is None and tnx_installed:
return NeuronFramework.TRANSFORMERS_NEURONX
return None
def use_neuronx_distributed(self):
"""
Return True if the framework determined in get_neuron_framework_to_use()
is NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE, False otherwise. This
is used to select the Neuron model framework and framework-specific
configuration to apply during model compilation.
"""
nxd_framework = NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE
return self.get_neuron_framework_to_use() == nxd_framework
def use_transformers_neuronx(self):
"""
Return True if the framework determined in get_neuron_framework_to_use()
is NeuronFramework.TRANSFORMERS_NEURONX, False otherwise. This is used
to select the Neuron model framework and framework-specific
configuration to apply during model compilation.
"""
return self.get_neuron_framework_to_use(
) == NeuronFramework.TRANSFORMERS_NEURONX

View File

@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
from importlib.util import find_spec
from typing import List, Optional
import torch
from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors
from vllm.worker.neuron_model_runner import (ModelInputForNeuron,
NeuronModelRunner)
class MultiStepNeuronModelRunner(NeuronModelRunner):
"""A model runner for multi step decoding using the transformers_neuronx
framework"""
def __init__(
self,
vllm_config: VllmConfig,
):
super().__init__(vllm_config)
self.speculation_config = self.speculative_config
from transformers_neuronx.config import GenerationConfig
self.speculation_config.draft_model_config.neuron_sampling_params = (
GenerationConfig(
max_length=self.scheduler_config.max_model_len,
do_sample=True,
per_batch_line=True,
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
* self.scheduler_config.max_num_seqs,
top_p=[1.0] * self.scheduler_config.max_num_seqs,
temperature=[1.0] * self.scheduler_config.max_num_seqs,
dynamic=True,
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K
))
def load_model(self) -> None:
if find_spec("transformers_neuronx") is not None:
from vllm.model_executor.model_loader.neuron import (
get_neuron_eagle_speculation_model,
get_neuron_speculation_model)
if self.speculation_config.speculative_token_tree is not None:
self.model = get_neuron_eagle_speculation_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
speculation_config=self.speculation_config)
else:
self.model = get_neuron_speculation_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
speculation_config=self.speculation_config)
else:
raise NotImplementedError(
"Supports only Transformer-NeuronX based models.")
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
logits = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
)
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return output

View File

@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
import torch
from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors
from vllm.worker.neuronx_distributed_model_runner import (
NeuronxDistributedModelRunner)
class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
"""A model runner for multi-step decoding using the
neuronx-distributed-inference framework"""
def __init__(
self,
vllm_config: VllmConfig,
):
super().__init__(vllm_config)
def load_model(self) -> None:
from vllm.model_executor.model_loader.neuronx_distributed import (
get_neuron_speculation_model)
self.model = get_neuron_speculation_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
speculation_config=self.speculative_config)
@torch.inference_mode()
def execute_model(
self,
model_input,
kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
sampling_params = torch.tensor([[
seq_group.sampling_params.top_k,
seq_group.sampling_params.top_p,
seq_group.sampling_params.temperature,
] for seq_group in model_input.sampling_metadata.seq_groups])
logits = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
)
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
return output

View File

@ -2,20 +2,20 @@
import os
from dataclasses import dataclass
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers_neuronx.config import GenerationConfig
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.config import DeviceConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs)
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
@ -34,12 +34,18 @@ class ModelInputForNeuron(ModelRunnerInputBase):
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
sampling_metadata: SamplingMetadata = None
multi_modal_kwargs: BatchedTensorInputs = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")
return {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"input_block_ids": self.input_block_ids,
"sampling_metadata": self.sampling_metadata,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
@classmethod
def from_broadcasted_tensor_dict(
@ -47,11 +53,17 @@ class ModelInputForNeuron(ModelRunnerInputBase):
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForNeuron":
assert attn_backend is None
return cls.from_broadcasted_tensor_dict(tensor_dict)
return ModelInputForNeuron(
input_tokens=tensor_dict["input_tokens"],
input_positions=tensor_dict["input_positions"],
input_block_ids=tensor_dict["input_block_ids"],
sampling_metadata=tensor_dict["sampling_metadata"],
multi_modal_kwargs=tensor_dict["multi_modal_kwargs"],
)
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
"""A model runner for AWS Neuron hardware"""
# NEURON has an upper limit on the top_k
_MAX_NEURON_SAMPLING_TOP_K = 256
@ -61,13 +73,20 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
vllm_config: VllmConfig,
):
ModelRunnerBase.__init__(self, vllm_config)
model_config = self.model_config
if model_config is not None and model_config.get_sliding_window():
if (self.model_config is not None
and self.model_config.get_sliding_window()):
logger.warning("Sliding window is not supported on Neuron. "
"The model will run without sliding window.")
self.device_config = (self.device_config if self.device_config
is not None else DeviceConfig())
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
# Multi-modal data support
self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
.create_input_mapper(self.model_config)
# Lazy initialization.
self.model: nn.Module # initialize after load_model.
@ -82,32 +101,33 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self._previous_batch_request_ids: List[str] = []
if not self._on_device_sampling_disabled:
logger.warning(
"On-device sampling is turned on in Neuron by default, only "
"top_k, top_p, and temperature are current supported sampling "
"parameters. To turn off the on-device sampling, please set "
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1."
)
self.model_config.neuron_sampling_params = GenerationConfig(
max_length=self.scheduler_config.max_model_len,
do_sample=True,
per_batch_line=True,
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
* self.scheduler_config.max_num_seqs,
top_p=[1.0] * self.scheduler_config.max_num_seqs,
temperature=[1.0] * self.scheduler_config.max_num_seqs,
dynamic=True,
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)
self._init_neuron_sampling()
def _init_neuron_sampling(self) -> None:
if current_platform.use_transformers_neuronx():
from transformers_neuronx.config import GenerationConfig
else:
from transformers import GenerationConfig
logger.warning(
"On-device sampling is turned on in Neuron by default, only "
"top_k, top_p, and temperature are current supported sampling "
"parameters. To turn off the on-device sampling, please set "
"the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1.")
self.model_config.neuron_sampling_params = GenerationConfig(
max_length=self.scheduler_config.max_model_len,
do_sample=True,
per_batch_line=True,
top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \
* self.scheduler_config.max_num_seqs,
top_p=[1.0] * self.scheduler_config.max_num_seqs,
temperature=[1.0] * self.scheduler_config.max_num_seqs,
dynamic=True,
global_top_k=self._MAX_NEURON_SAMPLING_TOP_K)
def load_model(self) -> None:
if find_spec("transformers_neuronx") is not None:
self.model = get_neuron_model(
self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
else:
raise NotImplementedError(
"Supports only Transformer-NeuronX based models.")
self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
def get_model(self) -> nn.Module:
return self.model
@ -240,6 +260,16 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
(input_tokens, input_positions,
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
seq_lens = None
if not self._on_device_sampling_disabled:
for seq_group_metadata in seq_group_metadata_list:
sampling_params = seq_group_metadata.sampling_params
top_k, top_p, temperature = (
self._convert_to_neuron_sampling_params(sampling_params))
sampling_params.top_k = top_k
sampling_params.top_p = top_p
sampling_params.temperature = temperature
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
@ -251,7 +281,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
self.pin_memory,
generators=self.get_generators(finished_requests_ids))
if not self._on_device_sampling_disabled:
if current_platform.use_transformers_neuronx(
) and not self._on_device_sampling_disabled:
# Once the request IDs are changed in current iteration, we will
# update the on-device sampling parameters.
current_batch_request_ids = [
@ -259,7 +290,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
for seq_group_meta_data in seq_group_metadata_list
]
if current_batch_request_ids != self._previous_batch_request_ids:
self._update_neuron_sampling_params(sampling_metadata)
self._update_neuron_sampling_params(seq_group_metadata_list)
self._previous_batch_request_ids = current_batch_request_ids
return ModelInputForNeuron(input_tokens=input_tokens,
@ -268,31 +299,59 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
sampling_metadata=sampling_metadata,
multi_modal_kwargs=multi_modal_kwargs)
def _update_neuron_sampling_params(self,
sampling_metadata: SamplingMetadata):
def _update_neuron_sampling_params(
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
# Update Neuron sampling parameters (GenerationConfig in Neuron)
current_sampling_params = self.model_config.neuron_sampling_params
assert current_sampling_params is not None, (
f"Failed to update sampling_params, "
f"current sampling params is {current_sampling_params}")
is_update_needed = False
top_k = current_sampling_params.top_k
top_p = current_sampling_params.top_p
temperature = current_sampling_params.temperature
for index, sequence_group_to_sample in enumerate(
sampling_metadata.seq_groups):
top_k[index] = self._convert_to_neuron_top_k(
sequence_group_to_sample.sampling_params.top_k)
top_p[index] = sequence_group_to_sample.sampling_params.top_p
temperature[index] = \
sequence_group_to_sample.sampling_params.temperature
self.model.model.update_generation_config(current_sampling_params)
# The index of a sequence's sampling parameters in neuron is equal to
# its index in `input_block_ids`.
for seq_group_metadata in seq_group_metadata_list:
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
def _convert_to_neuron_top_k(self, top_k: int) -> int:
seq_group_top_k = sampling_params.top_k
seq_group_top_p = sampling_params.top_p
seq_group_temperature = sampling_params.temperature
for seq_id in seq_ids:
index = seq_group_metadata.block_tables[seq_id][0]
if (top_k[index] != seq_group_top_k
or top_p[index] != seq_group_top_p
or temperature[index] != seq_group_temperature):
is_update_needed = True
top_k[index] = seq_group_top_k
top_p[index] = seq_group_top_p
temperature[index] = seq_group_temperature
# update_generation_config is only available in transformers-neuronx
if is_update_needed and current_platform.use_transformers_neuronx():
self.model.model.update_generation_config(current_sampling_params)
def _convert_to_neuron_sampling_params(
self, sampling_params: SamplingParams) -> Tuple[int, float, float]:
# Returns the top_k, top_p and temperature parameters for neuron.
top_k = sampling_params.top_k
top_p = sampling_params.top_p
temperature = sampling_params.temperature
if temperature == 0.0:
# Enable greedy sampling on zero temperature
return (1, 1.0, 1.0)
if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K:
return self._MAX_NEURON_SAMPLING_TOP_K
return top_k
top_k = self._MAX_NEURON_SAMPLING_TOP_K
return (top_k, top_p, temperature)
@torch.inference_mode()
def execute_model(
@ -306,7 +365,26 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
raise ValueError(
"NeuronModelRunner does not support multi-step execution.")
with set_forward_context(None, self.vllm_config, 0):
# extract top_k, top_p and temperature from model_input for neuron
# forward call
sampling_params = (torch.tensor([[
seq_group.sampling_params.top_k, seq_group.sampling_params.top_p,
seq_group.sampling_params.temperature
] for seq_group in model_input.sampling_metadata.seq_groups]))
if current_platform.use_neuronx_distributed():
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
sampling_params=sampling_params,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
or {},
device=self.device),
)
elif current_platform.use_transformers_neuronx():
# [TODO] validate on-device sampling
# The model signature may need change for on-device sampling
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,

View File

@ -1,61 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
"""A Neuron worker class."""
import os
from typing import List, Optional, Tuple
import torch
import torch.distributed
from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.platforms.neuron import NeuronFramework
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoRANotSupportedWorkerBase, WorkerBase,
WorkerInput)
logger = init_logger(__name__)
class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"""A worker class that executes the model on a group of neuron cores.
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = True,
) -> None:
model_runner: NeuronModelRunner
def __init__(self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False) -> None:
WorkerBase.__init__(self, vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner: NeuronModelRunner = NeuronModelRunner(
vllm_config=vllm_config)
self.is_driver_worker = is_driver_worker
neuron_framework = current_platform.get_neuron_framework_to_use()
if neuron_framework == NeuronFramework.TRANSFORMERS_NEURONX:
self.model_runner = self.get_tnx_model_runner(vllm_config)
elif neuron_framework == NeuronFramework.NEURONX_DISTRIBUTED_INFERENCE:
self.model_runner = self.get_neuronx_distributed_model_runner(
vllm_config)
else:
raise NotImplementedError(
"Specified framework" +
f" {os.environ.get('VLLM_NEURON_FRAMEWORK')}" +
" is either not installed or not supported." +
" Supported frameworks: " +
"[transformers-neuronx, neuronx-distributed-inference]")
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]:
assert execute_model_req is not None
assert (not execute_model_req.blocks_to_swap_in
and not execute_model_req.blocks_to_swap_out
and not execute_model_req.blocks_to_copy), (
"Cache operations are not supported for Neuron backend.")
assert execute_model_req.num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.")
output = LocalOrDistributedWorkerBase.execute_model(
self, execute_model_req)
return output
def get_tnx_model_runner(self, vllm_config):
from vllm.worker.multi_step_neuron_model_runner import (
MultiStepNeuronModelRunner)
if self.speculative_config is not None:
return MultiStepNeuronModelRunner(vllm_config=vllm_config)
else:
return NeuronModelRunner(vllm_config=vllm_config)
def get_neuronx_distributed_model_runner(self, vllm_config):
from vllm.worker.multi_step_neuronx_distributed_model_runner import (
MultiStepNeuronxDistributedModelRunner)
from vllm.worker.neuronx_distributed_model_runner import (
NeuronxDistributedModelRunner)
if self.speculative_config is not None:
return MultiStepNeuronxDistributedModelRunner(
vllm_config=vllm_config)
else:
return NeuronxDistributedModelRunner(vllm_config=vllm_config)
def init_device(self) -> None:
self.init_distributed_environment()
@ -121,17 +141,17 @@ class NeuronWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
def init_distributed_environment(self):
"""Neuron uses transformers-neuronx for tensor parallelism.
It has only one process to control multiple devices.
vLLM still needs the environment initialized when TP/PP > 1,
so we initialize a distributed environment with one process.
vLLM still needs the environment initialized when TP/PP > 1
"""
init_distributed_environment(
world_size=1,
rank=0,
local_rank=0,
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
)
ensure_model_parallel_initialized(
1,
1,

View File

@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
import torch
from neuronx_distributed_inference.modules.generation.sampling import (
prepare_sampling_params)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.neuronx_distributed import (
_get_model_architecture, get_neuron_model)
from vllm.sequence import IntermediateTensors
from vllm.worker.neuron_model_runner import (ModelInputForNeuron,
NeuronModelRunner)
logger = init_logger(__name__)
class NeuronxDistributedModelRunner(NeuronModelRunner):
def __init__(
self,
vllm_config: VllmConfig,
):
super().__init__(vllm_config)
def load_model(self) -> None:
self.model = get_neuron_model(self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
def get_nxd_sampling_params(self, sampling_metadata):
if self.model.config.neuron_config.on_device_sampling_config:
max_topk = (self.model.config.neuron_config.
on_device_sampling_config.global_topk)
else:
max_topk = self.model.config.vocab_size
top_k = [1] * self.scheduler_config.max_num_seqs
top_p = [1.0] * self.scheduler_config.max_num_seqs
temperature = [1.0] * self.scheduler_config.max_num_seqs
for index, sequenceGroupToSample in enumerate(
sampling_metadata.seq_groups):
top_k[index] = (sequenceGroupToSample.sampling_params.top_k
if sequenceGroupToSample.sampling_params.top_k > 0
else max_topk)
top_p[index] = sequenceGroupToSample.sampling_params.top_p
temperature[index] = (
sequenceGroupToSample.sampling_params.temperature)
sampling_params = prepare_sampling_params(
batch_size=self.scheduler_config.max_num_seqs,
top_k=top_k,
top_p=top_p,
temperature=temperature)
return sampling_params
def get_multi_modal_data_neuron(self, input_images):
raise NotImplementedError("need to restore multi-modal support")
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[List[SamplerOutput]]:
if num_steps > 1:
raise ValueError(
"NeuronModelRunner does not support multi-step execution.")
if _get_model_architecture(
self.model.config) != "MllamaForConditionalGeneration":
return super().execute_model(model_input, kv_caches,
intermediate_tensors, num_steps)
sampling_params = self.get_nxd_sampling_params(
model_input.sampling_metadata)
if model_input.multi_modal_kwargs.get('image') is not None:
pixel_values = []
aspect_ratios = []
num_chunks = []
has_image = []
for multi_modal_input in model_input.multi_modal_kwargs.get(
'image'):
image_tensors = self.get_multi_modal_data_neuron(
multi_modal_input.squeeze(0))
pixel_values.append(image_tensors[0])
aspect_ratios.append(image_tensors[1])
num_chunks.append(image_tensors[2])
has_image.append(image_tensors[3])
pixel_values = torch.cat(pixel_values, dim=0)
aspect_ratios = torch.cat(aspect_ratios, dim=0)
num_chunks = torch.cat(num_chunks, dim=0)
has_image = torch.cat(has_image, dim=0)
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
seq_ids=model_input.input_block_ids,
pixel_values=pixel_values,
aspect_ratios=aspect_ratios,
sampling_params=sampling_params,
num_chunks=num_chunks,
has_image=has_image,
)
else:
empty_pixel_values = torch.zeros([1, 1, 4, 3, 560, 560],
dtype=torch.bfloat16)
empty_aspect_ratios = torch.ones([1, 1, 2], dtype=torch.int64)
num_chunks = torch.tensor([[1]
]) # dummy num_chunks, will not be used
has_image = torch.tensor([0])
hidden_states = self.model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
seq_ids=model_input.input_block_ids,
pixel_values=empty_pixel_values,
aspect_ratios=empty_aspect_ratios,
sampling_params=sampling_params,
num_chunks=num_chunks,
has_image=has_image,
)
output = self.model.sample(
hidden_states=hidden_states,
sampling_metadata=model_input.sampling_metadata,
)
return [output]