mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:36:20 +08:00
[Hardware][Neuron] Refactor neuron support (#3471)
This commit is contained in:
parent
ea5f14e6ff
commit
e90fc21f2e
@ -12,7 +12,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|||||||
|
|
||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="openlm-research/open_llama_3b",
|
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||||
max_num_seqs=8,
|
max_num_seqs=8,
|
||||||
# The max_model_len and block_size arguments are required to be same as
|
# The max_model_len and block_size arguments are required to be same as
|
||||||
# max sequence length when targeting neuron device.
|
# max sequence length when targeting neuron device.
|
||||||
@ -24,7 +24,8 @@ llm = LLM(
|
|||||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||||
# The device argument can be either unspecified for automated detection,
|
# The device argument can be either unspecified for automated detection,
|
||||||
# or explicitly assigned.
|
# or explicitly assigned.
|
||||||
device="neuron")
|
device="neuron",
|
||||||
|
tensor_parallel_size=2)
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|||||||
@ -33,7 +33,7 @@ def test_worker_apply_lora(sql_lora_files):
|
|||||||
max_loras=32),
|
max_loras=32),
|
||||||
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
|
distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
|
||||||
)
|
)
|
||||||
worker.init_model()
|
worker.init_device()
|
||||||
worker.load_model()
|
worker.load_model()
|
||||||
|
|
||||||
worker.model_runner.set_active_loras([], LoRAMapping([], []))
|
worker.model_runner.set_active_loras([], LoRAMapping([], []))
|
||||||
|
|||||||
@ -71,7 +71,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
|||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
worker.init_model()
|
worker.init_device()
|
||||||
|
|
||||||
vocab_size = 32_000
|
vocab_size = 32_000
|
||||||
|
|
||||||
@ -151,7 +151,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
|||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
worker.init_model()
|
worker.init_device()
|
||||||
|
|
||||||
proposal_token_ids = torch.randint(low=0,
|
proposal_token_ids = torch.randint(low=0,
|
||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
@ -230,7 +230,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
|||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
worker.init_model()
|
worker.init_device()
|
||||||
|
|
||||||
proposal_token_ids = torch.randint(low=0,
|
proposal_token_ids = torch.randint(low=0,
|
||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
@ -342,7 +342,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
|||||||
|
|
||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
worker.init_model()
|
worker.init_device()
|
||||||
|
|
||||||
proposal_token_ids = torch.randint(low=0,
|
proposal_token_ids = torch.randint(low=0,
|
||||||
high=vocab_size,
|
high=vocab_size,
|
||||||
@ -486,8 +486,8 @@ def test_empty_input_batch(k: int, batch_size: int):
|
|||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_init_model():
|
def test_init_device():
|
||||||
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_model, as
|
"""Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
|
||||||
well as other GPU initialization.
|
well as other GPU initialization.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker)
|
draft_worker = mock_worker(cls=MultiStepWorker)
|
||||||
@ -499,11 +499,11 @@ def test_init_model():
|
|||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
|
|
||||||
worker.init_model()
|
worker.init_device()
|
||||||
|
|
||||||
draft_worker.init_model.assert_called_once()
|
draft_worker.init_device.assert_called_once()
|
||||||
|
|
||||||
target_worker.init_model.assert_called_once()
|
target_worker.init_device.assert_called_once()
|
||||||
|
|
||||||
metrics_collector.init_gpu_tensors.assert_called_once()
|
metrics_collector.init_gpu_tensors.assert_called_once()
|
||||||
rejection_sampler.init_gpu_tensors.assert_called_once()
|
rejection_sampler.init_gpu_tensors.assert_called_once()
|
||||||
|
|||||||
@ -123,7 +123,7 @@ def create_worker(cls: type,
|
|||||||
is_driver_worker=is_driver_worker,
|
is_driver_worker=is_driver_worker,
|
||||||
)
|
)
|
||||||
|
|
||||||
worker.init_model()
|
worker.init_device()
|
||||||
worker.load_model()
|
worker.load_model()
|
||||||
|
|
||||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
|||||||
@ -30,7 +30,7 @@ def test_swap() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the worker.
|
# Initialize the worker.
|
||||||
worker.init_model()
|
worker.init_device()
|
||||||
worker.load_model()
|
worker.load_model()
|
||||||
worker.init_cache_engine(cache_config)
|
worker.init_cache_engine(cache_config)
|
||||||
worker.warm_up_model()
|
worker.warm_up_model()
|
||||||
|
|||||||
@ -474,15 +474,7 @@ class ParallelConfig:
|
|||||||
placement_group: Optional["PlacementGroup"] = None,
|
placement_group: Optional["PlacementGroup"] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.pipeline_parallel_size = pipeline_parallel_size
|
self.pipeline_parallel_size = pipeline_parallel_size
|
||||||
if is_neuron():
|
self.tensor_parallel_size = tensor_parallel_size
|
||||||
# For Neuron device support, here we assign TP=1 to avoid sharding
|
|
||||||
# within vLLM directly. Transformer-neuronx would take
|
|
||||||
# neuron_tp_degree attribute, and distribute the workload
|
|
||||||
# to multiple NeuronCores.
|
|
||||||
self.tensor_parallel_size = 1
|
|
||||||
self.neuron_tp_degree = tensor_parallel_size
|
|
||||||
else:
|
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
|
||||||
self.worker_use_ray = worker_use_ray
|
self.worker_use_ray = worker_use_ray
|
||||||
self.max_parallel_loading_workers = max_parallel_loading_workers
|
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||||
self.disable_custom_all_reduce = disable_custom_all_reduce
|
self.disable_custom_all_reduce = disable_custom_all_reduce
|
||||||
@ -491,8 +483,7 @@ class ParallelConfig:
|
|||||||
self.placement_group = placement_group
|
self.placement_group = placement_group
|
||||||
|
|
||||||
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
|
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
|
||||||
# Ray worker is not supported for Neuron backend.
|
if self.world_size > 1:
|
||||||
if self.world_size > 1 and not is_neuron():
|
|
||||||
self.worker_use_ray = True
|
self.worker_use_ray = True
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
@ -591,10 +582,6 @@ class DeviceConfig:
|
|||||||
# Set device with device type
|
# Set device with device type
|
||||||
self.device = torch.device(self.device_type)
|
self.device = torch.device(self.device_type)
|
||||||
|
|
||||||
@property
|
|
||||||
def is_neuron(self):
|
|
||||||
return self.device_type == "neuron"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoRAConfig:
|
class LoRAConfig:
|
||||||
|
|||||||
@ -325,7 +325,12 @@ class AsyncLLMEngine:
|
|||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
engine_configs = engine_args.create_engine_configs()
|
engine_configs = engine_args.create_engine_configs()
|
||||||
parallel_config = engine_configs[2]
|
parallel_config = engine_configs[2]
|
||||||
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
|
device_config = engine_configs[4]
|
||||||
|
|
||||||
|
if device_config.device_type == "neuron":
|
||||||
|
raise NotImplementedError("Neuron is not supported for "
|
||||||
|
"async engine yet.")
|
||||||
|
elif parallel_config.worker_use_ray or engine_args.engine_use_ray:
|
||||||
initialize_ray_cluster(parallel_config)
|
initialize_ray_cluster(parallel_config)
|
||||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||||
executor_class = RayGPUExecutorAsync
|
executor_class = RayGPUExecutorAsync
|
||||||
|
|||||||
@ -125,9 +125,13 @@ class LLMEngine:
|
|||||||
# Create the engine configs.
|
# Create the engine configs.
|
||||||
engine_configs = engine_args.create_engine_configs()
|
engine_configs = engine_args.create_engine_configs()
|
||||||
parallel_config = engine_configs[2]
|
parallel_config = engine_configs[2]
|
||||||
|
device_config = engine_configs[4]
|
||||||
|
|
||||||
# Initialize the cluster and specify the executor class.
|
# Initialize the cluster and specify the executor class.
|
||||||
if parallel_config.worker_use_ray:
|
if device_config.device_type == "neuron":
|
||||||
|
from vllm.executor.neuron_executor import NeuronExecutor
|
||||||
|
executor_class = NeuronExecutor
|
||||||
|
elif parallel_config.worker_use_ray:
|
||||||
initialize_ray_cluster(parallel_config)
|
initialize_ray_cluster(parallel_config)
|
||||||
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
from vllm.executor.ray_gpu_executor import RayGPUExecutor
|
||||||
executor_class = RayGPUExecutor
|
executor_class = RayGPUExecutor
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
import importlib
|
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -13,12 +12,6 @@ from vllm.utils import (get_ip, get_open_port, get_distributed_init_method,
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# A map between the device type (in device config) to its worker module.
|
|
||||||
DEVICE_TO_WORKER_MODULE_MAP = {
|
|
||||||
"cuda": "vllm.worker.worker",
|
|
||||||
"neuron": "vllm.worker.neuron_worker",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class GPUExecutor(ExecutorBase):
|
class GPUExecutor(ExecutorBase):
|
||||||
|
|
||||||
@ -44,17 +37,10 @@ class GPUExecutor(ExecutorBase):
|
|||||||
# Profile the memory usage and initialize the cache.
|
# Profile the memory usage and initialize the cache.
|
||||||
self._init_cache()
|
self._init_cache()
|
||||||
|
|
||||||
def _dispatch_worker(self):
|
|
||||||
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
|
|
||||||
self.device_config.device_type]
|
|
||||||
imported_worker = importlib.import_module(worker_module)
|
|
||||||
Worker = imported_worker.Worker
|
|
||||||
return Worker
|
|
||||||
|
|
||||||
def _init_worker(self):
|
def _init_worker(self):
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
Worker = self._dispatch_worker()
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
assert self.parallel_config.world_size == 1, (
|
assert self.parallel_config.world_size == 1, (
|
||||||
"GPUExecutor only supports single GPU.")
|
"GPUExecutor only supports single GPU.")
|
||||||
@ -73,7 +59,7 @@ class GPUExecutor(ExecutorBase):
|
|||||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||||
is_driver_worker=True,
|
is_driver_worker=True,
|
||||||
)
|
)
|
||||||
self.driver_worker.init_model()
|
self.driver_worker.init_device()
|
||||||
self.driver_worker.load_model()
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
def _init_cache(self) -> None:
|
def _init_cache(self) -> None:
|
||||||
|
|||||||
80
vllm/executor/neuron_executor.py
Normal file
80
vllm/executor/neuron_executor.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
|
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||||
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class NeuronExecutor(ExecutorBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
) -> None:
|
||||||
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
assert lora_config is None, "LoRA is not supported for Neuron backend."
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
|
||||||
|
# Set the number of GPU blocks to be the same as the maximum number of
|
||||||
|
# sequences that can be processed in a single batch. This is equivalent
|
||||||
|
# to schedule without PagedAttention.
|
||||||
|
self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs
|
||||||
|
self.cache_config.num_cpu_blocks = 0
|
||||||
|
|
||||||
|
# Instantiate the worker and load the model to the device.
|
||||||
|
self._init_worker()
|
||||||
|
|
||||||
|
def _init_worker(self):
|
||||||
|
from vllm.worker.neuron_worker import NeuronWorker
|
||||||
|
|
||||||
|
self.driver_worker = NeuronWorker(
|
||||||
|
self.model_config,
|
||||||
|
self.parallel_config,
|
||||||
|
self.scheduler_config,
|
||||||
|
self.device_config,
|
||||||
|
)
|
||||||
|
self.driver_worker.init_device()
|
||||||
|
self.driver_worker.load_model()
|
||||||
|
|
||||||
|
def execute_model(self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
|
||||||
|
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
|
||||||
|
and blocks_to_copy == {}), (
|
||||||
|
"Cache operations are not supported for Neuron backend.")
|
||||||
|
|
||||||
|
output = self.driver_worker.execute_model(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"LoRA is not implemented for neuron backend.")
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"LoRA is not implemented for neuron backend.")
|
||||||
|
|
||||||
|
def list_loras(self) -> List[int]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"LoRA is not implemented for neuron backend.")
|
||||||
|
|
||||||
|
def check_health(self) -> None:
|
||||||
|
# NeuronExecutor will always be healthy as long as
|
||||||
|
# it's running.
|
||||||
|
return
|
||||||
@ -3,7 +3,6 @@ import copy
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import importlib
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
@ -25,12 +24,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# A map between the device type (in device config) to its worker module.
|
|
||||||
DEVICE_TO_WORKER_MODULE_MAP = {
|
|
||||||
"cuda": "vllm.worker.worker",
|
|
||||||
"neuron": "vllm.worker.neuron_worker",
|
|
||||||
}
|
|
||||||
|
|
||||||
# If the env var is set, it uses the Ray's compiled DAG API
|
# If the env var is set, it uses the Ray's compiled DAG API
|
||||||
# which optimizes the control plane overhead.
|
# which optimizes the control plane overhead.
|
||||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||||
@ -73,13 +66,6 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
if USE_RAY_COMPILED_DAG:
|
if USE_RAY_COMPILED_DAG:
|
||||||
self.forward_dag = self._compiled_ray_dag()
|
self.forward_dag = self._compiled_ray_dag()
|
||||||
|
|
||||||
def _dispatch_worker(self):
|
|
||||||
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
|
|
||||||
self.device_config.device_type]
|
|
||||||
imported_worker = importlib.import_module(worker_module)
|
|
||||||
Worker = imported_worker.Worker
|
|
||||||
return Worker
|
|
||||||
|
|
||||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||||
**ray_remote_kwargs):
|
**ray_remote_kwargs):
|
||||||
if self.parallel_config.tensor_parallel_size == 1:
|
if self.parallel_config.tensor_parallel_size == 1:
|
||||||
@ -155,7 +141,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
|
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
Worker = self._dispatch_worker()
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
model_config = copy.deepcopy(self.model_config)
|
model_config = copy.deepcopy(self.model_config)
|
||||||
parallel_config = copy.deepcopy(self.parallel_config)
|
parallel_config = copy.deepcopy(self.parallel_config)
|
||||||
@ -201,7 +187,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
|
|
||||||
# FIXME(woosuk): We are not properly initializing cupy NCCL when
|
# FIXME(woosuk): We are not properly initializing cupy NCCL when
|
||||||
# we have multiple nodes.
|
# we have multiple nodes.
|
||||||
self._run_workers("init_model",
|
self._run_workers("init_device",
|
||||||
cupy_port=get_open_port()
|
cupy_port=get_open_port()
|
||||||
if not model_config.enforce_eager else None)
|
if not model_config.enforce_eager else None)
|
||||||
self._run_workers(
|
self._run_workers(
|
||||||
|
|||||||
@ -799,8 +799,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logits_as_hidden_states(self):
|
def logits_as_input(self):
|
||||||
return self.base_layer.logits_as_hidden_states
|
return self.base_layer.logits_as_input
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.utils import in_wsl
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerWeights:
|
class LoRALayerWeights:
|
||||||
@ -64,7 +64,7 @@ class LoRALayerWeights:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
|
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
|
||||||
pin_memory = str(device) == "cpu" and not in_wsl()
|
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||||
lora_a = torch.zeros([input_dim, rank],
|
lora_a = torch.zeros([input_dim, rank],
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.config import LoRAConfig
|
from vllm.config import LoRAConfig
|
||||||
from vllm.utils import LRUCache, in_wsl
|
from vllm.utils import LRUCache, is_pin_memory_available
|
||||||
|
|
||||||
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
|
from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer,
|
||||||
from_layer_logits_processor)
|
from_layer_logits_processor)
|
||||||
@ -143,7 +143,7 @@ class LoRAModel:
|
|||||||
embedding_padding_modules: Optional[List[str]] = None,
|
embedding_padding_modules: Optional[List[str]] = None,
|
||||||
) -> "LoRAModel":
|
) -> "LoRAModel":
|
||||||
"""Create a LoRAModel from a dictionary of tensors."""
|
"""Create a LoRAModel from a dictionary of tensors."""
|
||||||
pin_memory = str(device) == "cpu" and not in_wsl()
|
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||||
loras: Dict[str, LoRALayerWeights] = {}
|
loras: Dict[str, LoRALayerWeights] = {}
|
||||||
for tensor_name, tensor in tensors.items():
|
for tensor_name, tensor in tensors.items():
|
||||||
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
|
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_random_seed, get_model
|
from vllm.model_executor.utils import set_random_seed
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InputMetadata",
|
"InputMetadata",
|
||||||
"get_model",
|
|
||||||
"SamplingMetadata",
|
"SamplingMetadata",
|
||||||
"set_random_seed",
|
"set_random_seed",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from typing import Optional, List, Any, Dict
|
from typing import TYPE_CHECKING, Optional, List, Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from xformers.ops.fmha.attn_bias import AttentionBias
|
if TYPE_CHECKING:
|
||||||
|
from xformers.ops.fmha.attn_bias import AttentionBias
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -82,7 +83,7 @@ class InputMetadata:
|
|||||||
# when alibi slopes is used. It is because of the limitation
|
# when alibi slopes is used. It is because of the limitation
|
||||||
# from xformer API.
|
# from xformer API.
|
||||||
# will not appear in the __repr__ and __init__
|
# will not appear in the __repr__ and __init__
|
||||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
self.attn_bias: Optional[List["AttentionBias"]] = None
|
||||||
|
|
||||||
# Cuda graph is only used for decoding now.
|
# Cuda graph is only used for decoding now.
|
||||||
if self.use_cuda_graph:
|
if self.use_cuda_graph:
|
||||||
|
|||||||
@ -4,8 +4,6 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.utils import is_neuron
|
|
||||||
|
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
tensor_model_parallel_gather)
|
tensor_model_parallel_gather)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
@ -23,7 +21,8 @@ class LogitsProcessor(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
org_vocab_size: Optional[int] = None,
|
org_vocab_size: Optional[int] = None,
|
||||||
scale: Optional[float] = 1.0) -> None:
|
scale: Optional[float] = 1.0,
|
||||||
|
logits_as_input: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
scale: A scaling factor to apply to the logits.
|
scale: A scaling factor to apply to the logits.
|
||||||
@ -31,8 +30,8 @@ class LogitsProcessor(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
# Transformers-neuronx generate outputs as logits directly.
|
# Whether the input is logits (default is hidden states).
|
||||||
self.logits_as_hidden_states = is_neuron()
|
self.logits_as_input = logits_as_input
|
||||||
# original vocabulary size (without LoRA).
|
# original vocabulary size (without LoRA).
|
||||||
self.org_vocab_size = org_vocab_size or vocab_size
|
self.org_vocab_size = org_vocab_size or vocab_size
|
||||||
|
|
||||||
@ -43,7 +42,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
embedding_bias: Optional[torch.Tensor] = None,
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.logits_as_hidden_states:
|
if self.logits_as_input:
|
||||||
logits = hidden_states
|
logits = hidden_states
|
||||||
else:
|
else:
|
||||||
hidden_states = _prune_hidden_states(hidden_states,
|
hidden_states = _prune_hidden_states(hidden_states,
|
||||||
|
|||||||
@ -4,13 +4,13 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
||||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||||
SamplingTensors)
|
SamplingTensors)
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
||||||
SamplerOutput, SequenceData, SequenceGroupOutput,
|
SamplerOutput, SequenceData, SequenceGroupOutput,
|
||||||
SequenceOutput)
|
SequenceOutput)
|
||||||
from vllm.model_executor.layers.ops.sample import (sample as sample_triton)
|
|
||||||
|
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import List, Optional, Type
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import is_hip, is_neuron
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -63,12 +63,6 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
|
|||||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Models supported by Neuron.
|
|
||||||
_NEURON_SUPPORTED_MODELS = {
|
|
||||||
"LlamaForCausalLM": "neuron.llama",
|
|
||||||
"MistralForCausalLM": "neuron.mistral"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistry:
|
class ModelRegistry:
|
||||||
|
|
||||||
@ -85,15 +79,8 @@ class ModelRegistry:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Model architecture {model_arch} is partially supported "
|
f"Model architecture {model_arch} is partially supported "
|
||||||
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
||||||
elif is_neuron():
|
|
||||||
if model_arch not in _NEURON_SUPPORTED_MODELS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Model architecture {model_arch} is not supported by "
|
|
||||||
"Neuron for now.")
|
|
||||||
|
|
||||||
module_name, model_cls_name = _MODELS[model_arch]
|
module_name, model_cls_name = _MODELS[model_arch]
|
||||||
if is_neuron():
|
|
||||||
module_name = _NEURON_SUPPORTED_MODELS[model_arch]
|
|
||||||
module = importlib.import_module(
|
module = importlib.import_module(
|
||||||
f"vllm.model_executor.models.{module_name}")
|
f"vllm.model_executor.models.{module_name}")
|
||||||
return getattr(module, model_cls_name, None)
|
return getattr(module, model_cls_name, None)
|
||||||
|
|||||||
@ -1,86 +0,0 @@
|
|||||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
|
||||||
import os
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers import LlamaConfig
|
|
||||||
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
||||||
from vllm.sequence import SamplerOutput
|
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLM(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: LlamaConfig,
|
|
||||||
linear_method=None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.linear_method = linear_method
|
|
||||||
self.model = None
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
kv_caches: List[KVCache],
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
with torch.inference_mode():
|
|
||||||
block_size = self.model.context_buckets[-1]
|
|
||||||
if input_metadata.is_prompt:
|
|
||||||
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
|
|
||||||
else:
|
|
||||||
seq_ids = input_metadata.block_tables
|
|
||||||
logits = self.model(input_ids,
|
|
||||||
cache_ids=positions,
|
|
||||||
start_ids=seq_ids.flatten())
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
||||||
logits = self.logits_processor(self.model.chkpt_model.lm_head,
|
|
||||||
hidden_states, sampling_metadata)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def sample(
|
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[SamplerOutput]:
|
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def load_weights(self,
|
|
||||||
model_name_or_path: str,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
load_format: str = "auto",
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
**kwargs):
|
|
||||||
from transformers_neuronx.llama.model import LlamaForSampling
|
|
||||||
|
|
||||||
split_model_dir = f"{model_name_or_path}-split"
|
|
||||||
if os.path.isdir(os.path.join(model_name_or_path,
|
|
||||||
"pytorch_model.bin")):
|
|
||||||
split_model_dir = model_name_or_path
|
|
||||||
elif not os.path.exists(f"{model_name_or_path}-split"):
|
|
||||||
from transformers.models.llama import LlamaForCausalLM
|
|
||||||
from transformers_neuronx.module import save_pretrained_split
|
|
||||||
|
|
||||||
hf_model = LlamaForCausalLM.from_pretrained(model_name_or_path,
|
|
||||||
low_cpu_mem_usage=True)
|
|
||||||
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
|
|
||||||
|
|
||||||
self.model = LlamaForSampling.from_pretrained(split_model_dir,
|
|
||||||
**kwargs)
|
|
||||||
self.model.to_neuron()
|
|
||||||
@ -1,89 +0,0 @@
|
|||||||
"""Inference-only Mistral model compatible with HuggingFace weights."""
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
from transformers import MistralConfig
|
|
||||||
|
|
||||||
from vllm.model_executor.input_metadata import InputMetadata
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
||||||
from vllm.sequence import SamplerOutput
|
|
||||||
import os
|
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
|
||||||
|
|
||||||
|
|
||||||
class MistralForCausalLM(nn.Module):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: MistralConfig,
|
|
||||||
linear_method=None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
self.linear_method = linear_method
|
|
||||||
self.model = None
|
|
||||||
self.lm_head = None
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
kv_caches: List[KVCache],
|
|
||||||
input_metadata: InputMetadata,
|
|
||||||
) -> SamplerOutput:
|
|
||||||
with torch.inference_mode():
|
|
||||||
seq_ids = []
|
|
||||||
block_size = self.model.context_buckets[-1]
|
|
||||||
if input_metadata.is_prompt:
|
|
||||||
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
|
|
||||||
else:
|
|
||||||
seq_ids = input_metadata.block_tables
|
|
||||||
|
|
||||||
logits = self.model(input_ids,
|
|
||||||
cache_ids=positions,
|
|
||||||
start_ids=seq_ids)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def compute_logits(self, hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
||||||
logits = self.logits_processor(self.model.chkpt_model.lm_head,
|
|
||||||
hidden_states, sampling_metadata)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def sample(
|
|
||||||
self,
|
|
||||||
logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[SamplerOutput]:
|
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def load_weights(self,
|
|
||||||
model_name_or_path: str,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
load_format: str = "auto",
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
**kwargs):
|
|
||||||
from transformers_neuronx.mistral.model import MistralForSampling
|
|
||||||
|
|
||||||
split_model_dir = f"{model_name_or_path}-split"
|
|
||||||
if os.path.isdir(os.path.join(model_name_or_path,
|
|
||||||
"pytorch_model.bin")):
|
|
||||||
split_model_dir = model_name_or_path
|
|
||||||
elif not os.path.exists(f"{model_name_or_path}-split"):
|
|
||||||
from transformers import MistralForCausalLM
|
|
||||||
from transformers_neuronx.module import save_pretrained_split
|
|
||||||
|
|
||||||
hf_model = MistralForCausalLM.from_pretrained(
|
|
||||||
model_name_or_path, low_cpu_mem_usage=True)
|
|
||||||
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
|
|
||||||
|
|
||||||
self.model = MistralForSampling.from_pretrained(
|
|
||||||
split_model_dir, **kwargs)
|
|
||||||
self.model.to_neuron()
|
|
||||||
@ -1,12 +1,18 @@
|
|||||||
"""Utilities for selecting and loading models."""
|
"""Utilities for selecting and loading neuron models."""
|
||||||
from typing import Type
|
import importlib
|
||||||
|
import os
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.config import ModelConfig, DeviceConfig
|
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
TORCH_DTYPE_TO_NEURON_AMP = {
|
TORCH_DTYPE_TO_NEURON_AMP = {
|
||||||
"auto": "f32",
|
"auto": "f32",
|
||||||
@ -20,31 +26,95 @@ TORCH_DTYPE_TO_NEURON_AMP = {
|
|||||||
torch.float32: "f32",
|
torch.float32: "f32",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Models supported by Neuron.
|
||||||
|
_NEURON_SUPPORTED_MODELS = {
|
||||||
|
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
|
||||||
|
"LlamaForSampling", "LlamaForCausalLM"),
|
||||||
|
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
|
||||||
|
"MistralForSampling", "MistralForCausalLM")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NeuronCasualLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = None
|
||||||
|
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||||
|
logits_as_input=True)
|
||||||
|
self.sampler = Sampler()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
input_block_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
logits = self.model(input_ids,
|
||||||
|
cache_ids=positions,
|
||||||
|
start_ids=input_block_ids)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(None, hidden_states, sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||||
|
arch = _get_model_architecture(self.config)
|
||||||
|
neuronx_module_path, neuronx_model_cls, hf_model_cls = (
|
||||||
|
_NEURON_SUPPORTED_MODELS[arch])
|
||||||
|
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||||
|
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
|
||||||
|
|
||||||
|
split_model_dir = f"{model_name_or_path}-split"
|
||||||
|
if os.path.isdir(os.path.join(model_name_or_path,
|
||||||
|
"pytorch_model.bin")):
|
||||||
|
split_model_dir = model_name_or_path
|
||||||
|
elif not os.path.exists(f"{model_name_or_path}-split"):
|
||||||
|
hf_model_cls = getattr(transformers, hf_model_cls)
|
||||||
|
from transformers_neuronx.module import save_pretrained_split
|
||||||
|
|
||||||
|
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
|
||||||
|
low_cpu_mem_usage=True)
|
||||||
|
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
|
||||||
|
|
||||||
|
self.model = neuronx_model_cls.from_pretrained(split_model_dir,
|
||||||
|
**kwargs)
|
||||||
|
self.model.to_neuron()
|
||||||
|
|
||||||
|
|
||||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||||
architectures = getattr(config, "architectures", [])
|
architectures = getattr(config, "architectures", [])
|
||||||
for arch in architectures:
|
for arch in architectures:
|
||||||
model_cls = ModelRegistry.load_model_cls(arch)
|
if arch in _NEURON_SUPPORTED_MODELS:
|
||||||
if model_cls is not None:
|
return arch
|
||||||
return model_cls
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model architectures {architectures} are not supported for now. "
|
f"Model architectures {architectures} are not supported on Neuron "
|
||||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
f"for now. Supported architectures: "
|
||||||
|
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
def get_neuron_model(model_config: ModelConfig,
|
||||||
**kwargs) -> nn.Module:
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||||
from transformers_neuronx.config import (NeuronConfig,
|
from transformers_neuronx.config import (NeuronConfig,
|
||||||
ContinuousBatchingConfig)
|
ContinuousBatchingConfig)
|
||||||
|
|
||||||
parallel_config = kwargs.get("parallel_config")
|
|
||||||
scheduler_config = kwargs.get("scheduler_config")
|
|
||||||
|
|
||||||
model_class = _get_model_architecture(model_config.hf_config)
|
|
||||||
linear_method = None
|
|
||||||
|
|
||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
model = model_class(model_config.hf_config, linear_method)
|
model = NeuronCasualLM(model_config.hf_config)
|
||||||
|
|
||||||
continuous_batching_config = ContinuousBatchingConfig(
|
continuous_batching_config = ContinuousBatchingConfig(
|
||||||
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
||||||
@ -54,10 +124,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
|||||||
# Load the weights from the cached or downloaded files.
|
# Load the weights from the cached or downloaded files.
|
||||||
model.load_weights(
|
model.load_weights(
|
||||||
model_config.model,
|
model_config.model,
|
||||||
model_config.download_dir,
|
tp_degree=parallel_config.tensor_parallel_size,
|
||||||
model_config.load_format,
|
|
||||||
model_config.revision,
|
|
||||||
tp_degree=parallel_config.neuron_tp_degree,
|
|
||||||
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||||
neuron_config=neuron_config,
|
neuron_config=neuron_config,
|
||||||
context_length_estimate=[scheduler_config.max_model_len],
|
context_length_estimate=[scheduler_config.max_model_len],
|
||||||
|
|||||||
@ -4,11 +4,11 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
|
||||||
from vllm.sequence import SequenceData
|
|
||||||
from vllm.utils import in_wsl, is_neuron
|
|
||||||
from vllm.model_executor.layers.ops.sample import (
|
from vllm.model_executor.layers.ops.sample import (
|
||||||
get_num_triton_sampler_splits)
|
get_num_triton_sampler_splits)
|
||||||
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
|
from vllm.sequence import SequenceData
|
||||||
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
_SEED_0_REPLACEMENT = 3403598558
|
_SEED_0_REPLACEMENT = 3403598558
|
||||||
@ -213,7 +213,7 @@ class SamplingTensors:
|
|||||||
dtype: torch.dtype) -> "SamplingTensors":
|
dtype: torch.dtype) -> "SamplingTensors":
|
||||||
# Note that the performance will be very bad without
|
# Note that the performance will be very bad without
|
||||||
# pinned memory.
|
# pinned memory.
|
||||||
pin_memory = not in_wsl() and not is_neuron()
|
pin_memory = is_pin_memory_available()
|
||||||
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
|
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
|
||||||
prompt_padded_tokens = [
|
prompt_padded_tokens = [
|
||||||
tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
||||||
|
|||||||
@ -1,18 +1,10 @@
|
|||||||
"""Utils for model executor."""
|
"""Utils for model executor."""
|
||||||
import random
|
import random
|
||||||
import importlib
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import DeviceConfig, ModelConfig
|
|
||||||
|
|
||||||
DEVICE_TO_MODEL_LOADER_MAP = {
|
|
||||||
"cuda": "model_loader",
|
|
||||||
"neuron": "neuron_model_loader",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed: int) -> None:
|
def set_random_seed(seed: int) -> None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@ -41,12 +33,3 @@ def set_weight_attrs(
|
|||||||
assert not hasattr(
|
assert not hasattr(
|
||||||
weight, key), (f"Overwriting existing tensor attribute: {key}")
|
weight, key), (f"Overwriting existing tensor attribute: {key}")
|
||||||
setattr(weight, key, value)
|
setattr(weight, key, value)
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
|
||||||
**kwargs) -> torch.nn.Module:
|
|
||||||
model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type]
|
|
||||||
imported_model_loader = importlib.import_module(
|
|
||||||
f"vllm.model_executor.{model_loader_module}")
|
|
||||||
get_model_fn = imported_model_loader.get_model
|
|
||||||
return get_model_fn(model_config, device_config, **kwargs)
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import torch
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from vllm.utils import in_wsl
|
from vllm.utils import is_pin_memory_available
|
||||||
import time
|
import time
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ class AsyncMetricsCollector:
|
|||||||
|
|
||||||
self._in_flight_copy: Optional[torch.cuda.Event] = None
|
self._in_flight_copy: Optional[torch.cuda.Event] = None
|
||||||
|
|
||||||
pin_memory = not in_wsl()
|
pin_memory = is_pin_memory_available()
|
||||||
self._aggregate_num_accepted_tokens = torch.tensor(
|
self._aggregate_num_accepted_tokens = torch.tensor(
|
||||||
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
|
||||||
self._aggregate_num_emitted_tokens = torch.tensor(
|
self._aggregate_num_emitted_tokens = torch.tensor(
|
||||||
|
|||||||
@ -27,8 +27,8 @@ class MultiStepWorker(Worker):
|
|||||||
|
|
||||||
self._proposer: Optional[DraftModelTop1Proposer] = None
|
self._proposer: Optional[DraftModelTop1Proposer] = None
|
||||||
|
|
||||||
def init_model(self):
|
def init_device(self):
|
||||||
super().init_model()
|
super().init_device()
|
||||||
|
|
||||||
self._proposer = DraftModelTop1Proposer(
|
self._proposer = DraftModelTop1Proposer(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -79,13 +79,13 @@ class SpecDecodeWorker:
|
|||||||
|
|
||||||
self.scorer: SpeculativeScorer = None
|
self.scorer: SpeculativeScorer = None
|
||||||
|
|
||||||
def init_model(self) -> None:
|
def init_device(self) -> None:
|
||||||
"""Initialize both scorer and proposer models.
|
"""Initialize both scorer and proposer models.
|
||||||
"""
|
"""
|
||||||
# The scorer worker model is initialized first in case the proposer
|
# The scorer worker model is initialized first in case the proposer
|
||||||
# model has a smaller TP degree than the target worker.
|
# model has a smaller TP degree than the target worker.
|
||||||
self.scorer_worker.init_model()
|
self.scorer_worker.init_device()
|
||||||
self.proposer_worker.init_model()
|
self.proposer_worker.init_device()
|
||||||
|
|
||||||
self._metrics.init_gpu_tensors(self.rank)
|
self._metrics.init_gpu_tensors(self.rank)
|
||||||
self.rejection_sampler.init_gpu_tensors(self.rank)
|
self.rejection_sampler.init_gpu_tensors(self.rank)
|
||||||
|
|||||||
@ -338,7 +338,27 @@ def create_kv_caches_with_random(
|
|||||||
return key_caches, value_caches
|
return key_caches, value_caches
|
||||||
|
|
||||||
|
|
||||||
class measure_cuda_memory:
|
@lru_cache
|
||||||
|
def print_warning_once(msg: str) -> None:
|
||||||
|
logger.warning(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def is_pin_memory_available() -> bool:
|
||||||
|
|
||||||
|
if in_wsl():
|
||||||
|
# Pinning memory in WSL is not supported.
|
||||||
|
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
|
||||||
|
print_warning_once("Using 'pin_memory=False' as WSL is detected. "
|
||||||
|
"This may slow down the performance.")
|
||||||
|
return False
|
||||||
|
elif is_neuron():
|
||||||
|
print_warning_once("Pin memory is not supported on Neuron.")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class CudaMemoryProfiler:
|
||||||
|
|
||||||
def __init__(self, device=None):
|
def __init__(self, device=None):
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -360,3 +380,44 @@ class measure_cuda_memory:
|
|||||||
|
|
||||||
# Force garbage collection
|
# Force garbage collection
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
|
||||||
|
assert len(x) <= max_len
|
||||||
|
return x + [pad] * (max_len - len(x))
|
||||||
|
|
||||||
|
|
||||||
|
def make_tensor_with_pad(
|
||||||
|
x: List[List[int]],
|
||||||
|
max_len: int,
|
||||||
|
pad: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: Optional[Union[str, torch.device]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Make a padded tensor of a 2D inputs.
|
||||||
|
|
||||||
|
The padding is applied to the end of each inner list until it reaches
|
||||||
|
`max_len`.
|
||||||
|
"""
|
||||||
|
padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
|
||||||
|
return torch.tensor(padded_x, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def async_tensor_h2d(
|
||||||
|
data: list,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
target_device: Union[str, torch.device],
|
||||||
|
pin_memory: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Asynchronously create a tensor and copy it from host to device."""
|
||||||
|
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
|
||||||
|
return t.to(device=target_device, non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_expand_dim(tensor: torch.Tensor,
|
||||||
|
target_dims: int,
|
||||||
|
size: int = 1) -> torch.Tensor:
|
||||||
|
"""Expand the tensor to the target_dims."""
|
||||||
|
if tensor.ndim < target_dims:
|
||||||
|
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
|
||||||
|
return tensor
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import is_pin_memory_available, STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -38,10 +38,6 @@ class CacheEngine:
|
|||||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||||
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
||||||
|
|
||||||
# Skip initializing KV cache for Neuron backend.
|
|
||||||
if is_neuron():
|
|
||||||
return
|
|
||||||
|
|
||||||
if cache_config.cache_dtype == "auto":
|
if cache_config.cache_dtype == "auto":
|
||||||
self.dtype = model_config.dtype
|
self.dtype = model_config.dtype
|
||||||
else:
|
else:
|
||||||
@ -90,12 +86,7 @@ class CacheEngine:
|
|||||||
cpu_cache: List[KVCache] = []
|
cpu_cache: List[KVCache] = []
|
||||||
key_block_shape = self.get_key_block_shape()
|
key_block_shape = self.get_key_block_shape()
|
||||||
value_block_shape = self.get_value_block_shape()
|
value_block_shape = self.get_value_block_shape()
|
||||||
pin_memory = not in_wsl()
|
pin_memory = is_pin_memory_available()
|
||||||
if not pin_memory:
|
|
||||||
# Pinning memory in WSL is not supported.
|
|
||||||
# https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
|
|
||||||
logger.warning("Using 'pin_memory=False' as WSL is detected. "
|
|
||||||
"This may slow down the performance.")
|
|
||||||
for _ in range(self.num_layers):
|
for _ in range(self.num_layers):
|
||||||
key_blocks = torch.empty(
|
key_blocks = torch.empty(
|
||||||
size=(self.num_cpu_blocks, *key_block_shape),
|
size=(self.num_cpu_blocks, *key_block_shape),
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Tuple, Set, Union
|
from typing import Dict, List, Optional, Tuple, Set
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -9,7 +9,8 @@ import torch.nn as nn
|
|||||||
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
|
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
|
||||||
SchedulerConfig)
|
SchedulerConfig)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
from vllm.model_executor import InputMetadata, SamplingMetadata
|
||||||
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.parallel_utils import cupy_utils
|
from vllm.model_executor.parallel_utils import cupy_utils
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
broadcast_tensor_dict)
|
broadcast_tensor_dict)
|
||||||
@ -21,7 +22,9 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
|||||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||||
from vllm.lora.layers import LoRAMapping
|
from vllm.lora.layers import LoRAMapping
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.utils import in_wsl, measure_cuda_memory
|
from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler,
|
||||||
|
is_pin_memory_available, make_tensor_with_pad,
|
||||||
|
maybe_expand_dim)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -79,16 +82,11 @@ class ModelRunner:
|
|||||||
# The shape of the cached block table will be
|
# The shape of the cached block table will be
|
||||||
# (max batch size to capture, max context len to capture / block size).
|
# (max batch size to capture, max context len to capture / block size).
|
||||||
self.graph_block_tables = None # Set after initial profiling.
|
self.graph_block_tables = None # Set after initial profiling.
|
||||||
# cache in_wsl result
|
self.pin_memory = is_pin_memory_available()
|
||||||
self.in_wsl = in_wsl()
|
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
# Set enforce_eager to True for Neuron backend, to avoid capturing graph
|
|
||||||
if self.device_config.is_neuron:
|
|
||||||
self.model_config.enforce_eager = True
|
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
with measure_cuda_memory() as m:
|
with CudaMemoryProfiler() as m:
|
||||||
self.model = get_model(self.model_config,
|
self.model = get_model(self.model_config,
|
||||||
self.device_config,
|
self.device_config,
|
||||||
lora_config=self.lora_config,
|
lora_config=self.lora_config,
|
||||||
@ -238,7 +236,7 @@ class ModelRunner:
|
|||||||
device=self.device)
|
device=self.device)
|
||||||
# Prepare prefix block tables
|
# Prepare prefix block tables
|
||||||
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
|
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
|
||||||
block_tables = _make_tensor_with_pad(
|
block_tables = make_tensor_with_pad(
|
||||||
prefix_block_tables,
|
prefix_block_tables,
|
||||||
max_len=max_prompt_block_table_len,
|
max_len=max_prompt_block_table_len,
|
||||||
pad=0,
|
pad=0,
|
||||||
@ -395,7 +393,7 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
max_block_table_len = max(
|
max_block_table_len = max(
|
||||||
len(block_table) for block_table in block_tables)
|
len(block_table) for block_table in block_tables)
|
||||||
block_tables = _make_tensor_with_pad(
|
block_tables = make_tensor_with_pad(
|
||||||
block_tables,
|
block_tables,
|
||||||
max_len=max_block_table_len,
|
max_len=max_block_table_len,
|
||||||
pad=0,
|
pad=0,
|
||||||
@ -436,7 +434,6 @@ class ModelRunner:
|
|||||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||||
categorized_sample_indices_start_idx = 0
|
categorized_sample_indices_start_idx = 0
|
||||||
categorized_sampled_token_indices_start_idx = 0
|
categorized_sampled_token_indices_start_idx = 0
|
||||||
pin_memory = not self.in_wsl and not self.device_config.is_neuron
|
|
||||||
|
|
||||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
@ -469,7 +466,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
if sampling_params.seed is not None:
|
if sampling_params.seed is not None:
|
||||||
seq_group_metadata.state.generator = torch.Generator(
|
seq_group_metadata.state.generator = torch.Generator(
|
||||||
device="cuda").manual_seed(sampling_params.seed)
|
device=self.device).manual_seed(sampling_params.seed)
|
||||||
else:
|
else:
|
||||||
num_seqs = len(seq_ids)
|
num_seqs = len(seq_ids)
|
||||||
selected_token_indices.extend(
|
selected_token_indices.extend(
|
||||||
@ -494,17 +491,17 @@ class ModelRunner:
|
|||||||
if sampling_params.seed is not None:
|
if sampling_params.seed is not None:
|
||||||
generators.append(seq_group_metadata.state.generator)
|
generators.append(seq_group_metadata.state.generator)
|
||||||
|
|
||||||
selected_token_indices = _async_h2d(selected_token_indices,
|
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
target_device=self.device,
|
target_device=self.device,
|
||||||
pin_memory=not self.in_wsl)
|
pin_memory=self.pin_memory)
|
||||||
|
|
||||||
categorized_sample_indices = {
|
categorized_sample_indices = {
|
||||||
t: _maybe_expand_dim(
|
t: maybe_expand_dim(
|
||||||
_async_h2d(seq_ids,
|
async_tensor_h2d(seq_ids,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
target_device=self.device,
|
target_device=self.device,
|
||||||
pin_memory=pin_memory), 2, 2)
|
pin_memory=self.pin_memory), 2, 2)
|
||||||
for t, seq_ids in categorized_sample_indices.items()
|
for t, seq_ids in categorized_sample_indices.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -910,27 +907,6 @@ def _maybe_cupy_nccl():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
|
|
||||||
assert len(x) <= max_len
|
|
||||||
return x + [pad] * (max_len - len(x))
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tensor_with_pad(
|
|
||||||
x: List[List[int]],
|
|
||||||
max_len: int,
|
|
||||||
pad: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: Optional[Union[str, torch.device]],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Make a padded tensor of a 2D inputs.
|
|
||||||
|
|
||||||
The padding is applied to the end of each inner list until it reaches
|
|
||||||
`max_len`.
|
|
||||||
"""
|
|
||||||
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
|
||||||
return torch.tensor(padded_x, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_graph_batch_size(batch_size: int) -> int:
|
def _get_graph_batch_size(batch_size: int) -> int:
|
||||||
"""Returns the padded batch size given actual batch size.
|
"""Returns the padded batch size given actual batch size.
|
||||||
|
|
||||||
@ -944,21 +920,3 @@ def _get_graph_batch_size(batch_size: int) -> int:
|
|||||||
else:
|
else:
|
||||||
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
||||||
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
||||||
|
|
||||||
|
|
||||||
def _async_h2d(
|
|
||||||
data: list,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
target_device: Union[str, torch.device],
|
|
||||||
pin_memory: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
|
|
||||||
return t.to(device=target_device, non_blocking=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_expand_dim(tensor: torch.Tensor,
|
|
||||||
target_dims: int,
|
|
||||||
size: int = 1) -> torch.Tensor:
|
|
||||||
if tensor.ndim < target_dims:
|
|
||||||
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
|
|
||||||
return tensor
|
|
||||||
|
|||||||
287
vllm/worker/neuron_model_runner.py
Normal file
287
vllm/worker/neuron_model_runner.py
Normal file
@ -0,0 +1,287 @@
|
|||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||||
|
SchedulerConfig)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
from vllm.model_executor.neuron_model_loader import get_neuron_model
|
||||||
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
||||||
|
make_tensor_with_pad, maybe_expand_dim)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class NeuronModelRunner:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
):
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
|
||||||
|
if model_config is not None and model_config.get_sliding_window():
|
||||||
|
logger.warning("Sliding window is not supported on Neuron. "
|
||||||
|
"The model will run without sliding window.")
|
||||||
|
self.device_config = (device_config
|
||||||
|
if device_config is not None else DeviceConfig())
|
||||||
|
self.device = self.device_config.device
|
||||||
|
self.model = None
|
||||||
|
self.pin_memory = is_pin_memory_available()
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
self.model = get_neuron_model(self.model_config,
|
||||||
|
parallel_config=self.parallel_config,
|
||||||
|
scheduler_config=self.scheduler_config)
|
||||||
|
|
||||||
|
def _prepare_prompt(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[List[int]] = []
|
||||||
|
input_positions: List[List[int]] = []
|
||||||
|
input_block_ids: List[int] = []
|
||||||
|
|
||||||
|
prompt_lens: List[int] = []
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
assert seq_group_metadata.is_prompt
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
seq_id = seq_ids[0]
|
||||||
|
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
prompt_tokens = seq_data.get_token_ids()
|
||||||
|
prompt_len = len(prompt_tokens)
|
||||||
|
prompt_lens.append(prompt_len)
|
||||||
|
|
||||||
|
input_tokens.append(prompt_tokens)
|
||||||
|
input_positions.append(list(range(prompt_len)))
|
||||||
|
|
||||||
|
assert seq_group_metadata.block_tables is not None
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
assert len(block_table) == 1
|
||||||
|
input_block_ids.append(block_table[0])
|
||||||
|
|
||||||
|
max_prompt_len = max(prompt_lens)
|
||||||
|
assert max_prompt_len > 0
|
||||||
|
input_tokens = make_tensor_with_pad(input_tokens,
|
||||||
|
max_prompt_len,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
input_positions = make_tensor_with_pad(input_positions,
|
||||||
|
max_prompt_len,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
input_block_ids = torch.tensor(input_block_ids,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
return input_tokens, input_positions, input_block_ids, prompt_lens
|
||||||
|
|
||||||
|
def _prepare_decode(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[List[int]] = []
|
||||||
|
input_positions: List[List[int]] = []
|
||||||
|
input_block_ids: List[int] = []
|
||||||
|
context_lens: List[int] = []
|
||||||
|
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
assert not seq_group_metadata.is_prompt
|
||||||
|
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
generation_token = seq_data.get_last_token_id()
|
||||||
|
input_tokens.append([generation_token])
|
||||||
|
|
||||||
|
seq_len = seq_data.get_len()
|
||||||
|
position = seq_len - 1
|
||||||
|
input_positions.append([position])
|
||||||
|
context_lens.append(seq_len)
|
||||||
|
|
||||||
|
assert seq_group_metadata.block_tables is not None
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
assert len(block_table) == 1
|
||||||
|
input_block_ids.append(block_table[0])
|
||||||
|
|
||||||
|
input_tokens = make_tensor_with_pad(input_tokens,
|
||||||
|
max_len=1,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
input_positions = make_tensor_with_pad(input_positions,
|
||||||
|
max_len=1,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
context_lens = torch.tensor(context_lens,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=self.device)
|
||||||
|
input_block_ids = torch.tensor(input_block_ids,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
return input_tokens, input_positions, input_block_ids
|
||||||
|
|
||||||
|
def _prepare_sample(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
prompt_lens: List[int],
|
||||||
|
) -> SamplingMetadata:
|
||||||
|
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||||
|
selected_token_indices: List[int] = []
|
||||||
|
generators: List[torch.Generator] = []
|
||||||
|
selected_token_start_idx = 0
|
||||||
|
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||||
|
categorized_sample_indices_start_idx = 0
|
||||||
|
categorized_sampled_token_indices_start_idx = 0
|
||||||
|
|
||||||
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
|
seq_groups.append((seq_ids, sampling_params))
|
||||||
|
|
||||||
|
if seq_group_metadata.is_prompt:
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
assert prompt_lens is not None
|
||||||
|
prompt_len = prompt_lens[i]
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
# NOTE: prompt token positions do not need sample, skip
|
||||||
|
categorized_sample_indices_start_idx += prompt_len - 1
|
||||||
|
|
||||||
|
categorized_sample_indices[
|
||||||
|
sampling_params.sampling_type].append([
|
||||||
|
categorized_sample_indices_start_idx,
|
||||||
|
categorized_sampled_token_indices_start_idx
|
||||||
|
])
|
||||||
|
categorized_sample_indices_start_idx += 1
|
||||||
|
categorized_sampled_token_indices_start_idx += 1
|
||||||
|
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
selected_token_indices.extend(
|
||||||
|
range(selected_token_start_idx,
|
||||||
|
selected_token_start_idx + prompt_len - 1))
|
||||||
|
selected_token_indices.append(selected_token_start_idx +
|
||||||
|
prompt_len - 1)
|
||||||
|
selected_token_start_idx += prompt_len
|
||||||
|
|
||||||
|
if sampling_params.seed is not None:
|
||||||
|
seq_group_metadata.state.generator = torch.Generator(
|
||||||
|
device=self.device).manual_seed(sampling_params.seed)
|
||||||
|
else:
|
||||||
|
num_seqs = len(seq_ids)
|
||||||
|
selected_token_indices.extend(
|
||||||
|
range(selected_token_start_idx,
|
||||||
|
selected_token_start_idx + num_seqs))
|
||||||
|
selected_token_start_idx += num_seqs
|
||||||
|
|
||||||
|
categorized_sample_indices[
|
||||||
|
sampling_params.sampling_type].extend(
|
||||||
|
zip(
|
||||||
|
range(
|
||||||
|
categorized_sample_indices_start_idx,
|
||||||
|
categorized_sample_indices_start_idx +
|
||||||
|
num_seqs),
|
||||||
|
range(
|
||||||
|
categorized_sampled_token_indices_start_idx,
|
||||||
|
categorized_sampled_token_indices_start_idx +
|
||||||
|
num_seqs)))
|
||||||
|
categorized_sample_indices_start_idx += num_seqs
|
||||||
|
categorized_sampled_token_indices_start_idx += num_seqs
|
||||||
|
|
||||||
|
if sampling_params.seed is not None:
|
||||||
|
generators.append(seq_group_metadata.state.generator)
|
||||||
|
|
||||||
|
selected_token_indices = async_tensor_h2d(selected_token_indices,
|
||||||
|
dtype=torch.long,
|
||||||
|
target_device=self.device,
|
||||||
|
pin_memory=self.pin_memory)
|
||||||
|
|
||||||
|
categorized_sample_indices = {
|
||||||
|
t: maybe_expand_dim(
|
||||||
|
async_tensor_h2d(seq_ids,
|
||||||
|
dtype=torch.int,
|
||||||
|
target_device=self.device,
|
||||||
|
pin_memory=self.pin_memory), 2, 2)
|
||||||
|
for t, seq_ids in categorized_sample_indices.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
seq_data: Dict[int, SequenceData] = {}
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
seq_data.update(seq_group_metadata.seq_data)
|
||||||
|
|
||||||
|
sampling_metadata = SamplingMetadata(
|
||||||
|
seq_groups=seq_groups,
|
||||||
|
seq_data=seq_data,
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
selected_token_indices=selected_token_indices,
|
||||||
|
categorized_sample_indices=categorized_sample_indices,
|
||||||
|
generators=generators,
|
||||||
|
)
|
||||||
|
return sampling_metadata
|
||||||
|
|
||||||
|
def prepare_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
|
||||||
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
|
# all decodes.
|
||||||
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
|
# Prepare input tensors.
|
||||||
|
if is_prompt:
|
||||||
|
(input_tokens, input_positions, input_block_ids,
|
||||||
|
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
|
||||||
|
else:
|
||||||
|
(input_tokens, input_positions,
|
||||||
|
input_block_ids) = self._prepare_decode(seq_group_metadata_list)
|
||||||
|
prompt_lens = []
|
||||||
|
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||||
|
prompt_lens)
|
||||||
|
|
||||||
|
return (input_tokens, input_positions, input_block_ids,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
(input_tokens, input_positions, input_block_ids, sampling_metadata
|
||||||
|
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||||
|
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids=input_tokens,
|
||||||
|
positions=input_positions,
|
||||||
|
input_block_ids=input_block_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute the logits.
|
||||||
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
output = self.model.sample(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self.model_config.get_vocab_size()
|
||||||
@ -1,22 +1,17 @@
|
|||||||
"""A Neuron worker class."""
|
"""A Neuron worker class."""
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
|
||||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
SchedulerConfig)
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.model_executor.parallel_utils.communication_op import (
|
|
||||||
broadcast_tensor_dict)
|
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
|
||||||
ensure_model_parallel_initialized)
|
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
||||||
from vllm.worker.model_runner import ModelRunner
|
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class NeuronWorker:
|
||||||
"""A worker class that executes the model on a group of neuron cores.
|
"""A worker class that executes the model on a group of neuron cores.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -26,168 +21,32 @@ class Worker:
|
|||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
local_rank: int,
|
|
||||||
rank: int,
|
|
||||||
distributed_init_method: str,
|
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
|
||||||
kv_cache_dtype: Optional[str] = "auto",
|
|
||||||
is_driver_worker: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
self.local_rank = local_rank
|
|
||||||
self.rank = rank
|
|
||||||
self.distributed_init_method = distributed_init_method
|
|
||||||
self.lora_config = lora_config
|
|
||||||
self.is_driver_worker = is_driver_worker
|
|
||||||
if self.is_driver_worker:
|
|
||||||
assert self.rank == 0, "The driver worker must have rank 0."
|
|
||||||
|
|
||||||
self.model_runner = ModelRunner(model_config,
|
self.model_runner = NeuronModelRunner(model_config, parallel_config,
|
||||||
parallel_config,
|
scheduler_config, device_config)
|
||||||
scheduler_config,
|
|
||||||
device_config,
|
|
||||||
lora_config=self.lora_config,
|
|
||||||
is_driver_worker=is_driver_worker)
|
|
||||||
# Uninitialized cache engine. Will be initialized by
|
|
||||||
# self.init_cache_engine().
|
|
||||||
self.cache_config = None
|
|
||||||
self.cache_engine = None
|
|
||||||
self.cache_events = None
|
|
||||||
self.gpu_cache = None
|
|
||||||
|
|
||||||
def init_model(self) -> None:
|
def init_device(self) -> None:
|
||||||
# Initialize the distributed environment.
|
# Set random seed.
|
||||||
_init_distributed_environment(self.parallel_config,
|
|
||||||
self.rank,
|
|
||||||
self.distributed_init_method,
|
|
||||||
distributed_backend="gloo")
|
|
||||||
|
|
||||||
# Initialize the model.
|
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
self.model_runner.load_model()
|
self.model_runner.load_model()
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def profile_num_available_blocks(
|
|
||||||
self,
|
|
||||||
block_size: int = 128,
|
|
||||||
gpu_memory_utilization: float = 0.9,
|
|
||||||
cpu_swap_space: int = 0,
|
|
||||||
cache_dtype: str = "float16",
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
"""Simply returns max_num_seqs as num_gpu_blocks, 0 as
|
|
||||||
num_cpu_blocks."""
|
|
||||||
num_gpu_blocks = self.scheduler_config.max_num_seqs
|
|
||||||
num_cpu_blocks = 0
|
|
||||||
return num_gpu_blocks, num_cpu_blocks
|
|
||||||
|
|
||||||
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
|
||||||
self.cache_config = cache_config
|
|
||||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
|
||||||
self.parallel_config)
|
|
||||||
self.model_runner.set_block_size(self.cache_engine.block_size)
|
|
||||||
|
|
||||||
def warm_up_model(self) -> None:
|
|
||||||
# Warm up is maintained in transformers-neuronx
|
|
||||||
pass
|
|
||||||
|
|
||||||
def cache_swap(
|
|
||||||
self,
|
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
) -> None:
|
|
||||||
# Issue cache operations.
|
|
||||||
issued_cache_op = False
|
|
||||||
if blocks_to_swap_in:
|
|
||||||
self.cache_engine.swap_in(blocks_to_swap_in)
|
|
||||||
issued_cache_op = True
|
|
||||||
if blocks_to_swap_out:
|
|
||||||
self.cache_engine.swap_out(blocks_to_swap_out)
|
|
||||||
issued_cache_op = True
|
|
||||||
if blocks_to_copy:
|
|
||||||
self.cache_engine.copy(blocks_to_copy)
|
|
||||||
issued_cache_op = True
|
|
||||||
|
|
||||||
cache_events = self.cache_events if issued_cache_op else None
|
|
||||||
|
|
||||||
# Wait for cache operations to finish.
|
|
||||||
if cache_events is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"cache operations are not implemented for neuron backend.")
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
|
||||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
|
||||||
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
if self.is_driver_worker:
|
num_seq_groups = len(seq_group_metadata_list)
|
||||||
assert seq_group_metadata_list is not None
|
|
||||||
num_seq_groups = len(seq_group_metadata_list)
|
|
||||||
assert blocks_to_swap_in is not None
|
|
||||||
assert blocks_to_swap_out is not None
|
|
||||||
assert blocks_to_copy is not None
|
|
||||||
data = {
|
|
||||||
"num_seq_groups": num_seq_groups,
|
|
||||||
"blocks_to_swap_in": blocks_to_swap_in,
|
|
||||||
"blocks_to_swap_out": blocks_to_swap_out,
|
|
||||||
"blocks_to_copy": blocks_to_copy,
|
|
||||||
}
|
|
||||||
broadcast_tensor_dict(data, src=0)
|
|
||||||
else:
|
|
||||||
data = broadcast_tensor_dict(src=0)
|
|
||||||
num_seq_groups = data["num_seq_groups"]
|
|
||||||
blocks_to_swap_in = data["blocks_to_swap_in"]
|
|
||||||
blocks_to_swap_out = data["blocks_to_swap_out"]
|
|
||||||
blocks_to_copy = data["blocks_to_copy"]
|
|
||||||
|
|
||||||
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
|
||||||
|
|
||||||
# If there is no input, we don't need to execute the model.
|
# If there is no input, we don't need to execute the model.
|
||||||
if num_seq_groups == 0:
|
if num_seq_groups == 0:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
output = self.model_runner.execute_model(seq_group_metadata_list)
|
||||||
self.gpu_cache)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _init_distributed_environment(
|
|
||||||
parallel_config: ParallelConfig,
|
|
||||||
rank: int,
|
|
||||||
distributed_init_method: Optional[str] = None,
|
|
||||||
distributed_backend: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the distributed environment."""
|
|
||||||
if torch.distributed.is_initialized():
|
|
||||||
torch_world_size = torch.distributed.get_world_size()
|
|
||||||
if torch_world_size != parallel_config.world_size:
|
|
||||||
raise RuntimeError(
|
|
||||||
"torch.distributed is already initialized but the torch world "
|
|
||||||
"size does not match parallel_config.world_size "
|
|
||||||
f"({torch_world_size} vs. {parallel_config.world_size}).")
|
|
||||||
elif not distributed_init_method:
|
|
||||||
raise ValueError(
|
|
||||||
"distributed_init_method must be set if torch.distributed "
|
|
||||||
"is not already initialized")
|
|
||||||
else:
|
|
||||||
distributed_backend = (distributed_backend
|
|
||||||
if distributed_backend else "nccl")
|
|
||||||
torch.distributed.init_process_group(
|
|
||||||
backend=distributed_backend,
|
|
||||||
world_size=parallel_config.world_size,
|
|
||||||
rank=rank,
|
|
||||||
init_method=distributed_init_method,
|
|
||||||
)
|
|
||||||
|
|
||||||
# A small all_reduce for warmup.
|
|
||||||
torch.distributed.all_reduce(torch.zeros(1))
|
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
|
||||||
parallel_config.pipeline_parallel_size)
|
|
||||||
|
|||||||
@ -67,7 +67,7 @@ class Worker:
|
|||||||
self.cache_engine = None
|
self.cache_engine = None
|
||||||
self.gpu_cache = None
|
self.gpu_cache = None
|
||||||
|
|
||||||
def init_model(self, cupy_port: Optional[int] = None) -> None:
|
def init_device(self, cupy_port: Optional[int] = None) -> None:
|
||||||
if self.device_config.device.type == "cuda":
|
if self.device_config.device.type == "cuda":
|
||||||
# torch.distributed.all_reduce does not free the input tensor until
|
# torch.distributed.all_reduce does not free the input tensor until
|
||||||
# the synchronization point. This causes the memory usage to grow
|
# the synchronization point. This causes the memory usage to grow
|
||||||
@ -91,7 +91,7 @@ class Worker:
|
|||||||
# Initialize the distributed environment.
|
# Initialize the distributed environment.
|
||||||
init_distributed_environment(self.parallel_config, self.rank,
|
init_distributed_environment(self.parallel_config, self.rank,
|
||||||
cupy_port, self.distributed_init_method)
|
cupy_port, self.distributed_init_method)
|
||||||
# Initialize the model.
|
# Set random seed.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user