mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 12:09:38 +08:00
[V1][Core] Add worker_base for v1 worker (#12816)
Signed-off-by: Aoyu <aoyuzhan@amazon.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: Aoyu <aoyuzhan@amazon.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
c9d3ecf016
commit
2092a6fa7d
@ -2220,3 +2220,46 @@ def import_pynvml():
|
||||
"""
|
||||
import vllm.third_party.pynvml as pynvml
|
||||
return pynvml
|
||||
|
||||
|
||||
def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]:
|
||||
"""
|
||||
A replacement for `abc.ABC`.
|
||||
When we use `abc.ABC`, subclasses will fail to instantiate
|
||||
if they do not implement all abstract methods.
|
||||
Here, we only require `raise NotImplementedError` in the
|
||||
base class, and log a warning if the method is not implemented
|
||||
in the subclass.
|
||||
"""
|
||||
|
||||
original_init = cls.__init__
|
||||
|
||||
def find_unimplemented_methods(self: object):
|
||||
unimplemented_methods = []
|
||||
for attr_name in dir(self):
|
||||
# bypass inner method
|
||||
if attr_name.startswith('_'):
|
||||
continue
|
||||
|
||||
try:
|
||||
attr = getattr(self, attr_name)
|
||||
# get the func of callable method
|
||||
if callable(attr):
|
||||
attr_func = attr.__func__
|
||||
except AttributeError:
|
||||
continue
|
||||
src = inspect.getsource(attr_func)
|
||||
if "NotImplementedError" in src:
|
||||
unimplemented_methods.append(attr_name)
|
||||
if unimplemented_methods:
|
||||
method_names = ','.join(unimplemented_methods)
|
||||
msg = (f"Methods {method_names} not implemented in {self}")
|
||||
logger.warning(msg)
|
||||
|
||||
@wraps(original_init)
|
||||
def wrapped_init(self, *args, **kwargs) -> None:
|
||||
original_init(self, *args, **kwargs)
|
||||
find_unimplemented_methods(self)
|
||||
|
||||
type.__setattr__(cls, '__init__', wrapped_init)
|
||||
return cls
|
||||
|
||||
@ -21,6 +21,7 @@ from vllm.utils import GiB_bytes
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -28,7 +29,7 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
||||
|
||||
|
||||
class Worker:
|
||||
class Worker(WorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -39,23 +40,11 @@ class Worker:
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
|
||||
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker)
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
@ -126,7 +115,8 @@ class Worker:
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# Construct the model runner
|
||||
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
|
||||
self.model_runner: GPUModelRunner = GPUModelRunner(
|
||||
self.vllm_config, self.device)
|
||||
|
||||
def load_model(self) -> None:
|
||||
if self.vllm_config.model_config.enable_sleep_mode:
|
||||
|
||||
63
vllm/v1/worker/worker_base.py
Normal file
63
vllm/v1/worker/worker_base.py
Normal file
@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WorkerBase(WorkerBaseV0):
|
||||
"""
|
||||
Abstract class for v1 worker, mainly define some methods for v1.
|
||||
For methods shared by v0 and v1, define them in v0 WorkerBase
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize common worker components.
|
||||
|
||||
Args:
|
||||
vllm_config: Complete vLLM configuration
|
||||
local_rank: Local device index
|
||||
rank: Global rank in distributed setup
|
||||
distributed_init_method: Distributed initialization method
|
||||
is_driver_worker: Whether this worker handles driver
|
||||
responsibilities
|
||||
"""
|
||||
# Configuration storage
|
||||
super().__init__(vllm_config=vllm_config)
|
||||
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
# Device and model state
|
||||
self.device: Optional[torch.device] = None
|
||||
self.model_runner: Optional[nn.Module] = None
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
"""Get specifications for KV cache implementation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def compile_or_warm_up_model(self) -> None:
|
||||
"""Prepare model for execution through compilation/warmup."""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_health(self) -> None:
|
||||
"""Basic health check (override for device-specific checks)."""
|
||||
return
|
||||
@ -3,7 +3,7 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import cloudpickle
|
||||
@ -19,7 +19,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
resolve_obj_by_qualname, run_method,
|
||||
update_environment_variables)
|
||||
update_environment_variables,
|
||||
warn_for_unimplemented_methods)
|
||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
ModelRunnerBase,
|
||||
ModelRunnerInputBase)
|
||||
@ -27,7 +28,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WorkerBase(ABC):
|
||||
@warn_for_unimplemented_methods
|
||||
class WorkerBase:
|
||||
"""Worker interface that allows vLLM to cleanly separate implementations for
|
||||
different hardware. Also abstracts control plane communication, e.g., to
|
||||
communicate request metadata to other workers.
|
||||
@ -53,35 +55,31 @@ class WorkerBase(ABC):
|
||||
from vllm.platforms import current_platform
|
||||
self.current_platform = current_platform
|
||||
|
||||
@abstractmethod
|
||||
def init_device(self) -> None:
|
||||
"""Initialize device state, such as loading the model or other on-device
|
||||
memory allocations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available blocks for the GPU KV cache and
|
||||
swappable CPU KV cache.
|
||||
|
||||
The implementation may run profiling or other heuristics to determine
|
||||
the size of caches.
|
||||
|
||||
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
||||
are blocks that are "active" on the device and can be appended to.
|
||||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
||||
appended to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache with the given size in blocks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load model onto target device."""
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def start_worker_execution_loop(self) -> None:
|
||||
"""Execute model loop in parallel worker.
|
||||
|
||||
@ -94,40 +92,43 @@ class WorkerBase(ABC):
|
||||
if output is None:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self) -> nn.Module:
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available blocks for the GPU KV cache and
|
||||
swappable CPU KV cache.
|
||||
|
||||
The implementation may run profiling or other heuristics to determine
|
||||
the size of caches.
|
||||
|
||||
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
||||
are blocks that are "active" on the device and can be appended to.
|
||||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
||||
appended to.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Return the size of a single cache block, in bytes. Used in
|
||||
speculative decoding.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
"""Get vocabulary size from model configuration."""
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
|
||||
class DelegateWorkerBase(WorkerBase):
|
||||
"""
|
||||
@ -156,6 +157,10 @@ class DelegateWorkerBase(WorkerBase):
|
||||
num_cpu_blocks: int) -> None:
|
||||
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load model onto target device."""
|
||||
self.worker.load_model()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.worker.get_model()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user