[V0 Deprecation] Remove V0 model runner base & simplify worker base (#25328)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Woosuk Kwon 2025-09-20 20:49:09 -07:00 committed by yewentao256
parent b18dde7478
commit a6cf307fa8
4 changed files with 11 additions and 706 deletions

View File

@ -4,19 +4,14 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
Protocol, Set, Tuple, Type, TypeVar)
from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple,
Type, TypeVar)
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING:
from vllm.worker.model_runner_base import (ModelRunnerBase,
ModelRunnerInputBase,
ModelRunnerInputBuilderBase)
class AttentionType:
"""
@ -170,7 +165,7 @@ class AttentionState(ABC, Generic[T]):
lifetime of the model runner."""
@abstractmethod
def __init__(self, runner: "ModelRunnerBase"):
def __init__(self, runner: Any):
...
@abstractmethod
@ -210,7 +205,7 @@ class AttentionState(ABC, Generic[T]):
...
@abstractmethod
def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
def begin_forward(self, model_input) -> None:
"""Prepare state for forward pass."""
...
@ -219,7 +214,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
@abstractmethod
def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:
def __init__(self, input_builder) -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError

View File

@ -5,8 +5,7 @@ from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
import numpy as np
import torch
@ -21,9 +20,6 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.worker.model_runner_base import ModelRunnerBase
# Error string(s) for encoder/decoder
# unsupported attention scenarios
STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
@ -286,7 +282,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
class CommonAttentionState(AttentionState):
def __init__(self, runner: "ModelRunnerBase"):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False

View File

@ -1,307 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
import torch
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import is_text_generation_model
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.tasks import GenerationTask, SupportedTask
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
logger = init_logger(__name__)
T = TypeVar('T', bound="BroadcastableModelInput")
def _add_attn_metadata_broadcastable_dict(
tensor_dict: Dict[str, Any],
attn_metadata: Optional["AttentionMetadata"]) -> None:
"""
Helper method to update tensor_dict with broadcastable
AttentionMetadata fields.
"""
if attn_metadata is not None:
tensor_dict.update(attn_metadata.asdict_zerocopy())
def _init_attn_metadata_from_tensor_dict(
attn_backend: "AttentionBackend",
tensor_dict: Dict[str, Any],
) -> Dict[str, Any]:
"""
Helper method to initialize AttentionMetadata based on an
AttentionBackend and broadcastable AttentionMetadata fields.
"""
# Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
if field.name in tensor_dict:
if field.name == "input_positions":
valid_attn_kwargs[field.name] = tensor_dict[field.name]
else:
valid_attn_kwargs[field.name] = tensor_dict.pop(field.name)
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata
return tensor_dict
def _init_sampling_metadata_from_tensor_dict( # type: ignore
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper method to initialize SamplingMetadata based on broadcastable
SamplingMetadata fields.
"""
from vllm.model_executor import SamplingMetadata
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
# An empty SamplingMetadata to signal that the worker should skip
# sampling.
if selected_token_indices is not None:
tensor_dict["sampling_metadata"] = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=0,
)
return tensor_dict
def _add_sampling_metadata_broadcastable_dict(
tensor_dict: Dict[str, Any],
sampling_metadata: Optional["SamplingMetadata"]) -> None:
"""
Helper method to update tensor_dict with broadcastable
SamplingMetadata fields.
"""
if sampling_metadata is not None:
tensor_dict["selected_token_indices"] = (
sampling_metadata.selected_token_indices)
def _init_frozen_model_input_from_tensor_dict(
frozen_model_input_cls: Type["ModelRunnerInputBase"],
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper method to initialize a frozen ModelInput based on broadcastable
"""
valid_tensor_kwargs = {}
for field in dataclasses.fields(frozen_model_input_cls):
val = tensor_dict.pop(field.name, None)
if val is not None:
valid_tensor_kwargs[field.name] = val
frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs)
tensor_dict["frozen_model_input"] = frozen_model_input
return tensor_dict
class BroadcastableModelInput(ABC):
@abstractmethod
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"""
Extract broadcastable fields. Override for fields that require some
custom deserialization.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def from_broadcasted_tensor_dict(
cls: Type[T],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> T:
"""
Pop fields from the given tensor_dict and populate a new instance of
BroadcastableModelInput.
"""
raise NotImplementedError
@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(BroadcastableModelInput):
"""Local inputs to each worker's model runner. May contain
device-specific data. Different worker backends may have different methods
of converting from the global ExecuteModelRequest produced by the LLM
engine to the worker-local ModelRunnerInputBase objects.
Model runners that support multi-GPU execution should define a
ModelRunnerInputBase subclass, add their required fields, and specify how to
serialize/deserialize a ModelInput for broadcast between workers.
"""
pass
class ModelRunnerInputBuilderBase(ABC, Generic[T]):
"""A builder to create ModelRunnerInputBase objects.
"""
@abstractmethod
def prepare(self,
finished_requests_ids: Optional[List[str]] = None) -> None:
raise NotImplementedError
@abstractmethod
def add_seq_group(self, seq_group_metadata):
"""TBA"""
raise NotImplementedError
@abstractmethod
def build(self, *args, **kwargs) -> T:
"""Build metadata with on-device tensors."""
raise NotImplementedError
class ModelRunnerBase(ABC, Generic[T]):
"""
Model runner interface that abstracts a particular hardware and/or type of
model. Model execution may communicate data with model runners in other
processes, but it should not include control plane metadata communication.
Each ModelRunnerBase subclass should define a corresponding
ModelRunnerInputBase subclass.
"""
def __init__(
self,
vllm_config: VllmConfig,
) -> None:
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
# Map of request_id -> generator used for seeded random sampling
generators: Dict[str, torch.Generator] = {}
@abstractmethod
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> T:
"""
Make an instance of a ModelRunnerInputBase from the broadcasted tensor
dict.
"""
raise NotImplementedError
@abstractmethod
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None,
) -> T:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
request. This method may move data to the worker's local device. It is
not allowed to communicate with other workers or devices.
"""
raise NotImplementedError
@abstractmethod
def get_model(self) -> nn.Module:
raise NotImplementedError
def get_supported_generation_tasks(self) -> list[GenerationTask]:
model = self.get_model()
supported_tasks = list[GenerationTask]()
if is_text_generation_model(model):
supported_tasks.append("generate")
if supports_transcription(model):
if model.supports_transcription_only:
return ["transcription"]
supported_tasks.append("transcription")
return supported_tasks
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks = list[SupportedTask]()
if self.model_config.runner_type == "generate":
tasks.extend(self.get_supported_generation_tasks())
return tuple(tasks)
def execute_model(
self,
model_input: T,
kv_caches: Optional[List[torch.Tensor]],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
**kwargs,
) -> Optional[List[SamplerOutput]]:
"""
Execute the model on the given input.
"""
raise NotImplementedError
def get_generators(self, finished_request_ids: Optional[List[str]] = None):
"""
Return dict of per-request generators used for random sampling.
"""
# Clean up generators from completed requests
if finished_request_ids:
for request_id in finished_request_ids:
self.generators.pop(request_id, None)
return self.generators
class ModelRunnerWrapperBase:
"""
The whole point of this class is to lazily initialize the model_runner.
"""
def __init__(
self,
model_runner: ModelRunnerBase,
) -> None:
self.model_runner: ModelRunnerBase = model_runner
def __getattr__(self, attr):
return getattr(self.model_runner, attr)
class InputProcessingError(Exception):
"""This exception is raised when an error occurs preparing the inputs for
a single sequence group.
This allows the engine to gracefully handle errors with a single sequence
group without having to fail the entire batch.
"""
def __init__(self, request_id, message):
"""request_id is the id of the offending sequence group"""
self.request_id = request_id
self.message = message
super().__init__(self.message)
def __str__(self):
return "Failed to prepare inputs for sequence group with request id: " \
f"{self.request_id}, Error: {self.message}"

View File

@ -1,31 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import os
import time
from abc import abstractmethod
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)
from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar,
Union)
import cloudpickle
import torch
import torch.nn as nn
from vllm.config import (ObservabilityConfig, VllmConfig,
set_current_vllm_config)
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.sequence import ExecuteModelRequest
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method,
update_environment_variables,
warn_for_unimplemented_methods)
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
logger = init_logger(__name__)
@ -141,356 +132,6 @@ class WorkerBase:
return
class DelegateWorkerBase(WorkerBase):
"""
A class that delegates all methods to another WorkerBase instance. This is
useful for creating a WorkerBase that wraps another WorkerBase instance,
e.g. speculative decoding.
"""
worker: WorkerBase
def __init__(
self,
*args,
**kwargs,
) -> None:
vllm_config: VllmConfig = kwargs.get("vllm_config")
cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls)
self.worker = cls(*args, **kwargs)
def init_device(self) -> None:
self.worker.init_device()
def determine_num_available_blocks(self) -> Tuple[int, int]:
return self.worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def load_model(self) -> None:
"""Load model onto target device."""
self.worker.load_model()
def get_model(self) -> nn.Module:
return self.worker.get_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
return self.worker.execute_model(execute_model_req)
def get_cache_block_size_bytes(self) -> int:
return self.worker.get_cache_block_size_bytes()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
return self.worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.worker.list_loras()
def __getattr__(self, attr):
return getattr(self.worker, attr)
class LoRANotSupportedWorkerBase(WorkerBase):
"""Partial implementation of WorkerBase that raises exceptions when LoRA
methods are invoked.
"""
def add_lora(self, lora_request: LoRARequest) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def pin_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")
@dataclasses.dataclass(frozen=True)
class WorkerInput:
"""Local inputs to each worker. May contain device-specific data. These
fields should be broadcastable to other workers.
"""
num_seq_groups: Optional[int] = None
blocks_to_swap_in: Optional[torch.Tensor] = None
blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None
virtual_engine: int = 0
num_steps: int = 1
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
tensor_dict: Dict[str, Any],
) -> "WorkerInput":
"""
Pop fields from the given tensor_dict and populate a new instance of
WorkerInput.
"""
return cls(
num_seq_groups=tensor_dict.pop("num_seq_groups"),
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
)
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
"""
Extract broadcastable fields.
"""
tensor_dict = {
"num_seq_groups": self.num_seq_groups,
"blocks_to_swap_in": self.blocks_to_swap_in,
"blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
}
return tensor_dict
class LocalOrDistributedWorkerBase(WorkerBase):
"""
Partial implementation of WorkerBase that has a default `execute_model`
definition to perform metadata transfer between workers when in distributed
mode. Subclasses of this interface should use model runners that inherit
from ModelRunnerBase, and should only need to implement worker-local logic.
If custom control plane logic is needed to transfer metadata, or if the
model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
"""
is_driver_worker: bool
model_runner: ModelRunnerBase
observability_config: Optional[ObservabilityConfig] = None
@property
@abstractmethod
def do_metadata_broadcast(self) -> bool:
"""
Used by the default `execute_model` to check whether broadcast is
needed to transfer request inputs from the driver worker to other
workers in the TP group. If WorkerBase subclass only supports
single-worker execution, then this method should return False.
"""
raise NotImplementedError
@property
@abstractmethod
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
"""
Gets the list of kv caches to pass to the worker's model runner. Each
element in the list is a kv cache corresponding to a particular virtual
engine (PP stream). Used by the default `execute_model`. If the worker's
model runner does not follow the ModelRunnerBase interface, then inherit
from WorkerBase instead.
"""
raise NotImplementedError
@abstractmethod
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
"""
Prepare the inputs to WorkerBase.execute_worker from an execution
request. This method may move data to the worker's local device. It is
not allowed to communicate with other workers or devices.
"""
raise NotImplementedError
@abstractmethod
def execute_worker(self, worker_input: WorkerInput) -> None:
"""
Process an execution request.
"""
raise NotImplementedError
def _get_worker_input_from_broadcast(
self
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
str, torch.Tensor]]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
model_input = (
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
broadcast_data))
kwargs = extract_previous_hidden_states(broadcast_data)
return model_input, worker_input, kwargs
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
""" Get the driver input and broadcast it to other workers. """
assert self.is_driver_worker
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
kwargs = extract_previous_hidden_states(execute_model_req)
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_data.update(kwargs)
broadcast_tensor_dict(broadcast_data, src=0)
if execute_model_req.async_callback:
model_input = dataclasses.replace( # type: ignore
model_input,
async_callback=execute_model_req.async_callback)
return model_input, worker_input, kwargs
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
str, torch.Tensor]]]:
"""
Prepare the inputs to ModelRunner and workers.
"""
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
return self._get_driver_input_and_broadcast(execute_model_req)
else:
return self._get_worker_input_from_broadcast()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
start_time = time.perf_counter()
inputs = self.prepare_input(execute_model_req)
if inputs is None:
return None
model_input, worker_input, kwargs = inputs
num_steps = worker_input.num_steps
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
intermediate_tensors = None
orig_model_execute_time = 0.0
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
orig_model_execute_time = intermediate_tensors.tensors.get(
"model_execute_time", torch.tensor(0)).item()
output = self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
**kwargs,
)
model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
assert isinstance(output, IntermediateTensors)
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
output.tensors["model_execute_time"] = torch.tensor(
model_execute_time + orig_model_execute_time)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return [None]
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time
and output is not None):
for o in output:
o.model_execute_time = (orig_model_execute_time +
model_execute_time)
# output is List[SamplerOutput]
return output
def _execute_model_spmd(
self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None
) -> Optional[List[SamplerOutput]]:
"""
Execute model in Single Program Multiple Data (SPMD) fashion.
All workers take the same request, prepare the input and
execute the model.
"""
assert execute_model_req is not None, (
"_execute_model_spmd() requires each worker to take in an "
"ExecuteModelRequest")
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list))
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
kwargs = extract_previous_hidden_states(execute_model_req)
return self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
**kwargs,
)
class WorkerWrapperBase:
"""
This class represents one process in an executor/engine. It is responsible
@ -636,23 +277,3 @@ class WorkerWrapperBase:
def __getattr__(self, attr):
return getattr(self.worker, attr)
def extract_previous_hidden_states(
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
Dict[str, torch.Tensor]:
"""If data contains previous_hidden_states, extract it. This returns a dict
which can be used directly as additional kwargs in any following
execute_model calls. This is used in draft models like EAGLE."""
output = {}
# When called from non-driver worker, data is dict but when called from
# driver worker, data is ExecuteModelRequest.
if isinstance(data, dict):
if "previous_hidden_states" in data:
output["previous_hidden_states"] = data["previous_hidden_states"]
elif data.previous_hidden_states is not None:
output["previous_hidden_states"] = data.previous_hidden_states\
.hidden_states
return output