mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-26 00:36:32 +08:00
[V0 Deprecation] Remove V0 model runner base & simplify worker base (#25328)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
b18dde7478
commit
a6cf307fa8
@ -4,19 +4,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
|
||||
Protocol, Set, Tuple, Type, TypeVar)
|
||||
from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
|
||||
Type, TypeVar)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner_base import (ModelRunnerBase,
|
||||
ModelRunnerInputBase,
|
||||
ModelRunnerInputBuilderBase)
|
||||
|
||||
|
||||
class AttentionType:
|
||||
"""
|
||||
@ -170,7 +165,7 @@ class AttentionState(ABC, Generic[T]):
|
||||
lifetime of the model runner."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, runner: "ModelRunnerBase"):
|
||||
def __init__(self, runner: Any):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
@ -210,7 +205,7 @@ class AttentionState(ABC, Generic[T]):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
|
||||
def begin_forward(self, model_input) -> None:
|
||||
"""Prepare state for forward pass."""
|
||||
...
|
||||
|
||||
@ -219,7 +214,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
|
||||
"""Abstract class for attention metadata builders."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
|
||||
def __init__(self, input_builder) -> None:
|
||||
"""Create the builder, remember some configuration and parameters."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -5,8 +5,7 @@ from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -21,9 +20,6 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase
|
||||
|
||||
# Error string(s) for encoder/decoder
|
||||
# unsupported attention scenarios
|
||||
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
|
||||
@ -286,7 +282,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
|
||||
class CommonAttentionState(AttentionState):
|
||||
|
||||
def __init__(self, runner: "ModelRunnerBase"):
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
|
||||
|
||||
@ -1,307 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
||||
TypeVar)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.interfaces import supports_transcription
|
||||
from vllm.model_executor.models.interfaces_base import is_text_generation_model
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.tasks import GenerationTask, SupportedTask
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
T = TypeVar('T', bound="BroadcastableModelInput")
|
||||
|
||||
|
||||
def _add_attn_metadata_broadcastable_dict(
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_metadata: Optional["AttentionMetadata"]) -> None:
|
||||
"""
|
||||
Helper method to update tensor_dict with broadcastable
|
||||
AttentionMetadata fields.
|
||||
"""
|
||||
if attn_metadata is not None:
|
||||
tensor_dict.update(attn_metadata.asdict_zerocopy())
|
||||
|
||||
|
||||
def _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend: "AttentionBackend",
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to initialize AttentionMetadata based on an
|
||||
AttentionBackend and broadcastable AttentionMetadata fields.
|
||||
"""
|
||||
# Extract the fields used to create AttentionMetadata.
|
||||
valid_attn_kwargs = {}
|
||||
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
|
||||
if field.name in tensor_dict:
|
||||
if field.name == "input_positions":
|
||||
valid_attn_kwargs[field.name] = tensor_dict[field.name]
|
||||
else:
|
||||
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
|
||||
|
||||
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
|
||||
tensor_dict["attn_metadata"] = attn_metadata
|
||||
return tensor_dict
|
||||
|
||||
|
||||
def _init_sampling_metadata_from_tensor_dict( # type: ignore
|
||||
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to initialize SamplingMetadata based on broadcastable
|
||||
SamplingMetadata fields.
|
||||
"""
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
|
||||
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
|
||||
# An empty SamplingMetadata to signal that the worker should skip
|
||||
# sampling.
|
||||
if selected_token_indices is not None:
|
||||
tensor_dict["sampling_metadata"] = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
num_prompts=0,
|
||||
)
|
||||
return tensor_dict
|
||||
|
||||
|
||||
def _add_sampling_metadata_broadcastable_dict(
|
||||
tensor_dict: Dict[str, Any],
|
||||
sampling_metadata: Optional["SamplingMetadata"]) -> None:
|
||||
"""
|
||||
Helper method to update tensor_dict with broadcastable
|
||||
SamplingMetadata fields.
|
||||
"""
|
||||
if sampling_metadata is not None:
|
||||
tensor_dict["selected_token_indices"] = (
|
||||
sampling_metadata.selected_token_indices)
|
||||
|
||||
|
||||
def _init_frozen_model_input_from_tensor_dict(
|
||||
frozen_model_input_cls: Type["ModelRunnerInputBase"],
|
||||
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to initialize a frozen ModelInput based on broadcastable
|
||||
"""
|
||||
valid_tensor_kwargs = {}
|
||||
for field in dataclasses.fields(frozen_model_input_cls):
|
||||
val = tensor_dict.pop(field.name, None)
|
||||
if val is not None:
|
||||
valid_tensor_kwargs[field.name] = val
|
||||
|
||||
frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
|
||||
tensor_dict["frozen_model_input"] = frozen_model_input
|
||||
return tensor_dict
|
||||
|
||||
|
||||
class BroadcastableModelInput(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract broadcastable fields. Override for fields that require some
|
||||
custom deserialization.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type[T],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> T:
|
||||
"""
|
||||
Pop fields from the given tensor_dict and populate a new instance of
|
||||
BroadcastableModelInput.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelRunnerInputBase(BroadcastableModelInput):
|
||||
"""Local inputs to each worker's model runner. May contain
|
||||
device-specific data. Different worker backends may have different methods
|
||||
of converting from the global ExecuteModelRequest produced by the LLM
|
||||
engine to the worker-local ModelRunnerInputBase objects.
|
||||
|
||||
Model runners that support multi-GPU execution should define a
|
||||
ModelRunnerInputBase subclass, add their required fields, and specify how to
|
||||
serialize/deserialize a ModelInput for broadcast between workers.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
|
||||
"""A builder to create ModelRunnerInputBase objects.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_seq_group(self, seq_group_metadata):
|
||||
"""TBA"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def build(self, *args, **kwargs) -> T:
|
||||
"""Build metadata with on-device tensors."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ModelRunnerBase(ABC, Generic[T]):
|
||||
"""
|
||||
Model runner interface that abstracts a particular hardware and/or type of
|
||||
model. Model execution may communicate data with model runners in other
|
||||
processes, but it should not include control plane metadata communication.
|
||||
|
||||
Each ModelRunnerBase subclass should define a corresponding
|
||||
ModelRunnerInputBase subclass.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> None:
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.load_config = vllm_config.load_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.device_config = vllm_config.device_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.observability_config = vllm_config.observability_config
|
||||
|
||||
# Map of request_id -> generator used for seeded random sampling
|
||||
generators: Dict[str, torch.Generator] = {}
|
||||
|
||||
@abstractmethod
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> T:
|
||||
"""
|
||||
Make an instance of a ModelRunnerInputBase from the broadcasted tensor
|
||||
dict.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None,
|
||||
) -> T:
|
||||
"""
|
||||
Prepare the inputs to ModelRunnerBase.execute_model from an execution
|
||||
request. This method may move data to the worker's local device. It is
|
||||
not allowed to communicate with other workers or devices.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_supported_generation_tasks(self) -> list[GenerationTask]:
|
||||
model = self.get_model()
|
||||
supported_tasks = list[GenerationTask]()
|
||||
|
||||
if is_text_generation_model(model):
|
||||
supported_tasks.append("generate")
|
||||
|
||||
if supports_transcription(model):
|
||||
if model.supports_transcription_only:
|
||||
return ["transcription"]
|
||||
|
||||
supported_tasks.append("transcription")
|
||||
|
||||
return supported_tasks
|
||||
|
||||
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
||||
tasks = list[SupportedTask]()
|
||||
|
||||
if self.model_config.runner_type == "generate":
|
||||
tasks.extend(self.get_supported_generation_tasks())
|
||||
|
||||
return tuple(tasks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: T,
|
||||
kv_caches: Optional[List[torch.Tensor]],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
**kwargs,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""
|
||||
Execute the model on the given input.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_generators(self, finished_request_ids: Optional[List[str]] = None):
|
||||
"""
|
||||
Return dict of per-request generators used for random sampling.
|
||||
"""
|
||||
|
||||
# Clean up generators from completed requests
|
||||
if finished_request_ids:
|
||||
for request_id in finished_request_ids:
|
||||
self.generators.pop(request_id, None)
|
||||
|
||||
return self.generators
|
||||
|
||||
|
||||
class ModelRunnerWrapperBase:
|
||||
"""
|
||||
The whole point of this class is to lazily initialize the model_runner.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_runner: ModelRunnerBase,
|
||||
) -> None:
|
||||
self.model_runner: ModelRunnerBase = model_runner
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.model_runner, attr)
|
||||
|
||||
|
||||
class InputProcessingError(Exception):
|
||||
"""This exception is raised when an error occurs preparing the inputs for
|
||||
a single sequence group.
|
||||
This allows the engine to gracefully handle errors with a single sequence
|
||||
group without having to fail the entire batch.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id, message):
|
||||
"""request_id is the id of the offending sequence group"""
|
||||
self.request_id = request_id
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
def __str__(self):
|
||||
return "Failed to prepare inputs for sequence group with request id: " \
|
||||
f"{self.request_id}, Error: {self.message}"
|
||||
@ -1,31 +1,22 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar,
|
||||
Union)
|
||||
|
||||
import cloudpickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import (ObservabilityConfig, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
resolve_obj_by_qualname, run_method,
|
||||
update_environment_variables,
|
||||
warn_for_unimplemented_methods)
|
||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
ModelRunnerBase,
|
||||
ModelRunnerInputBase)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -141,356 +132,6 @@ class WorkerBase:
|
||||
return
|
||||
|
||||
|
||||
class DelegateWorkerBase(WorkerBase):
|
||||
"""
|
||||
A class that delegates all methods to another WorkerBase instance. This is
|
||||
useful for creating a WorkerBase that wraps another WorkerBase instance,
|
||||
e.g. speculative decoding.
|
||||
"""
|
||||
worker: WorkerBase
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
vllm_config: VllmConfig = kwargs.get("vllm_config")
|
||||
cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls)
|
||||
self.worker = cls(*args, **kwargs)
|
||||
|
||||
def init_device(self) -> None:
|
||||
self.worker.init_device()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
return self.worker.determine_num_available_blocks()
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def load_model(self) -> None:
|
||||
"""Load model onto target device."""
|
||||
self.worker.load_model()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.worker.get_model()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
return self.worker.execute_model(execute_model_req)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
return self.worker.get_cache_block_size_bytes()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.worker.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.worker.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.worker.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.worker.list_loras()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.worker, attr)
|
||||
|
||||
|
||||
class LoRANotSupportedWorkerBase(WorkerBase):
|
||||
"""Partial implementation of WorkerBase that raises exceptions when LoRA
|
||||
methods are invoked.
|
||||
"""
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WorkerInput:
|
||||
"""Local inputs to each worker. May contain device-specific data. These
|
||||
fields should be broadcastable to other workers.
|
||||
"""
|
||||
|
||||
num_seq_groups: Optional[int] = None
|
||||
blocks_to_swap_in: Optional[torch.Tensor] = None
|
||||
blocks_to_swap_out: Optional[torch.Tensor] = None
|
||||
blocks_to_copy: Optional[torch.Tensor] = None
|
||||
virtual_engine: int = 0
|
||||
num_steps: int = 1
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type["WorkerInput"],
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> "WorkerInput":
|
||||
"""
|
||||
Pop fields from the given tensor_dict and populate a new instance of
|
||||
WorkerInput.
|
||||
"""
|
||||
return cls(
|
||||
num_seq_groups=tensor_dict.pop("num_seq_groups"),
|
||||
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
|
||||
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
|
||||
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
|
||||
virtual_engine=tensor_dict["virtual_engine"],
|
||||
num_steps=tensor_dict.pop("num_steps"),
|
||||
)
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
"""
|
||||
Extract broadcastable fields.
|
||||
"""
|
||||
tensor_dict = {
|
||||
"num_seq_groups": self.num_seq_groups,
|
||||
"blocks_to_swap_in": self.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": self.blocks_to_swap_out,
|
||||
"blocks_to_copy": self.blocks_to_copy,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"num_steps": self.num_steps,
|
||||
}
|
||||
|
||||
return tensor_dict
|
||||
|
||||
|
||||
class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
"""
|
||||
Partial implementation of WorkerBase that has a default `execute_model`
|
||||
definition to perform metadata transfer between workers when in distributed
|
||||
mode. Subclasses of this interface should use model runners that inherit
|
||||
from ModelRunnerBase, and should only need to implement worker-local logic.
|
||||
If custom control plane logic is needed to transfer metadata, or if the
|
||||
model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
|
||||
"""
|
||||
is_driver_worker: bool
|
||||
model_runner: ModelRunnerBase
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
"""
|
||||
Used by the default `execute_model` to check whether broadcast is
|
||||
needed to transfer request inputs from the driver worker to other
|
||||
workers in the TP group. If WorkerBase subclass only supports
|
||||
single-worker execution, then this method should return False.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
"""
|
||||
Gets the list of kv caches to pass to the worker's model runner. Each
|
||||
element in the list is a kv cache corresponding to a particular virtual
|
||||
engine (PP stream). Used by the default `execute_model`. If the worker's
|
||||
model runner does not follow the ModelRunnerBase interface, then inherit
|
||||
from WorkerBase instead.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
"""
|
||||
Prepare the inputs to WorkerBase.execute_worker from an execution
|
||||
request. This method may move data to the worker's local device. It is
|
||||
not allowed to communicate with other workers or devices.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
"""
|
||||
Process an execution request.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_worker_input_from_broadcast(
|
||||
self
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
||||
str, torch.Tensor]]]:
|
||||
""" Get the worker input from the broadcasted tensor dict. """
|
||||
assert self.do_metadata_broadcast
|
||||
assert not self.is_driver_worker
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if not broadcast_data:
|
||||
return None
|
||||
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
||||
broadcast_data))
|
||||
|
||||
kwargs = extract_previous_hidden_states(broadcast_data)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
""" Get the driver input and broadcast it to other workers. """
|
||||
assert self.is_driver_worker
|
||||
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
kwargs = extract_previous_hidden_states(execute_model_req)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_data.update(kwargs)
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input = dataclasses.replace( # type: ignore
|
||||
model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
||||
str, torch.Tensor]]]:
|
||||
"""
|
||||
Prepare the inputs to ModelRunner and workers.
|
||||
"""
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
return self._get_driver_input_and_broadcast(execute_model_req)
|
||||
else:
|
||||
return self._get_worker_input_from_broadcast()
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model_runner.get_model()
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes at least one model step on the given sequences, unless no
|
||||
sequences are provided."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
inputs = self.prepare_input(execute_model_req)
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
model_input, worker_input, kwargs = inputs
|
||||
num_steps = worker_input.num_steps
|
||||
|
||||
self.execute_worker(worker_input)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if worker_input.num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
intermediate_tensors = None
|
||||
orig_model_execute_time = 0.0
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
orig_model_execute_time = intermediate_tensors.tensors.get(
|
||||
"model_execute_time", torch.tensor(0)).item()
|
||||
|
||||
output = self.model_runner.execute_model(
|
||||
model_input=model_input,
|
||||
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
num_steps=num_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model_execute_time = time.perf_counter() - start_time
|
||||
if not get_pp_group().is_last_rank:
|
||||
# output is IntermediateTensors
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
output.tensors["model_execute_time"] = torch.tensor(
|
||||
model_execute_time + orig_model_execute_time)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return [None]
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time
|
||||
and output is not None):
|
||||
for o in output:
|
||||
o.model_execute_time = (orig_model_execute_time +
|
||||
model_execute_time)
|
||||
|
||||
# output is List[SamplerOutput]
|
||||
return output
|
||||
|
||||
def _execute_model_spmd(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""
|
||||
Execute model in Single Program Multiple Data (SPMD) fashion.
|
||||
All workers take the same request, prepare the input and
|
||||
execute the model.
|
||||
"""
|
||||
assert execute_model_req is not None, (
|
||||
"_execute_model_spmd() requires each worker to take in an "
|
||||
"ExecuteModelRequest")
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list))
|
||||
|
||||
self.execute_worker(worker_input)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if worker_input.num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
kwargs = extract_previous_hidden_states(execute_model_req)
|
||||
|
||||
return self.model_runner.execute_model(
|
||||
model_input=model_input,
|
||||
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class WorkerWrapperBase:
|
||||
"""
|
||||
This class represents one process in an executor/engine. It is responsible
|
||||
@ -636,23 +277,3 @@ class WorkerWrapperBase:
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.worker, attr)
|
||||
|
||||
|
||||
def extract_previous_hidden_states(
|
||||
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
|
||||
Dict[str, torch.Tensor]:
|
||||
"""If data contains previous_hidden_states, extract it. This returns a dict
|
||||
which can be used directly as additional kwargs in any following
|
||||
execute_model calls. This is used in draft models like EAGLE."""
|
||||
output = {}
|
||||
|
||||
# When called from non-driver worker, data is dict but when called from
|
||||
# driver worker, data is ExecuteModelRequest.
|
||||
if isinstance(data, dict):
|
||||
if "previous_hidden_states" in data:
|
||||
output["previous_hidden_states"] = data["previous_hidden_states"]
|
||||
elif data.previous_hidden_states is not None:
|
||||
output["previous_hidden_states"] = data.previous_hidden_states\
|
||||
.hidden_states
|
||||
|
||||
return output
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user