[V1][Core] Add worker_base for v1 worker (#12816)

Signed-off-by: Aoyu <aoyuzhan@amazon.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: Aoyu <aoyuzhan@amazon.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Aoyu 2025-02-13 20:35:18 +08:00 committed by GitHub
parent c9d3ecf016
commit 2092a6fa7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 154 additions and 53 deletions

View File

@ -2220,3 +2220,46 @@ def import_pynvml():
""" """
import vllm.third_party.pynvml as pynvml import vllm.third_party.pynvml as pynvml
return pynvml return pynvml
def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]:
"""
A replacement for `abc.ABC`.
When we use `abc.ABC`, subclasses will fail to instantiate
if they do not implement all abstract methods.
Here, we only require `raise NotImplementedError` in the
base class, and log a warning if the method is not implemented
in the subclass.
"""
original_init = cls.__init__
def find_unimplemented_methods(self: object):
unimplemented_methods = []
for attr_name in dir(self):
# bypass inner method
if attr_name.startswith('_'):
continue
try:
attr = getattr(self, attr_name)
# get the func of callable method
if callable(attr):
attr_func = attr.__func__
except AttributeError:
continue
src = inspect.getsource(attr_func)
if "NotImplementedError" in src:
unimplemented_methods.append(attr_name)
if unimplemented_methods:
method_names = ','.join(unimplemented_methods)
msg = (f"Methods {method_names} not implemented in {self}")
logger.warning(msg)
@wraps(original_init)
def wrapped_init(self, *args, **kwargs) -> None:
original_init(self, *args, **kwargs)
find_unimplemented_methods(self)
type.__setattr__(cls, '__init__', wrapped_init)
return cls

View File

@ -21,6 +21,7 @@ from vllm.utils import GiB_bytes
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
logger = init_logger(__name__) logger = init_logger(__name__)
@ -28,7 +29,7 @@ if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput from vllm.v1.core.scheduler_output import SchedulerOutput
class Worker: class Worker(WorkerBase):
def __init__( def __init__(
self, self,
@ -39,23 +40,11 @@ class Worker:
is_driver_worker: bool = False, is_driver_worker: bool = False,
): ):
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) super().__init__(vllm_config=vllm_config,
self.vllm_config = vllm_config local_rank=local_rank,
self.model_config = vllm_config.model_config rank=rank,
self.cache_config = vllm_config.cache_config distributed_init_method=distributed_init_method,
self.lora_config = vllm_config.lora_config is_driver_worker=is_driver_worker)
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
self.parallel_config.rank = rank
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
if self.model_config.trust_remote_code: if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing # note: lazy import to avoid importing torch before initializing
@ -126,7 +115,8 @@ class Worker:
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
# Construct the model runner # Construct the model runner
self.model_runner = GPUModelRunner(self.vllm_config, self.device) self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
def load_model(self) -> None: def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode: if self.vllm_config.model_config.enable_sleep_mode:

View File

@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
logger = init_logger(__name__)
class WorkerBase(WorkerBaseV0):
"""
Abstract class for v1 worker, mainly define some methods for v1.
For methods shared by v0 and v1, define them in v0 WorkerBase
"""
def __init__(
self,
vllm_config: VllmConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
):
"""
Initialize common worker components.
Args:
vllm_config: Complete vLLM configuration
local_rank: Local device index
rank: Global rank in distributed setup
distributed_init_method: Distributed initialization method
is_driver_worker: Whether this worker handles driver
responsibilities
"""
# Configuration storage
super().__init__(vllm_config=vllm_config)
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
# Device and model state
self.device: Optional[torch.device] = None
self.model_runner: Optional[nn.Module] = None
def get_kv_cache_spec(self) -> KVCacheSpec:
"""Get specifications for KV cache implementation."""
raise NotImplementedError
def compile_or_warm_up_model(self) -> None:
"""Prepare model for execution through compilation/warmup."""
raise NotImplementedError
def check_health(self) -> None:
"""Basic health check (override for device-specific checks)."""
return

View File

@ -3,7 +3,7 @@
import dataclasses import dataclasses
import os import os
import time import time
from abc import ABC, abstractmethod from abc import abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import cloudpickle import cloudpickle
@ -19,7 +19,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method, resolve_obj_by_qualname, run_method,
update_environment_variables) update_environment_variables,
warn_for_unimplemented_methods)
from vllm.worker.model_runner_base import (BroadcastableModelInput, from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase, ModelRunnerBase,
ModelRunnerInputBase) ModelRunnerInputBase)
@ -27,7 +28,8 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
logger = init_logger(__name__) logger = init_logger(__name__)
class WorkerBase(ABC): @warn_for_unimplemented_methods
class WorkerBase:
"""Worker interface that allows vLLM to cleanly separate implementations for """Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers. communicate request metadata to other workers.
@ -53,35 +55,31 @@ class WorkerBase(ABC):
from vllm.platforms import current_platform from vllm.platforms import current_platform
self.current_platform = current_platform self.current_platform = current_platform
@abstractmethod
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device """Initialize device state, such as loading the model or other on-device
memory allocations. memory allocations.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
@abstractmethod
def initialize_cache(self, num_gpu_blocks: int, def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None: num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks. """Initialize the KV cache with the given size in blocks.
""" """
raise NotImplementedError raise NotImplementedError
def get_model(self) -> nn.Module:
raise NotImplementedError
def load_model(self) -> None:
"""Load model onto target device."""
raise NotImplementedError
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
raise NotImplementedError
def start_worker_execution_loop(self) -> None: def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker. """Execute model loop in parallel worker.
@ -94,40 +92,43 @@ class WorkerBase(ABC):
if output is None: if output is None:
return None return None
@abstractmethod def determine_num_available_blocks(self) -> Tuple[int, int]:
def get_model(self) -> nn.Module: """Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError raise NotImplementedError
@abstractmethod
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
raise NotImplementedError
@abstractmethod
def get_cache_block_size_bytes(self) -> int: def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in """Return the size of a single cache block, in bytes. Used in
speculative decoding. speculative decoding.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def pin_lora(self, lora_id: int) -> bool: def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def list_loras(self) -> Set[int]: def list_loras(self) -> Set[int]:
raise NotImplementedError raise NotImplementedError
@property
def vocab_size(self) -> int:
"""Get vocabulary size from model configuration."""
return self.model_config.get_vocab_size()
class DelegateWorkerBase(WorkerBase): class DelegateWorkerBase(WorkerBase):
""" """
@ -156,6 +157,10 @@ class DelegateWorkerBase(WorkerBase):
num_cpu_blocks: int) -> None: num_cpu_blocks: int) -> None:
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def load_model(self) -> None:
"""Load model onto target device."""
self.worker.load_model()
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
return self.worker.get_model() return self.worker.get_model()