mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 16:27:01 +08:00
[Core] Simplifications to executor classes (#4071)
This commit is contained in:
parent
0003e9154b
commit
eb46fbfda2
@ -1,10 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
||||||
ParallelConfig, SchedulerConfig)
|
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -16,23 +15,13 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class CPUExecutor(ExecutorBase):
|
class CPUExecutor(ExecutorBase):
|
||||||
|
|
||||||
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
|
def _init_executor(self) -> None:
|
||||||
parallel_config: ParallelConfig,
|
assert self.device_config.device_type == "cpu"
|
||||||
scheduler_config: SchedulerConfig,
|
assert self.lora_config is None, "cpu backend doesn't support LoRA"
|
||||||
device_config: DeviceConfig,
|
self.model_config = _verify_and_get_model_config(self.model_config)
|
||||||
lora_config: Optional[LoRAConfig], *args, **kwargs) -> None:
|
self.cache_config = _verify_and_get_cache_config(self.cache_config)
|
||||||
assert device_config.device_type == "cpu"
|
self.scheduler_config = _verify_and_get_scheduler_config(
|
||||||
assert lora_config is None, "cpu backend doesn't support LoRA"
|
self.scheduler_config)
|
||||||
model_config = _verify_and_get_model_config(model_config)
|
|
||||||
cache_config = _verify_and_get_cache_config(cache_config)
|
|
||||||
scheduler_config = _verify_and_get_scheduler_config(scheduler_config)
|
|
||||||
|
|
||||||
self.model_config = model_config
|
|
||||||
self.cache_config = cache_config
|
|
||||||
self.lora_config = lora_config
|
|
||||||
self.parallel_config = parallel_config
|
|
||||||
self.scheduler_config = scheduler_config
|
|
||||||
self.device_config = device_config
|
|
||||||
|
|
||||||
# Instantiate the worker and load the model to CPU.
|
# Instantiate the worker and load the model to CPU.
|
||||||
self._init_worker()
|
self._init_worker()
|
||||||
@ -99,7 +88,7 @@ class CPUExecutor(ExecutorBase):
|
|||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
return self.driver_worker.remove_lora(lora_id)
|
return self.driver_worker.remove_lora(lora_id)
|
||||||
|
|
||||||
def list_loras(self) -> List[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self.driver_worker.list_loras()
|
return self.driver_worker.list_loras()
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
VisionLanguageConfig)
|
TensorizerConfig, VisionLanguageConfig)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
|
||||||
@ -16,7 +16,6 @@ class ExecutorBase(ABC):
|
|||||||
that can execute the model on multiple devices.
|
that can execute the model on multiple devices.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
@ -27,8 +26,23 @@ class ExecutorBase(ABC):
|
|||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
speculative_config: Optional[SpeculativeConfig],
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
|
tensorizer_config: Optional[TensorizerConfig],
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
self.model_config = model_config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.vision_language_config = vision_language_config
|
||||||
|
self.speculative_config = speculative_config
|
||||||
|
self.tensorizer_config = tensorizer_config
|
||||||
|
|
||||||
|
self._init_executor()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _init_executor(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
@ -71,7 +85,7 @@ class ExecutorBase(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list_loras(self) -> List[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -94,8 +108,7 @@ class ExecutorAsyncBase(ExecutorBase):
|
|||||||
"""Executes one model step on the given sequences."""
|
"""Executes one model step on the given sequences."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def check_health_async(self) -> None:
|
async def check_health_async(self) -> None:
|
||||||
"""Checks if the executor is healthy. If not, it should raise an
|
"""Checks if the executor is healthy. If not, it should raise an
|
||||||
exception."""
|
exception."""
|
||||||
raise NotImplementedError
|
self.check_health()
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
|
||||||
TensorizerConfig, VisionLanguageConfig)
|
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -15,24 +12,8 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class GPUExecutor(ExecutorBase):
|
class GPUExecutor(ExecutorBase):
|
||||||
|
|
||||||
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
|
def _init_executor(self) -> None:
|
||||||
parallel_config: ParallelConfig,
|
assert (not self.speculative_config
|
||||||
scheduler_config: SchedulerConfig,
|
|
||||||
device_config: DeviceConfig,
|
|
||||||
lora_config: Optional[LoRAConfig],
|
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
|
||||||
speculative_config: Optional[SpeculativeConfig],
|
|
||||||
tensorizer_config: Optional[TensorizerConfig]) -> None:
|
|
||||||
self.model_config = model_config
|
|
||||||
self.cache_config = cache_config
|
|
||||||
self.lora_config = lora_config
|
|
||||||
self.parallel_config = parallel_config
|
|
||||||
self.scheduler_config = scheduler_config
|
|
||||||
self.device_config = device_config
|
|
||||||
self.vision_language_config = vision_language_config
|
|
||||||
self.tensorizer_config = tensorizer_config
|
|
||||||
|
|
||||||
assert (not speculative_config
|
|
||||||
), "Speculative decoding not yet supported for GPU backend"
|
), "Speculative decoding not yet supported for GPU backend"
|
||||||
|
|
||||||
# Instantiate the worker and load the model to GPU.
|
# Instantiate the worker and load the model to GPU.
|
||||||
@ -103,7 +84,7 @@ class GPUExecutor(ExecutorBase):
|
|||||||
assert lora_id > 0, "lora_id must be greater than 0."
|
assert lora_id > 0, "lora_id must be greater than 0."
|
||||||
return self.driver_worker.remove_lora(lora_id)
|
return self.driver_worker.remove_lora(lora_id)
|
||||||
|
|
||||||
def list_loras(self) -> List[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self.driver_worker.list_loras()
|
return self.driver_worker.list_loras()
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
@ -127,8 +108,3 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
|||||||
blocks_to_swap_out=blocks_to_swap_out,
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
blocks_to_copy=blocks_to_copy)
|
blocks_to_copy=blocks_to_copy)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def check_health_async(self) -> None:
|
|
||||||
# GPUExecutor will always be healthy as long as
|
|
||||||
# it's running.
|
|
||||||
return
|
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
|
||||||
VisionLanguageConfig)
|
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -13,24 +10,10 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
class NeuronExecutor(ExecutorBase):
|
class NeuronExecutor(ExecutorBase):
|
||||||
|
|
||||||
def __init__(
|
def _init_executor(self) -> None:
|
||||||
self,
|
assert (self.lora_config is
|
||||||
model_config: ModelConfig,
|
None), "LoRA is not supported for Neuron backend."
|
||||||
cache_config: CacheConfig,
|
assert (not self.speculative_config
|
||||||
parallel_config: ParallelConfig,
|
|
||||||
scheduler_config: SchedulerConfig,
|
|
||||||
device_config: DeviceConfig,
|
|
||||||
lora_config: Optional[LoRAConfig],
|
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
|
||||||
speculative_config: Optional[SpeculativeConfig],
|
|
||||||
) -> None:
|
|
||||||
self.model_config = model_config
|
|
||||||
self.cache_config = cache_config
|
|
||||||
assert lora_config is None, "LoRA is not supported for Neuron backend."
|
|
||||||
self.parallel_config = parallel_config
|
|
||||||
self.scheduler_config = scheduler_config
|
|
||||||
self.device_config = device_config
|
|
||||||
assert (not speculative_config
|
|
||||||
), "Speculative decoding not yet supported for Neuron backend."
|
), "Speculative decoding not yet supported for Neuron backend."
|
||||||
|
|
||||||
# Instantiate the worker and load the model to the device.
|
# Instantiate the worker and load the model to the device.
|
||||||
@ -80,7 +63,7 @@ class NeuronExecutor(ExecutorBase):
|
|||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
return self.driver_worker.remove_lora(lora_id)
|
return self.driver_worker.remove_lora(lora_id)
|
||||||
|
|
||||||
def list_loras(self) -> List[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self.driver_worker.list_loras()
|
return self.driver_worker.list_loras()
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
|
|||||||
@ -3,11 +3,8 @@ import copy
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
|
||||||
TensorizerConfig, VisionLanguageConfig)
|
|
||||||
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -32,27 +29,8 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
|||||||
|
|
||||||
class RayGPUExecutor(ExecutorBase):
|
class RayGPUExecutor(ExecutorBase):
|
||||||
|
|
||||||
def __init__(
|
def _init_executor(self) -> None:
|
||||||
self,
|
assert (not self.speculative_config
|
||||||
model_config: ModelConfig,
|
|
||||||
cache_config: CacheConfig,
|
|
||||||
parallel_config: ParallelConfig,
|
|
||||||
scheduler_config: SchedulerConfig,
|
|
||||||
device_config: DeviceConfig,
|
|
||||||
lora_config: Optional[LoRAConfig],
|
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
|
||||||
speculative_config: Optional[SpeculativeConfig],
|
|
||||||
tensorizer_config: Optional[TensorizerConfig],
|
|
||||||
) -> None:
|
|
||||||
self.model_config = model_config
|
|
||||||
self.cache_config = cache_config
|
|
||||||
self.lora_config = lora_config
|
|
||||||
self.parallel_config = parallel_config
|
|
||||||
self.scheduler_config = scheduler_config
|
|
||||||
self.device_config = device_config
|
|
||||||
self.vision_language_config = vision_language_config
|
|
||||||
self.tensorizer_config = tensorizer_config
|
|
||||||
assert (not speculative_config
|
|
||||||
), "Speculative decoding not yet supported for RayGPU backend."
|
), "Speculative decoding not yet supported for RayGPU backend."
|
||||||
|
|
||||||
assert self.parallel_config.worker_use_ray
|
assert self.parallel_config.worker_use_ray
|
||||||
@ -273,7 +251,7 @@ class RayGPUExecutor(ExecutorBase):
|
|||||||
lora_id=lora_id,
|
lora_id=lora_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
def list_loras(self) -> List[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self._run_workers("list_loras")
|
return self._run_workers("list_loras")
|
||||||
|
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
@ -416,7 +394,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
|
|||||||
# Only the driver worker returns the sampling results.
|
# Only the driver worker returns the sampling results.
|
||||||
output = all_outputs[0]
|
output = all_outputs[0]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def check_health_async(self) -> None:
|
|
||||||
"""Raises an error if engine is unhealthy."""
|
|
||||||
self._check_if_any_actor_is_dead()
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user