[Core] Simplifications to executor classes (#4071)

This commit is contained in:
Nick Hill 2024-04-15 13:05:09 -07:00 committed by GitHub
parent 0003e9154b
commit eb46fbfda2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 44 additions and 109 deletions

View File

@ -1,10 +1,9 @@
import os
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Set, Tuple
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -16,23 +15,13 @@ logger = init_logger(__name__)
class CPUExecutor(ExecutorBase):
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], *args, **kwargs) -> None:
assert device_config.device_type == "cpu"
assert lora_config is None, "cpu backend doesn't support LoRA"
model_config = _verify_and_get_model_config(model_config)
cache_config = _verify_and_get_cache_config(cache_config)
scheduler_config = _verify_and_get_scheduler_config(scheduler_config)
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
def _init_executor(self) -> None:
assert self.device_config.device_type == "cpu"
assert self.lora_config is None, "cpu backend doesn't support LoRA"
self.model_config = _verify_and_get_model_config(self.model_config)
self.cache_config = _verify_and_get_cache_config(self.cache_config)
self.scheduler_config = _verify_and_get_scheduler_config(
self.scheduler_config)
# Instantiate the worker and load the model to CPU.
self._init_worker()
@ -99,7 +88,7 @@ class CPUExecutor(ExecutorBase):
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:

View File

@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
TensorizerConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@ -16,7 +16,6 @@ class ExecutorBase(ABC):
that can execute the model on multiple devices.
"""
@abstractmethod
def __init__(
self,
model_config: ModelConfig,
@ -27,8 +26,23 @@ class ExecutorBase(ABC):
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig],
) -> None:
raise NotImplementedError
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.speculative_config = speculative_config
self.tensorizer_config = tensorizer_config
self._init_executor()
@abstractmethod
def _init_executor(self) -> None:
pass
@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
@ -71,7 +85,7 @@ class ExecutorBase(ABC):
raise NotImplementedError
@abstractmethod
def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
raise NotImplementedError
@abstractmethod
@ -94,8 +108,7 @@ class ExecutorAsyncBase(ExecutorBase):
"""Executes one model step on the given sequences."""
raise NotImplementedError
@abstractmethod
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError
self.check_health()

View File

@ -1,8 +1,5 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TensorizerConfig, VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -15,24 +12,8 @@ logger = init_logger(__name__)
class GPUExecutor(ExecutorBase):
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig]) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.tensorizer_config = tensorizer_config
assert (not speculative_config
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for GPU backend"
# Instantiate the worker and load the model to GPU.
@ -103,7 +84,7 @@ class GPUExecutor(ExecutorBase):
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:
@ -127,8 +108,3 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
return output
async def check_health_async(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return

View File

@ -1,8 +1,5 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -13,24 +10,10 @@ logger = init_logger(__name__)
class NeuronExecutor(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
assert lora_config is None, "LoRA is not supported for Neuron backend."
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
assert (not speculative_config
def _init_executor(self) -> None:
assert (self.lora_config is
None), "LoRA is not supported for Neuron backend."
assert (not self.speculative_config
), "Speculative decoding not yet supported for Neuron backend."
# Instantiate the worker and load the model to the device.
@ -80,7 +63,7 @@ class NeuronExecutor(ExecutorBase):
def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def check_health(self) -> None:

View File

@ -3,11 +3,8 @@ import copy
import os
import pickle
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
TensorizerConfig, VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
@ -32,27 +29,8 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
class RayGPUExecutor(ExecutorBase):
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.tensorizer_config = tensorizer_config
assert (not speculative_config
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.worker_use_ray
@ -273,7 +251,7 @@ class RayGPUExecutor(ExecutorBase):
lora_id=lora_id,
)
def list_loras(self) -> List[int]:
def list_loras(self) -> Set[int]:
return self._run_workers("list_loras")
def _run_workers(
@ -416,7 +394,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
# Only the driver worker returns the sampling results.
output = all_outputs[0]
return output
async def check_health_async(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()