diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index dfde67e1713c5..ab7ef2112b083 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -4,19 +4,14 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass, fields -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, - Protocol, Set, Tuple, Type, TypeVar) +from typing import (Any, Dict, Generic, List, Optional, Protocol, Set, Tuple, + Type, TypeVar) import torch from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.multimodal import MultiModalPlaceholderMap -if TYPE_CHECKING: - from vllm.worker.model_runner_base import (ModelRunnerBase, - ModelRunnerInputBase, - ModelRunnerInputBuilderBase) - class AttentionType: """ @@ -170,7 +165,7 @@ class AttentionState(ABC, Generic[T]): lifetime of the model runner.""" @abstractmethod - def __init__(self, runner: "ModelRunnerBase"): + def __init__(self, runner: Any): ... @abstractmethod @@ -210,7 +205,7 @@ class AttentionState(ABC, Generic[T]): ... @abstractmethod - def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + def begin_forward(self, model_input) -> None: """Prepare state for forward pass.""" ... @@ -219,7 +214,7 @@ class AttentionMetadataBuilder(ABC, Generic[T]): """Abstract class for attention metadata builders.""" @abstractmethod - def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: + def __init__(self, input_builder) -> None: """Create the builder, remember some configuration and parameters.""" raise NotImplementedError diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 289cfa2177438..b28e6a4237cb6 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -5,8 +5,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, - TypeVar, Union) +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np import torch @@ -21,9 +20,6 @@ from vllm.utils import async_tensor_h2d, make_tensor_with_pad logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.worker.model_runner_base import ModelRunnerBase - # Error string(s) for encoder/decoder # unsupported attention scenarios STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " @@ -286,7 +282,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): class CommonAttentionState(AttentionState): - def __init__(self, runner: "ModelRunnerBase"): + def __init__(self, runner): self.runner = runner self._is_graph_capturing = False diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py deleted file mode 100644 index 1008b743619a4..0000000000000 --- a/vllm/worker/model_runner_base.py +++ /dev/null @@ -1,307 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from abc import ABC, abstractmethod -from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, - TypeVar) - -import torch -import torch.nn as nn - -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.models.interfaces import supports_transcription -from vllm.model_executor.models.interfaces_base import is_text_generation_model -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.tasks import GenerationTask, SupportedTask - -if TYPE_CHECKING: - from vllm.attention import AttentionMetadata - from vllm.attention.backends.abstract import AttentionBackend - from vllm.model_executor import SamplingMetadata - -logger = init_logger(__name__) - -T = TypeVar('T', bound="BroadcastableModelInput") - - -def _add_attn_metadata_broadcastable_dict( - tensor_dict: Dict[str, Any], - attn_metadata: Optional["AttentionMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - AttentionMetadata fields. - """ - if attn_metadata is not None: - tensor_dict.update(attn_metadata.asdict_zerocopy()) - - -def _init_attn_metadata_from_tensor_dict( - attn_backend: "AttentionBackend", - tensor_dict: Dict[str, Any], -) -> Dict[str, Any]: - """ - Helper method to initialize AttentionMetadata based on an - AttentionBackend and broadcastable AttentionMetadata fields. - """ - # Extract the fields used to create AttentionMetadata. - valid_attn_kwargs = {} - for field in dataclasses.fields(attn_backend.get_metadata_cls()): - if field.name in tensor_dict: - if field.name == "input_positions": - valid_attn_kwargs[field.name] = tensor_dict[field.name] - else: - valid_attn_kwargs[field.name] = tensor_dict.pop(field.name) - - attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) - tensor_dict["attn_metadata"] = attn_metadata - return tensor_dict - - -def _init_sampling_metadata_from_tensor_dict( # type: ignore - tensor_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Helper method to initialize SamplingMetadata based on broadcastable - SamplingMetadata fields. - """ - from vllm.model_executor import SamplingMetadata - - selected_token_indices = tensor_dict.pop("selected_token_indices", None) - # An empty SamplingMetadata to signal that the worker should skip - # sampling. - if selected_token_indices is not None: - tensor_dict["sampling_metadata"] = SamplingMetadata( - seq_groups=None, - selected_token_indices=selected_token_indices, - categorized_sample_indices=None, - num_prompts=0, - ) - return tensor_dict - - -def _add_sampling_metadata_broadcastable_dict( - tensor_dict: Dict[str, Any], - sampling_metadata: Optional["SamplingMetadata"]) -> None: - """ - Helper method to update tensor_dict with broadcastable - SamplingMetadata fields. - """ - if sampling_metadata is not None: - tensor_dict["selected_token_indices"] = ( - sampling_metadata.selected_token_indices) - - -def _init_frozen_model_input_from_tensor_dict( - frozen_model_input_cls: Type["ModelRunnerInputBase"], - tensor_dict: Dict[str, Any]) -> Dict[str, Any]: - """ - Helper method to initialize a frozen ModelInput based on broadcastable - """ - valid_tensor_kwargs = {} - for field in dataclasses.fields(frozen_model_input_cls): - val = tensor_dict.pop(field.name, None) - if val is not None: - valid_tensor_kwargs[field.name] = val - - frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs) - tensor_dict["frozen_model_input"] = frozen_model_input - return tensor_dict - - -class BroadcastableModelInput(ABC): - - @abstractmethod - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - """ - Extract broadcastable fields. Override for fields that require some - custom deserialization. - """ - raise NotImplementedError - - @classmethod - @abstractmethod - def from_broadcasted_tensor_dict( - cls: Type[T], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> T: - """ - Pop fields from the given tensor_dict and populate a new instance of - BroadcastableModelInput. - """ - raise NotImplementedError - - -@dataclasses.dataclass(frozen=True) -class ModelRunnerInputBase(BroadcastableModelInput): - """Local inputs to each worker's model runner. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelRunnerInputBase objects. - - Model runners that support multi-GPU execution should define a - ModelRunnerInputBase subclass, add their required fields, and specify how to - serialize/deserialize a ModelInput for broadcast between workers. - """ - pass - - -class ModelRunnerInputBuilderBase(ABC, Generic[T]): - """A builder to create ModelRunnerInputBase objects. - """ - - @abstractmethod - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - raise NotImplementedError - - @abstractmethod - def add_seq_group(self, seq_group_metadata): - """TBA""" - raise NotImplementedError - - @abstractmethod - def build(self, *args, **kwargs) -> T: - """Build metadata with on-device tensors.""" - raise NotImplementedError - - -class ModelRunnerBase(ABC, Generic[T]): - """ - Model runner interface that abstracts a particular hardware and/or type of - model. Model execution may communicate data with model runners in other - processes, but it should not include control plane metadata communication. - - Each ModelRunnerBase subclass should define a corresponding - ModelRunnerInputBase subclass. - """ - - def __init__( - self, - vllm_config: VllmConfig, - ) -> None: - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.observability_config = vllm_config.observability_config - - # Map of request_id -> generator used for seeded random sampling - generators: Dict[str, torch.Generator] = {} - - @abstractmethod - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, Any], - ) -> T: - """ - Make an instance of a ModelRunnerInputBase from the broadcasted tensor - dict. - """ - raise NotImplementedError - - @abstractmethod - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, - ) -> T: - """ - Prepare the inputs to ModelRunnerBase.execute_model from an execution - request. This method may move data to the worker's local device. It is - not allowed to communicate with other workers or devices. - """ - raise NotImplementedError - - @abstractmethod - def get_model(self) -> nn.Module: - raise NotImplementedError - - def get_supported_generation_tasks(self) -> list[GenerationTask]: - model = self.get_model() - supported_tasks = list[GenerationTask]() - - if is_text_generation_model(model): - supported_tasks.append("generate") - - if supports_transcription(model): - if model.supports_transcription_only: - return ["transcription"] - - supported_tasks.append("transcription") - - return supported_tasks - - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: - tasks = list[SupportedTask]() - - if self.model_config.runner_type == "generate": - tasks.extend(self.get_supported_generation_tasks()) - - return tuple(tasks) - - def execute_model( - self, - model_input: T, - kv_caches: Optional[List[torch.Tensor]], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - **kwargs, - ) -> Optional[List[SamplerOutput]]: - """ - Execute the model on the given input. - """ - raise NotImplementedError - - def get_generators(self, finished_request_ids: Optional[List[str]] = None): - """ - Return dict of per-request generators used for random sampling. - """ - - # Clean up generators from completed requests - if finished_request_ids: - for request_id in finished_request_ids: - self.generators.pop(request_id, None) - - return self.generators - - -class ModelRunnerWrapperBase: - """ - The whole point of this class is to lazily initialize the model_runner. - """ - - def __init__( - self, - model_runner: ModelRunnerBase, - ) -> None: - self.model_runner: ModelRunnerBase = model_runner - - def __getattr__(self, attr): - return getattr(self.model_runner, attr) - - -class InputProcessingError(Exception): - """This exception is raised when an error occurs preparing the inputs for - a single sequence group. - This allows the engine to gracefully handle errors with a single sequence - group without having to fail the entire batch. - """ - - def __init__(self, request_id, message): - """request_id is the id of the offending sequence group""" - self.request_id = request_id - self.message = message - super().__init__(self.message) - - def __str__(self): - return "Failed to prepare inputs for sequence group with request id: " \ - f"{self.request_id}, Error: {self.message}" diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d0a56f6ff4637..eaab976bf7f75 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,31 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import dataclasses import os -import time -from abc import abstractmethod -from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, Type, - TypeVar, Union) +from typing import (Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, + Union) import cloudpickle -import torch import torch.nn as nn -from vllm.config import (ObservabilityConfig, VllmConfig, - set_current_vllm_config) -from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group +from vllm.config import VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.sequence import ExecuteModelRequest from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, run_method, update_environment_variables, warn_for_unimplemented_methods) -from vllm.worker.model_runner_base import (BroadcastableModelInput, - ModelRunnerBase, - ModelRunnerInputBase) logger = init_logger(__name__) @@ -141,356 +132,6 @@ class WorkerBase: return -class DelegateWorkerBase(WorkerBase): - """ - A class that delegates all methods to another WorkerBase instance. This is - useful for creating a WorkerBase that wraps another WorkerBase instance, - e.g. speculative decoding. - """ - worker: WorkerBase - - def __init__( - self, - *args, - **kwargs, - ) -> None: - vllm_config: VllmConfig = kwargs.get("vllm_config") - cls = resolve_obj_by_qualname(vllm_config.parallel_config.worker_cls) - self.worker = cls(*args, **kwargs) - - def init_device(self) -> None: - self.worker.init_device() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - return self.worker.determine_num_available_blocks() - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - def load_model(self) -> None: - """Load model onto target device.""" - self.worker.load_model() - - def get_model(self) -> nn.Module: - return self.worker.get_model() - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: - return self.worker.execute_model(execute_model_req) - - def get_cache_block_size_bytes(self) -> int: - return self.worker.get_cache_block_size_bytes() - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.worker.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.worker.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.worker.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.worker.list_loras() - - def __getattr__(self, attr): - return getattr(self.worker, attr) - - -class LoRANotSupportedWorkerBase(WorkerBase): - """Partial implementation of WorkerBase that raises exceptions when LoRA - methods are invoked. - """ - - def add_lora(self, lora_request: LoRARequest) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") - - def remove_lora(self, lora_id: int) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") - - def pin_lora(self, lora_id: int) -> bool: - raise ValueError(f"{type(self)} does not support LoRA") - - def list_loras(self) -> Set[int]: - raise ValueError(f"{type(self)} does not support LoRA") - - -@dataclasses.dataclass(frozen=True) -class WorkerInput: - """Local inputs to each worker. May contain device-specific data. These - fields should be broadcastable to other workers. - """ - - num_seq_groups: Optional[int] = None - blocks_to_swap_in: Optional[torch.Tensor] = None - blocks_to_swap_out: Optional[torch.Tensor] = None - blocks_to_copy: Optional[torch.Tensor] = None - virtual_engine: int = 0 - num_steps: int = 1 - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type["WorkerInput"], - tensor_dict: Dict[str, Any], - ) -> "WorkerInput": - """ - Pop fields from the given tensor_dict and populate a new instance of - WorkerInput. - """ - return cls( - num_seq_groups=tensor_dict.pop("num_seq_groups"), - blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), - blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), - blocks_to_copy=tensor_dict.pop("blocks_to_copy"), - virtual_engine=tensor_dict["virtual_engine"], - num_steps=tensor_dict.pop("num_steps"), - ) - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - """ - Extract broadcastable fields. - """ - tensor_dict = { - "num_seq_groups": self.num_seq_groups, - "blocks_to_swap_in": self.blocks_to_swap_in, - "blocks_to_swap_out": self.blocks_to_swap_out, - "blocks_to_copy": self.blocks_to_copy, - "virtual_engine": self.virtual_engine, - "num_steps": self.num_steps, - } - - return tensor_dict - - -class LocalOrDistributedWorkerBase(WorkerBase): - """ - Partial implementation of WorkerBase that has a default `execute_model` - definition to perform metadata transfer between workers when in distributed - mode. Subclasses of this interface should use model runners that inherit - from ModelRunnerBase, and should only need to implement worker-local logic. - If custom control plane logic is needed to transfer metadata, or if the - model runner cannot inherit from ModelRunnerBase, use WorkerBase instead. - """ - is_driver_worker: bool - model_runner: ModelRunnerBase - observability_config: Optional[ObservabilityConfig] = None - - @property - @abstractmethod - def do_metadata_broadcast(self) -> bool: - """ - Used by the default `execute_model` to check whether broadcast is - needed to transfer request inputs from the driver worker to other - workers in the TP group. If WorkerBase subclass only supports - single-worker execution, then this method should return False. - """ - raise NotImplementedError - - @property - @abstractmethod - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - """ - Gets the list of kv caches to pass to the worker's model runner. Each - element in the list is a kv cache corresponding to a particular virtual - engine (PP stream). Used by the default `execute_model`. If the worker's - model runner does not follow the ModelRunnerBase interface, then inherit - from WorkerBase instead. - """ - raise NotImplementedError - - @abstractmethod - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - """ - Prepare the inputs to WorkerBase.execute_worker from an execution - request. This method may move data to the worker's local device. It is - not allowed to communicate with other workers or devices. - """ - raise NotImplementedError - - @abstractmethod - def execute_worker(self, worker_input: WorkerInput) -> None: - """ - Process an execution request. - """ - raise NotImplementedError - - def _get_worker_input_from_broadcast( - self - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ - str, torch.Tensor]]]: - """ Get the worker input from the broadcasted tensor dict. """ - assert self.do_metadata_broadcast - assert not self.is_driver_worker - broadcast_data = broadcast_tensor_dict(src=0) - if not broadcast_data: - return None - - worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) - model_input = ( - self.model_runner.make_model_input_from_broadcasted_tensor_dict( - broadcast_data)) - - kwargs = extract_previous_hidden_states(broadcast_data) - - return model_input, worker_input, kwargs - - def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest - ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: - """ Get the driver input and broadcast it to other workers. """ - assert self.is_driver_worker - - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) - - kwargs = extract_previous_hidden_states(execute_model_req) - - if self.do_metadata_broadcast: - broadcast_data = worker_input.as_broadcastable_tensor_dict() - broadcast_data.update(model_input.as_broadcastable_tensor_dict()) - broadcast_data.update(kwargs) - broadcast_tensor_dict(broadcast_data, src=0) - - if execute_model_req.async_callback: - model_input = dataclasses.replace( # type: ignore - model_input, - async_callback=execute_model_req.async_callback) - - return model_input, worker_input, kwargs - - def prepare_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ - str, torch.Tensor]]]: - """ - Prepare the inputs to ModelRunner and workers. - """ - if self.is_driver_worker: - if execute_model_req is None: - if self.do_metadata_broadcast: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) - return None - return self._get_driver_input_and_broadcast(execute_model_req) - else: - return self._get_worker_input_from_broadcast() - - def get_model(self) -> nn.Module: - return self.model_runner.get_model() - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[List[SamplerOutput]]: - """Executes at least one model step on the given sequences, unless no - sequences are provided.""" - start_time = time.perf_counter() - - inputs = self.prepare_input(execute_model_req) - if inputs is None: - return None - - model_input, worker_input, kwargs = inputs - num_steps = worker_input.num_steps - - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - intermediate_tensors = None - orig_model_execute_time = 0.0 - if not get_pp_group().is_first_rank: - intermediate_tensors = IntermediateTensors( - get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - orig_model_execute_time = intermediate_tensors.tensors.get( - "model_execute_time", torch.tensor(0)).item() - - output = self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - num_steps=num_steps, - **kwargs, - ) - - model_execute_time = time.perf_counter() - start_time - if not get_pp_group().is_last_rank: - # output is IntermediateTensors - assert isinstance(output, IntermediateTensors) - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time): - output.tensors["model_execute_time"] = torch.tensor( - model_execute_time + orig_model_execute_time) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - return [None] - if (self.observability_config is not None - and self.observability_config.collect_model_execute_time - and output is not None): - for o in output: - o.model_execute_time = (orig_model_execute_time + - model_execute_time) - - # output is List[SamplerOutput] - return output - - def _execute_model_spmd( - self, - execute_model_req: ExecuteModelRequest, - intermediate_tensors: Optional[IntermediateTensors] = None - ) -> Optional[List[SamplerOutput]]: - """ - Execute model in Single Program Multiple Data (SPMD) fashion. - All workers take the same request, prepare the input and - execute the model. - """ - assert execute_model_req is not None, ( - "_execute_model_spmd() requires each worker to take in an " - "ExecuteModelRequest") - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list)) - - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. - if worker_input.num_seq_groups == 0: - return [] - - kwargs = extract_previous_hidden_states(execute_model_req) - - return self.model_runner.execute_model( - model_input=model_input, - kv_caches=self.kv_cache[worker_input.virtual_engine] - if self.kv_cache is not None else None, - intermediate_tensors=intermediate_tensors, - **kwargs, - ) - - class WorkerWrapperBase: """ This class represents one process in an executor/engine. It is responsible @@ -636,23 +277,3 @@ class WorkerWrapperBase: def __getattr__(self, attr): return getattr(self.worker, attr) - - -def extract_previous_hidden_states( - data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ - Dict[str, torch.Tensor]: - """If data contains previous_hidden_states, extract it. This returns a dict - which can be used directly as additional kwargs in any following - execute_model calls. This is used in draft models like EAGLE.""" - output = {} - - # When called from non-driver worker, data is dict but when called from - # driver worker, data is ExecuteModelRequest. - if isinstance(data, dict): - if "previous_hidden_states" in data: - output["previous_hidden_states"] = data["previous_hidden_states"] - elif data.previous_hidden_states is not None: - output["previous_hidden_states"] = data.previous_hidden_states\ - .hidden_states - - return output