mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 04:04:25 +08:00
[V1][Core] Add worker_base for v1 worker (#12816)
Signed-off-by: Aoyu <aoyuzhan@amazon.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Aoyu <aoyuzhan@amazon.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
c9d3ecf016
commit
2092a6fa7d
@ -2220,3 +2220,46 @@ def import_pynvml():
|
|||||||
"""
|
"""
|
||||||
import vllm.third_party.pynvml as pynvml
|
import vllm.third_party.pynvml as pynvml
|
||||||
return pynvml
|
return pynvml
|
||||||
|
|
||||||
|
|
||||||
|
def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]:
|
||||||
|
"""
|
||||||
|
A replacement for `abc.ABC`.
|
||||||
|
When we use `abc.ABC`, subclasses will fail to instantiate
|
||||||
|
if they do not implement all abstract methods.
|
||||||
|
Here, we only require `raise NotImplementedError` in the
|
||||||
|
base class, and log a warning if the method is not implemented
|
||||||
|
in the subclass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
original_init = cls.__init__
|
||||||
|
|
||||||
|
def find_unimplemented_methods(self: object):
|
||||||
|
unimplemented_methods = []
|
||||||
|
for attr_name in dir(self):
|
||||||
|
# bypass inner method
|
||||||
|
if attr_name.startswith('_'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
attr = getattr(self, attr_name)
|
||||||
|
# get the func of callable method
|
||||||
|
if callable(attr):
|
||||||
|
attr_func = attr.__func__
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
src = inspect.getsource(attr_func)
|
||||||
|
if "NotImplementedError" in src:
|
||||||
|
unimplemented_methods.append(attr_name)
|
||||||
|
if unimplemented_methods:
|
||||||
|
method_names = ','.join(unimplemented_methods)
|
||||||
|
msg = (f"Methods {method_names} not implemented in {self}")
|
||||||
|
logger.warning(msg)
|
||||||
|
|
||||||
|
@wraps(original_init)
|
||||||
|
def wrapped_init(self, *args, **kwargs) -> None:
|
||||||
|
original_init(self, *args, **kwargs)
|
||||||
|
find_unimplemented_methods(self)
|
||||||
|
|
||||||
|
type.__setattr__(cls, '__init__', wrapped_init)
|
||||||
|
return cls
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from vllm.utils import GiB_bytes
|
|||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -28,7 +29,7 @@ if TYPE_CHECKING:
|
|||||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
from vllm.v1.core.scheduler_output import SchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker(WorkerBase):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -39,23 +40,11 @@ class Worker:
|
|||||||
is_driver_worker: bool = False,
|
is_driver_worker: bool = False,
|
||||||
):
|
):
|
||||||
|
|
||||||
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
|
super().__init__(vllm_config=vllm_config,
|
||||||
self.vllm_config = vllm_config
|
local_rank=local_rank,
|
||||||
self.model_config = vllm_config.model_config
|
rank=rank,
|
||||||
self.cache_config = vllm_config.cache_config
|
distributed_init_method=distributed_init_method,
|
||||||
self.lora_config = vllm_config.lora_config
|
is_driver_worker=is_driver_worker)
|
||||||
self.load_config = vllm_config.load_config
|
|
||||||
self.parallel_config = vllm_config.parallel_config
|
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
|
||||||
self.device_config = vllm_config.device_config
|
|
||||||
self.speculative_config = vllm_config.speculative_config
|
|
||||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
|
||||||
self.observability_config = vllm_config.observability_config
|
|
||||||
|
|
||||||
self.parallel_config.rank = rank
|
|
||||||
self.local_rank = local_rank
|
|
||||||
self.rank = rank
|
|
||||||
self.distributed_init_method = distributed_init_method
|
|
||||||
|
|
||||||
if self.model_config.trust_remote_code:
|
if self.model_config.trust_remote_code:
|
||||||
# note: lazy import to avoid importing torch before initializing
|
# note: lazy import to avoid importing torch before initializing
|
||||||
@ -126,7 +115,8 @@ class Worker:
|
|||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
# Construct the model runner
|
# Construct the model runner
|
||||||
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
|
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||||
|
self.vllm_config, self.device)
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
if self.vllm_config.model_config.enable_sleep_mode:
|
if self.vllm_config.model_config.enable_sleep_mode:
|
||||||
|
|||||||
63
vllm/v1/worker/worker_base.py
Normal file
63
vllm/v1/worker/worker_base.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||||
|
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerBase(WorkerBaseV0):
|
||||||
|
"""
|
||||||
|
Abstract class for v1 worker, mainly define some methods for v1.
|
||||||
|
For methods shared by v0 and v1, define them in v0 WorkerBase
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize common worker components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vllm_config: Complete vLLM configuration
|
||||||
|
local_rank: Local device index
|
||||||
|
rank: Global rank in distributed setup
|
||||||
|
distributed_init_method: Distributed initialization method
|
||||||
|
is_driver_worker: Whether this worker handles driver
|
||||||
|
responsibilities
|
||||||
|
"""
|
||||||
|
# Configuration storage
|
||||||
|
super().__init__(vllm_config=vllm_config)
|
||||||
|
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.rank = rank
|
||||||
|
self.distributed_init_method = distributed_init_method
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
|
||||||
|
# Device and model state
|
||||||
|
self.device: Optional[torch.device] = None
|
||||||
|
self.model_runner: Optional[nn.Module] = None
|
||||||
|
|
||||||
|
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||||
|
"""Get specifications for KV cache implementation."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def compile_or_warm_up_model(self) -> None:
|
||||||
|
"""Prepare model for execution through compilation/warmup."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def check_health(self) -> None:
|
||||||
|
"""Basic health check (override for device-specific checks)."""
|
||||||
|
return
|
||||||
@ -3,7 +3,7 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
@ -19,7 +19,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput
|
|||||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||||
resolve_obj_by_qualname, run_method,
|
resolve_obj_by_qualname, run_method,
|
||||||
update_environment_variables)
|
update_environment_variables,
|
||||||
|
warn_for_unimplemented_methods)
|
||||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||||
ModelRunnerBase,
|
ModelRunnerBase,
|
||||||
ModelRunnerInputBase)
|
ModelRunnerInputBase)
|
||||||
@ -27,7 +28,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class WorkerBase(ABC):
|
@warn_for_unimplemented_methods
|
||||||
|
class WorkerBase:
|
||||||
"""Worker interface that allows vLLM to cleanly separate implementations for
|
"""Worker interface that allows vLLM to cleanly separate implementations for
|
||||||
different hardware. Also abstracts control plane communication, e.g., to
|
different hardware. Also abstracts control plane communication, e.g., to
|
||||||
communicate request metadata to other workers.
|
communicate request metadata to other workers.
|
||||||
@ -53,35 +55,31 @@ class WorkerBase(ABC):
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
self.current_platform = current_platform
|
self.current_platform = current_platform
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
"""Initialize device state, such as loading the model or other on-device
|
"""Initialize device state, such as loading the model or other on-device
|
||||||
memory allocations.
|
memory allocations.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
|
||||||
"""Determine the number of available blocks for the GPU KV cache and
|
|
||||||
swappable CPU KV cache.
|
|
||||||
|
|
||||||
The implementation may run profiling or other heuristics to determine
|
|
||||||
the size of caches.
|
|
||||||
|
|
||||||
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
|
||||||
are blocks that are "active" on the device and can be appended to.
|
|
||||||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
|
||||||
appended to.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def initialize_cache(self, num_gpu_blocks: int,
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
num_cpu_blocks: int) -> None:
|
num_cpu_blocks: int) -> None:
|
||||||
"""Initialize the KV cache with the given size in blocks.
|
"""Initialize the KV cache with the given size in blocks.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
"""Load model onto target device."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def start_worker_execution_loop(self) -> None:
|
def start_worker_execution_loop(self) -> None:
|
||||||
"""Execute model loop in parallel worker.
|
"""Execute model loop in parallel worker.
|
||||||
|
|
||||||
@ -94,40 +92,43 @@ class WorkerBase(ABC):
|
|||||||
if output is None:
|
if output is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
def get_model(self) -> nn.Module:
|
"""Determine the number of available blocks for the GPU KV cache and
|
||||||
|
swappable CPU KV cache.
|
||||||
|
|
||||||
|
The implementation may run profiling or other heuristics to determine
|
||||||
|
the size of caches.
|
||||||
|
|
||||||
|
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
||||||
|
are blocks that are "active" on the device and can be appended to.
|
||||||
|
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
||||||
|
appended to.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def execute_model(
|
|
||||||
self,
|
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
|
||||||
) -> Optional[List[SamplerOutput]]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_cache_block_size_bytes(self) -> int:
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
"""Return the size of a single cache block, in bytes. Used in
|
"""Return the size of a single cache block, in bytes. Used in
|
||||||
speculative decoding.
|
speculative decoding.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def remove_lora(self, lora_id: int) -> bool:
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def pin_lora(self, lora_id: int) -> bool:
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
"""Get vocabulary size from model configuration."""
|
||||||
|
return self.model_config.get_vocab_size()
|
||||||
|
|
||||||
|
|
||||||
class DelegateWorkerBase(WorkerBase):
|
class DelegateWorkerBase(WorkerBase):
|
||||||
"""
|
"""
|
||||||
@ -156,6 +157,10 @@ class DelegateWorkerBase(WorkerBase):
|
|||||||
num_cpu_blocks: int) -> None:
|
num_cpu_blocks: int) -> None:
|
||||||
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
"""Load model onto target device."""
|
||||||
|
self.worker.load_model()
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.worker.get_model()
|
return self.worker.get_model()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user