vllm/vllm/worker/worker_base.py

84 lines
2.7 KiB
Python

from abc import ABC, abstractmethod
from typing import Dict, List, Tuple
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
class WorkerBase(ABC):
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware.
"""
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise NotImplementedError
@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
@abstractmethod
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError
@abstractmethod
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
"""Executes one model step on the given sequences."""
raise NotImplementedError
@abstractmethod
def get_cache_block_size_bytes() -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise NotImplementedError
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def list_loras(self) -> List[int]:
raise NotImplementedError
class LoraNotSupportedWorkerBase(WorkerBase):
"""Partial implementation of WorkerBase that raises exceptions when LoRA
methods are invoked.
"""
def add_lora(self, lora_request: LoRARequest) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def list_loras(self) -> List[int]:
raise ValueError(f"{type(self)} does not support LoRA")