[Core] Simplifications to executor classes (#4071)

This commit is contained in:
Nick Hill 2024-04-15 13:05:09 -07:00 committed by GitHub
parent 0003e9154b
commit eb46fbfda2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 44 additions and 109 deletions

View File

@ -1,10 +1,9 @@
import os import os
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Set, Tuple
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
ParallelConfig, SchedulerConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -16,23 +15,13 @@ logger = init_logger(__name__)
class CPUExecutor(ExecutorBase): class CPUExecutor(ExecutorBase):
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, def _init_executor(self) -> None:
parallel_config: ParallelConfig, assert self.device_config.device_type == "cpu"
scheduler_config: SchedulerConfig, assert self.lora_config is None, "cpu backend doesn't support LoRA"
device_config: DeviceConfig, self.model_config = _verify_and_get_model_config(self.model_config)
lora_config: Optional[LoRAConfig], *args, **kwargs) -> None: self.cache_config = _verify_and_get_cache_config(self.cache_config)
assert device_config.device_type == "cpu" self.scheduler_config = _verify_and_get_scheduler_config(
assert lora_config is None, "cpu backend doesn't support LoRA" self.scheduler_config)
model_config = _verify_and_get_model_config(model_config)
cache_config = _verify_and_get_cache_config(cache_config)
scheduler_config = _verify_and_get_scheduler_config(scheduler_config)
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
# Instantiate the worker and load the model to CPU. # Instantiate the worker and load the model to CPU.
self._init_worker() self._init_worker()
@ -99,7 +88,7 @@ class CPUExecutor(ExecutorBase):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id) return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()
def check_health(self) -> None: def check_health(self) -> None:

View File

@ -1,9 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig) TensorizerConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@ -16,7 +16,6 @@ class ExecutorBase(ABC):
that can execute the model on multiple devices. that can execute the model on multiple devices.
""" """
@abstractmethod
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
@ -27,8 +26,23 @@ class ExecutorBase(ABC):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig],
) -> None: ) -> None:
raise NotImplementedError self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.speculative_config = speculative_config
self.tensorizer_config = tensorizer_config
self._init_executor()
@abstractmethod
def _init_executor(self) -> None:
pass
@abstractmethod @abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
@ -71,7 +85,7 @@ class ExecutorBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
@ -94,8 +108,7 @@ class ExecutorAsyncBase(ExecutorBase):
"""Executes one model step on the given sequences.""" """Executes one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an """Checks if the executor is healthy. If not, it should raise an
exception.""" exception."""
raise NotImplementedError self.check_health()

View File

@ -1,8 +1,5 @@
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TensorizerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -15,24 +12,8 @@ logger = init_logger(__name__)
class GPUExecutor(ExecutorBase): class GPUExecutor(ExecutorBase):
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, def _init_executor(self) -> None:
parallel_config: ParallelConfig, assert (not self.speculative_config
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig]) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.tensorizer_config = tensorizer_config
assert (not speculative_config
), "Speculative decoding not yet supported for GPU backend" ), "Speculative decoding not yet supported for GPU backend"
# Instantiate the worker and load the model to GPU. # Instantiate the worker and load the model to GPU.
@ -103,7 +84,7 @@ class GPUExecutor(ExecutorBase):
assert lora_id > 0, "lora_id must be greater than 0." assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id) return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()
def check_health(self) -> None: def check_health(self) -> None:
@ -127,8 +108,3 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy) blocks_to_copy=blocks_to_copy)
return output return output
async def check_health_async(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return

View File

@ -1,8 +1,5 @@
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -13,24 +10,10 @@ logger = init_logger(__name__)
class NeuronExecutor(ExecutorBase): class NeuronExecutor(ExecutorBase):
def __init__( def _init_executor(self) -> None:
self, assert (self.lora_config is
model_config: ModelConfig, None), "LoRA is not supported for Neuron backend."
cache_config: CacheConfig, assert (not self.speculative_config
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
assert lora_config is None, "LoRA is not supported for Neuron backend."
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
assert (not speculative_config
), "Speculative decoding not yet supported for Neuron backend." ), "Speculative decoding not yet supported for Neuron backend."
# Instantiate the worker and load the model to the device. # Instantiate the worker and load the model to the device.
@ -80,7 +63,7 @@ class NeuronExecutor(ExecutorBase):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id) return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()
def check_health(self) -> None: def check_health(self) -> None:

View File

@ -3,11 +3,8 @@ import copy
import os import os
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TensorizerConfig, VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
@ -32,27 +29,8 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
class RayGPUExecutor(ExecutorBase): class RayGPUExecutor(ExecutorBase):
def __init__( def _init_executor(self) -> None:
self, assert (not self.speculative_config
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.tensorizer_config = tensorizer_config
assert (not speculative_config
), "Speculative decoding not yet supported for RayGPU backend." ), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.worker_use_ray assert self.parallel_config.worker_use_ray
@ -273,7 +251,7 @@ class RayGPUExecutor(ExecutorBase):
lora_id=lora_id, lora_id=lora_id,
) )
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self._run_workers("list_loras") return self._run_workers("list_loras")
def _run_workers( def _run_workers(
@ -416,7 +394,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
output = all_outputs[0] output = all_outputs[0]
return output return output
async def check_health_async(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()