From 3b7178cfa4a317922d4aef9dd3b2647b8d950e7d Mon Sep 17 00:00:00 2001 From: Liangfu Chen Date: Wed, 28 Feb 2024 09:34:34 -0800 Subject: [PATCH] [Neuron] Support inference with transformers-neuronx (#2569) --- examples/offline_inference_neuron.py | 33 ++++ tests/lora/conftest.py | 8 +- vllm/config.py | 41 ++++- vllm/engine/arg_utils.py | 16 +- vllm/engine/llm_engine.py | 21 ++- vllm/lora/layers.py | 4 + vllm/model_executor/__init__.py | 3 +- vllm/model_executor/layers/sampler.py | 18 +- vllm/model_executor/model_loader.py | 10 +- vllm/model_executor/models/__init__.py | 12 +- vllm/model_executor/models/neuron/llama.py | 79 +++++++++ vllm/model_executor/neuron_model_loader.py | 66 +++++++ vllm/model_executor/sampling_metadata.py | 4 +- vllm/model_executor/utils.py | 17 ++ vllm/utils.py | 8 + vllm/worker/cache_engine.py | 11 +- vllm/worker/model_runner.py | 16 +- vllm/worker/neuron_worker.py | 191 +++++++++++++++++++++ 18 files changed, 516 insertions(+), 42 deletions(-) create mode 100644 examples/offline_inference_neuron.py create mode 100644 vllm/model_executor/models/neuron/llama.py create mode 100644 vllm/model_executor/neuron_model_loader.py create mode 100644 vllm/worker/neuron_worker.py diff --git a/examples/offline_inference_neuron.py b/examples/offline_inference_neuron.py new file mode 100644 index 000000000000..9b9dc4d94892 --- /dev/null +++ b/examples/offline_inference_neuron.py @@ -0,0 +1,33 @@ +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM( + model="openlm-research/open_llama_3b", + max_num_seqs=8, + # The max_model_len and block_size arguments are required to be same as max sequence length, + # when targeting neuron device. Currently, this is a known limitation in continuous batching + # support in transformers-neuronx. + # TODO(liangfu): Support paged-attention in transformers-neuronx. + max_model_len=128, + block_size=128, + # The device can be automatically detected when AWS Neuron SDK is installed. + # The device argument can be either unspecified for automated detection, or explicitly assigned. + device="neuron") +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 0ca0715334c2..75f4e41290c3 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -131,9 +131,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() get_model_old = get_model - def get_model_patched(model_config, device_config, lora_config=None): - return get_model_old(model_config, device_config, - LoRAConfig(max_loras=4, max_lora_rank=8)) + def get_model_patched(model_config, device_config, **kwargs): + return get_model_old(model_config, + device_config, + lora_config=LoRAConfig(max_loras=4, + max_lora_rank=8)) with patch("vllm.worker.model_runner.get_model", get_model_patched): engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) diff --git a/vllm/config.py b/vllm/config.py index bd0dc89b585f..fc848b72d7f2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -8,7 +8,7 @@ from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.transformers_utils.config import get_config -from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version +from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version logger = init_logger(__name__) @@ -380,13 +380,21 @@ class ParallelConfig: disable_custom_all_reduce: bool = False, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size - self.tensor_parallel_size = tensor_parallel_size + if is_neuron(): + # For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly. + # Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload + # to multiple NeuronCores. + self.tensor_parallel_size = 1 + self.neuron_tp_degree = tensor_parallel_size + else: + self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce - self.world_size = pipeline_parallel_size * tensor_parallel_size - if self.world_size > 1: + self.world_size = pipeline_parallel_size * self.tensor_parallel_size + # Ray worker is not supported for Neuron backend. + if self.world_size > 1 and not is_neuron(): self.worker_use_ray = True self._verify_args() @@ -465,8 +473,29 @@ class SchedulerConfig: class DeviceConfig: - def __init__(self, device: str = "cuda") -> None: - self.device = torch.device(device) + def __init__(self, device: str = "auto") -> None: + if device == "auto": + # Automated device type detection + if torch.cuda.is_available(): + self.device_type = "cuda" + elif is_neuron(): + self.device_type = "neuron" + else: + raise RuntimeError("No supported device detected.") + else: + # Device type is assigned explicitly + self.device_type = device + + # Some device types require processing inputs on CPU + if self.device_type in ["neuron"]: + self.device = torch.device("cpu") + else: + # Set device with device type + self.device = torch.device(self.device_type) + + @property + def is_neuron(self): + return self.device_type == "neuron" @dataclass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a4efd171b871..c01e7311fb89 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -44,7 +44,7 @@ class EngineArgs: lora_extra_vocab_size: int = 256 lora_dtype = 'auto' max_cpu_loras: Optional[int] = None - device: str = 'cuda' + device: str = 'auto' def __post_init__(self): if self.tokenizer is None: @@ -171,7 +171,7 @@ class EngineArgs: parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32], + choices=[8, 16, 32, 128], help='token block size') parser.add_argument('--seed', type=int, @@ -264,13 +264,11 @@ class EngineArgs: help=('Maximum number of LoRAs to store in CPU memory. ' 'Must be >= than max_num_seqs. ' 'Defaults to max_num_seqs.')) - parser.add_argument( - "--device", - type=str, - default=EngineArgs.device, - choices=["cuda"], - help=('Device type for vLLM execution. ' - 'Currently, only CUDA-compatible devices are supported.')) + parser.add_argument("--device", + type=str, + default=EngineArgs.device, + choices=["auto", "cuda", "neuron"], + help='Device type for vLLM execution.') return parser @classmethod diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f5b2145c22d6..f0fd7efdef81 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -3,6 +3,7 @@ from collections import defaultdict import os import time import pickle +import importlib from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union) @@ -20,7 +21,8 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, TokenizerGroup) -from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method +from vllm.utils import (Counter, set_cuda_visible_devices, get_ip, + get_open_port, get_distributed_init_method) if ray: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -31,6 +33,12 @@ if TYPE_CHECKING: logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 +# A map between the device type (in device config) to its worker module. +DEVICE_TO_WORKER_MODULE_MAP = { + "cuda": "vllm.worker.worker", + "neuron": "vllm.worker.neuron_worker", +} + # If the env var is set, it uses the Ray's compiled DAG API # which optimizes the control plane overhead. # Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. @@ -138,10 +146,17 @@ class LLMEngine: def get_tokenizer_for_seq(self, sequence: Sequence): return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + def _dispatch_worker(self): + worker_module = DEVICE_TO_WORKER_MODULE_MAP[ + self.device_config.device_type] + imported_worker = importlib.import_module(worker_module) + Worker = imported_worker.Worker + return Worker + def _init_workers(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker + Worker = self._dispatch_worker() assert self.parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") @@ -243,7 +258,7 @@ class LLMEngine: # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker + Worker = self._dispatch_worker() # Initialize torch distributed process group for the workers. model_config = copy.deepcopy(self.model_config) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index e1aac20b038b..e667d70f71e3 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -795,6 +795,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA): self.dtype = dtype self.device = device + @property + def logits_as_hidden_states(self): + return self.base_layer.logits_as_hidden_states + @property def vocab_size(self): return self.base_layer.vocab_size diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py index 0d5b2004ad7c..cd6dbde5f54c 100644 --- a/vllm/model_executor/__init__.py +++ b/vllm/model_executor/__init__.py @@ -1,7 +1,6 @@ from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.utils import set_random_seed +from vllm.model_executor.utils import set_random_seed, get_model __all__ = [ "InputMetadata", diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 884d84387e50..71655b216fb3 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) +from vllm.utils import is_neuron class Sampler(nn.Module): @@ -32,6 +33,8 @@ class Sampler(nn.Module): org_vocab_size: Optional[int] = None) -> None: super().__init__() self.vocab_size = vocab_size + # Transformers-neuronx generate outputs as logits directly. + self.logits_as_hidden_states = is_neuron() # original vocabulary size (without LoRA). self.org_vocab_size = org_vocab_size or vocab_size @@ -55,10 +58,14 @@ class Sampler(nn.Module): embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[SamplerOutput]: # Get the hidden states that we use for sampling. - hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + if self.logits_as_hidden_states: + logits = hidden_states + else: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) - # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, embedding, embedding_bias) + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, embedding, embedding_bias) # Only perform sampling in the driver worker. # Note: `_get_logits` is still distributed across TP workers because @@ -395,7 +402,8 @@ def _sample( sample_metadata[sampling_type] = (seq_group_ids, seq_groups, is_prompts, sample_indices) if sampling_type == SamplingType.GREEDY: - greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1) + greedy_samples = torch.argmax(logprobs[sample_indices.long()], + dim=-1) elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): max_best_of = 1 for seq_group, is_prompt in zip(seq_groups, is_prompts): @@ -407,7 +415,7 @@ def _sample( "generators": sampling_metadata.generators, } multinomial_samples[sampling_type] = _multinomial( - probs[sample_indices], max_best_of, **seeded_args) + probs[sample_indices.long()], max_best_of, **seeded_args) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] else: diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index ebe092b5d62b..cb64d80c8147 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -1,11 +1,11 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Optional, Type +from typing import Type import torch import torch.nn as nn -from vllm.config import DeviceConfig, ModelConfig, LoRAConfig +from vllm.config import DeviceConfig, ModelConfig from vllm.model_executor.models import ModelRegistry from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: f"Supported architectures: {ModelRegistry.get_supported_archs()}") -def get_model(model_config: ModelConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig] = None) -> nn.Module: +def get_model(model_config: ModelConfig, device_config: DeviceConfig, + **kwargs) -> nn.Module: + lora_config = kwargs.get("lora_config", None) model_class = _get_model_architecture(model_config) # Get the (maybe quantized) linear method. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 66d28207d664..e4f3a785cd99 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -4,7 +4,7 @@ from typing import List, Optional, Type import torch.nn as nn from vllm.logger import init_logger -from vllm.utils import is_hip +from vllm.utils import is_hip, is_neuron logger = init_logger(__name__) @@ -61,6 +61,9 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = { "Sliding window attention is not yet supported in ROCm's flash attention", } +# Models not supported by Neuron. +_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"} + class ModelRegistry: @@ -77,8 +80,15 @@ class ModelRegistry: logger.warning( f"Model architecture {model_arch} is partially supported " "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) + elif is_neuron(): + if model_arch not in _NEURON_SUPPORTED_MODELS: + raise ValueError( + f"Model architecture {model_arch} is not supported by " + "Neuron for now.") module_name, model_cls_name = _MODELS[model_arch] + if is_neuron(): + module_name = _NEURON_SUPPORTED_MODELS[model_arch] module = importlib.import_module( f"vllm.model_executor.models.{module_name}") return getattr(module, model_cls_name, None) diff --git a/vllm/model_executor/models/neuron/llama.py b/vllm/model_executor/models/neuron/llama.py new file mode 100644 index 000000000000..e2856da99d9b --- /dev/null +++ b/vllm/model_executor/models/neuron/llama.py @@ -0,0 +1,79 @@ +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +import os +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class LlamaForCausalLM(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method=None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = None + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + with torch.inference_mode(): + block_size = self.model.context_buckets[-1] + if input_metadata.is_prompt: + seq_ids = input_metadata.slot_mapping[:, 0] // block_size + else: + seq_ids = input_metadata.block_tables + logits = self.model(input_ids, + cache_ids=positions, + start_ids=seq_ids.flatten()) + return logits + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.model.chkpt_model.lm_head, + hidden_states, sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + **kwargs): + from transformers_neuronx.llama.model import LlamaForSampling + + split_model_dir = f"{model_name_or_path}-split" + if os.path.isdir(os.path.join(model_name_or_path, + "pytorch_model.bin")): + split_model_dir = model_name_or_path + elif not os.path.exists(f"{model_name_or_path}-split"): + from transformers.models.llama import LlamaForCausalLM + from transformers_neuronx.module import save_pretrained_split + + hf_model = LlamaForCausalLM.from_pretrained(model_name_or_path, + low_cpu_mem_usage=True) + save_pretrained_split(hf_model, f"{model_name_or_path}-split") + + self.model = LlamaForSampling.from_pretrained(split_model_dir, + **kwargs) + self.model.to_neuron() diff --git a/vllm/model_executor/neuron_model_loader.py b/vllm/model_executor/neuron_model_loader.py new file mode 100644 index 000000000000..b8d63d4ff12f --- /dev/null +++ b/vllm/model_executor/neuron_model_loader.py @@ -0,0 +1,66 @@ +"""Utilities for selecting and loading models.""" +from typing import Type + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import ModelConfig, DeviceConfig +from vllm.model_executor.models import ModelRegistry + +TORCH_DTYPE_TO_NEURON_AMP = { + "auto": "f32", + "half": "f16", + "float16": "f16", + "bfloat16": "bf16", + "float": "f32", + "float32": "f32", + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", +} + + +def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_model(model_config: ModelConfig, device_config: DeviceConfig, + **kwargs) -> nn.Module: + from transformers_neuronx.config import NeuronConfig, ContinuousBatchingConfig + + parallel_config = kwargs.get("parallel_config") + scheduler_config = kwargs.get("scheduler_config") + + model_class = _get_model_architecture(model_config.hf_config) + linear_method = None + + # Create a model instance. + model = model_class(model_config.hf_config, linear_method) + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=scheduler_config.max_num_seqs) + neuron_config = NeuronConfig( + continuous_batching=continuous_batching_config) + + # Load the weights from the cached or downloaded files. + model.load_weights( + model_config.model, + model_config.download_dir, + model_config.load_format, + model_config.revision, + tp_degree=parallel_config.neuron_tp_degree, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=[scheduler_config.max_model_len], + n_positions=[scheduler_config.max_model_len], + batch_size=scheduler_config.max_num_seqs) + + return model.eval() diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index d0ffeecd2d74..7deb80801856 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -5,7 +5,7 @@ import torch from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData -from vllm.utils import in_wsl +from vllm.utils import in_wsl, is_neuron _SAMPLING_EPS = 1e-5 @@ -155,7 +155,7 @@ class SamplingTensors: dtype: torch.dtype) -> "SamplingTensors": # Note that the performance will be very bad without # pinned memory. - pin_memory = not in_wsl() + pin_memory = not in_wsl() and not is_neuron() prompt_max_len = max(len(tokens) for tokens in prompt_tokens) prompt_padded_tokens = [ tokens + [vocab_size] * (prompt_max_len - len(tokens)) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 336bc1cd005c..0113e3edf067 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -1,10 +1,18 @@ """Utils for model executor.""" import random +import importlib from typing import Any, Dict, Optional import numpy as np import torch +from vllm.config import DeviceConfig, ModelConfig + +DEVICE_TO_MODEL_LOADER_MAP = { + "cuda": "model_loader", + "neuron": "neuron_model_loader", +} + def set_random_seed(seed: int) -> None: random.seed(seed) @@ -33,3 +41,12 @@ def set_weight_attrs( assert not hasattr( weight, key), (f"Overwriting existing tensor attribute: {key}") setattr(weight, key, value) + + +def get_model(model_config: ModelConfig, device_config: DeviceConfig, + **kwargs) -> torch.nn.Module: + model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type] + imported_model_loader = importlib.import_module( + f"vllm.model_executor.{model_loader_module}") + get_model_fn = imported_model_loader.get_model + return get_model_fn(model_config, device_config, **kwargs) diff --git a/vllm/utils.py b/vllm/utils.py index c8ac57de6f5f..a4f9bfe6aac9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -118,6 +118,14 @@ def is_hip() -> bool: return torch.version.hip is not None +def is_neuron() -> bool: + try: + import transformers_neuronx + except ImportError: + transformers_neuronx = None + return transformers_neuronx is not None + + def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" # NOTE: This import statement should be executed lazily since diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index bbe33989fc2a..880299783935 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -3,10 +3,9 @@ from typing import Dict, List, Tuple import torch -from vllm._C import cache_ops from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE +from vllm.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE logger = init_logger(__name__) @@ -39,6 +38,10 @@ class CacheEngine: self.num_gpu_blocks = cache_config.num_gpu_blocks self.num_cpu_blocks = cache_config.num_cpu_blocks + # Skip initializing CUDA stream and buffer for Neuron backend. + if is_neuron(): + return + if cache_config.cache_dtype == "auto": self.dtype = model_config.dtype else: @@ -121,6 +124,8 @@ class CacheEngine: dst: List[KVCache], src_to_dst: Dict[int, int], ) -> None: + from vllm._C import cache_ops + with torch.cuda.stream(self.cache_stream): for i in range(self.num_layers): src_key_cache, src_value_cache = src[i] @@ -140,6 +145,8 @@ class CacheEngine: self._swap(self.gpu_cache, self.cpu_cache, src_to_dst) def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + from vllm._C import cache_ops + key_caches = [key_cache for key_cache, _ in self.gpu_cache] value_caches = [value_cache for _, value_cache in self.gpu_cache] # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b99a409e02d1..efe570778fb4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -80,9 +80,16 @@ class ModelRunner: self.in_wsl = in_wsl() self.kv_cache_dtype = kv_cache_dtype + # Set enforce_eager to True for Neuron backend, to avoid capturing graph + if self.device_config.is_neuron: + self.model_config.enforce_eager = True + def load_model(self) -> None: - self.model = get_model(self.model_config, self.device_config, - self.lora_config) + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) vocab_size = self.model.config.vocab_size @@ -393,6 +400,7 @@ class ModelRunner: selected_token_start_idx = 0 categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 + pin_memory = not self.in_wsl and not self.device_config.is_neuron max_subquery_len = max(subquery_lens) if subquery_lens else 1 for i, seq_group_metadata in enumerate(seq_group_metadata_list): @@ -443,12 +451,12 @@ class ModelRunner: selected_token_indices = _async_h2d(selected_token_indices, dtype=torch.long, target_device=self.device, - pin_memory=not self.in_wsl) + pin_memory=pin_memory) categorized_sample_indices = { t: _async_h2d(seq_ids, dtype=torch.int, target_device=self.device, - pin_memory=not self.in_wsl) + pin_memory=pin_memory) for t, seq_ids in categorized_sample_indices.items() } diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py new file mode 100644 index 000000000000..3229a21c11a3 --- /dev/null +++ b/vllm/worker/neuron_worker.py @@ -0,0 +1,191 @@ +"""A Neuron worker class.""" +from typing import Dict, List, Optional, Tuple + +import torch +import torch.distributed + +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) +from vllm.model_executor import set_random_seed +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_tensor_dict) +from vllm.model_executor.parallel_utils.parallel_state import ( + ensure_model_parallel_initialized) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.model_runner import ModelRunner + + +class Worker: + """A worker class that executes the model on a group of neuron cores. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + self.model_runner = ModelRunner(model_config, + parallel_config, + scheduler_config, + device_config, + lora_config=self.lora_config, + is_driver_worker=is_driver_worker) + # Uninitialized cache engine. Will be initialized by + # self.init_cache_engine(). + self.cache_config = None + self.cache_engine = None + self.cache_events = None + self.gpu_cache = None + + def init_model(self) -> None: + # Initialize the distributed environment. + _init_distributed_environment(self.parallel_config, + self.rank, + self.distributed_init_method, + distributed_backend="gloo") + + # Initialize the model. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + @torch.inference_mode() + def profile_num_available_blocks( + self, + block_size: int = 128, + gpu_memory_utilization: float = 0.9, + cpu_swap_space: int = 0, + cache_dtype: str = "float16", + ) -> Tuple[int, int]: + """Simply returns max_num_seqs as num_gpu_blocks, 0 as num_cpu_blocks.""" + num_gpu_blocks = self.scheduler_config.max_num_seqs + num_cpu_blocks = 0 + return num_gpu_blocks, num_cpu_blocks + + def init_cache_engine(self, cache_config: CacheConfig) -> None: + self.cache_config = cache_config + self.cache_engine = CacheEngine(self.cache_config, self.model_config, + self.parallel_config) + self.model_runner.set_block_size(self.cache_engine.block_size) + + def warm_up_model(self) -> None: + # Warm up is maintained in transformers-neuronx + pass + + def cache_swap( + self, + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> None: + # Issue cache operations. + issued_cache_op = False + if blocks_to_swap_in: + self.cache_engine.swap_in(blocks_to_swap_in) + issued_cache_op = True + if blocks_to_swap_out: + self.cache_engine.swap_out(blocks_to_swap_out) + issued_cache_op = True + if blocks_to_copy: + self.cache_engine.copy(blocks_to_copy) + issued_cache_op = True + + cache_events = self.cache_events if issued_cache_op else None + + # Wait for cache operations to finish. + if cache_events is not None: + raise NotImplementedError( + "cache operations are not implemented for neuron backend.") + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, List[int]]] = None, + ) -> Optional[SamplerOutput]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups = len(seq_group_metadata_list) + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None + data = { + "num_seq_groups": num_seq_groups, + "blocks_to_swap_in": blocks_to_swap_in, + "blocks_to_swap_out": blocks_to_swap_out, + "blocks_to_copy": blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) + else: + data = broadcast_tensor_dict(src=0) + num_seq_groups = data["num_seq_groups"] + blocks_to_swap_in = data["blocks_to_swap_in"] + blocks_to_swap_out = data["blocks_to_swap_out"] + blocks_to_copy = data["blocks_to_copy"] + + self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return {} + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.gpu_cache) + return output + + +def _init_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + distributed_backend: Optional[str] = None, +) -> None: + """Initialize the distributed environment.""" + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch world " + "size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + distributed_backend = distributed_backend if distributed_backend else "nccl" + torch.distributed.init_process_group( + backend=distributed_backend, + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1)) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size)