From eb46fbfda25348422918c4a876e17aef05fc5e34 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 15 Apr 2024 13:05:09 -0700 Subject: [PATCH] [Core] Simplifications to executor classes (#4071) --- vllm/executor/cpu_executor.py | 31 +++++++++------------------- vllm/executor/executor_base.py | 27 +++++++++++++++++------- vllm/executor/gpu_executor.py | 32 ++++------------------------- vllm/executor/neuron_executor.py | 29 ++++++-------------------- vllm/executor/ray_gpu_executor.py | 34 ++++--------------------------- 5 files changed, 44 insertions(+), 109 deletions(-) diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index e63a88be7868f..f562e4e0ae3de 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,10 +1,9 @@ import os -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple import torch -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig) +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -16,23 +15,13 @@ logger = init_logger(__name__) class CPUExecutor(ExecutorBase): - def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], *args, **kwargs) -> None: - assert device_config.device_type == "cpu" - assert lora_config is None, "cpu backend doesn't support LoRA" - model_config = _verify_and_get_model_config(model_config) - cache_config = _verify_and_get_cache_config(cache_config) - scheduler_config = _verify_and_get_scheduler_config(scheduler_config) - - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config + def _init_executor(self) -> None: + assert self.device_config.device_type == "cpu" + assert self.lora_config is None, "cpu backend doesn't support LoRA" + self.model_config = _verify_and_get_model_config(self.model_config) + self.cache_config = _verify_and_get_cache_config(self.cache_config) + self.scheduler_config = _verify_and_get_scheduler_config( + self.scheduler_config) # Instantiate the worker and load the model to CPU. self._init_worker() @@ -99,7 +88,7 @@ class CPUExecutor(ExecutorBase): def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index bbfbfc689c99f..bbb6ec80f7b7e 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + TensorizerConfig, VisionLanguageConfig) from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata @@ -16,7 +16,6 @@ class ExecutorBase(ABC): that can execute the model on multiple devices. """ - @abstractmethod def __init__( self, model_config: ModelConfig, @@ -27,8 +26,23 @@ class ExecutorBase(ABC): lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], speculative_config: Optional[SpeculativeConfig], + tensorizer_config: Optional[TensorizerConfig], ) -> None: - raise NotImplementedError + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.vision_language_config = vision_language_config + self.speculative_config = speculative_config + self.tensorizer_config = tensorizer_config + + self._init_executor() + + @abstractmethod + def _init_executor(self) -> None: + pass @abstractmethod def determine_num_available_blocks(self) -> Tuple[int, int]: @@ -71,7 +85,7 @@ class ExecutorBase(ABC): raise NotImplementedError @abstractmethod - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: raise NotImplementedError @abstractmethod @@ -94,8 +108,7 @@ class ExecutorAsyncBase(ExecutorBase): """Executes one model step on the given sequences.""" raise NotImplementedError - @abstractmethod async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" - raise NotImplementedError + self.check_health() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 30577ecf62faa..bae509f48025b 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,8 +1,5 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -15,24 +12,8 @@ logger = init_logger(__name__) class GPUExecutor(ExecutorBase): - def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - tensorizer_config: Optional[TensorizerConfig]) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.vision_language_config = vision_language_config - self.tensorizer_config = tensorizer_config - - assert (not speculative_config + def _init_executor(self) -> None: + assert (not self.speculative_config ), "Speculative decoding not yet supported for GPU backend" # Instantiate the worker and load the model to GPU. @@ -103,7 +84,7 @@ class GPUExecutor(ExecutorBase): assert lora_id > 0, "lora_id must be greater than 0." return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: @@ -127,8 +108,3 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy) return output - - async def check_health_async(self) -> None: - # GPUExecutor will always be healthy as long as - # it's running. - return diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index d45f18e466256..273b17a927efd 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,8 +1,5 @@ -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -13,24 +10,10 @@ logger = init_logger(__name__) class NeuronExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - assert lora_config is None, "LoRA is not supported for Neuron backend." - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - assert (not speculative_config + def _init_executor(self) -> None: + assert (self.lora_config is + None), "LoRA is not supported for Neuron backend." + assert (not self.speculative_config ), "Speculative decoding not yet supported for Neuron backend." # Instantiate the worker and load the model to the device. @@ -80,7 +63,7 @@ class NeuronExecutor(ExecutorBase): def remove_lora(self, lora_id: int) -> bool: return self.driver_worker.remove_lora(lora_id) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() def check_health(self) -> None: diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 28dc3e0db312a..5db2f3f652532 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,11 +3,8 @@ import copy import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple -from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - TensorizerConfig, VisionLanguageConfig) from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -32,27 +29,8 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) class RayGPUExecutor(ExecutorBase): - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - tensorizer_config: Optional[TensorizerConfig], - ) -> None: - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.vision_language_config = vision_language_config - self.tensorizer_config = tensorizer_config - assert (not speculative_config + def _init_executor(self) -> None: + assert (not self.speculative_config ), "Speculative decoding not yet supported for RayGPU backend." assert self.parallel_config.worker_use_ray @@ -273,7 +251,7 @@ class RayGPUExecutor(ExecutorBase): lora_id=lora_id, ) - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: return self._run_workers("list_loras") def _run_workers( @@ -416,7 +394,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): # Only the driver worker returns the sampling results. output = all_outputs[0] return output - - async def check_health_async(self) -> None: - """Raises an error if engine is unhealthy.""" - self._check_if_any_actor_is_dead()