[Core] Refactor Worker and ModelRunner to consolidate control plane communication (#5408)

Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu>
Signed-off-by: Stephanie <swang@anyscale.com>
Co-authored-by: Stephanie <swang@anyscale.com>
This commit is contained in:
Stephanie Wang 2024-06-25 20:30:03 -07:00 committed by GitHub
parent 82079729cc
commit dda4811591
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1106 additions and 573 deletions

View File

@ -0,0 +1,152 @@
import dataclasses
from typing import List, Tuple, Type
import torch
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
class MockAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod
def get_impl_cls():
raise NotImplementedError
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return AttentionMetadata
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
pass
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
pass
def test_model_runner_input():
sampling_metadata = SamplingMetadata(
["seq_group"],
"selected_token_indices",
"categorized_sample_indices",
"num_prompts",
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata)
assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithSamplingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# For sampling metadata, only selected_token_indices is copied.
assert (received_model_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert received_model_input.sampling_metadata.seq_groups is None
def test_embedding_model_runner_input():
pooling_metadata = PoolingMetadata(
seq_groups=[[0]],
seq_data={},
prompt_lens=[1],
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
pooling_metadata=pooling_metadata,
attn_metadata=attn_metadata)
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))
# Check that received copy has correct values.
assert isinstance(received_model_input,
ModelInputForGPUWithPoolingMetadata)
assert received_model_input.input_tokens is not None
assert (
received_model_input.input_tokens == model_input.input_tokens).all()
assert received_model_input.input_positions is not None
assert (received_model_input.input_positions == model_input.input_positions
).all()
assert received_model_input.multi_modal_kwargs is None
assert (received_model_input.multi_modal_kwargs ==
model_input.multi_modal_kwargs)
assert received_model_input.lora_requests is None
assert received_model_input.lora_requests == model_input.lora_requests
assert received_model_input.lora_mapping is None
assert received_model_input.lora_mapping == model_input.lora_mapping
for field in dataclasses.fields(AttentionMetadata):
assert getattr(received_model_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None

View File

@ -61,12 +61,13 @@ def test_prepare_prompt(batch_size):
expected_selected_token_indices.append(selected_token_start_idx +
seq_len - 1)
selected_token_start_idx += seq_len
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = model_input.slot_mapping
slot_mapping = attn_metadata.slot_mapping
assert return_seq_lens == seq_lens
assert len(slot_mapping) == len(input_tokens)
@ -174,10 +175,11 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_positions,
model_input.attn_metadata, model_input.slot_mapping)
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
@ -259,32 +261,29 @@ def test_empty_seq_group():
enforce_eager=False,
)
seq_group_metadata_list: List[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert input_tokens is None
assert input_positions is None
assert attn_metadata is None
assert len(slot_mapping) == 0
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, slot_mapping,
return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.slot_mapping,
model_input.seq_lens,
)
assert len(input_tokens) == 0
assert len(input_positions) == 0
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.seq_lens,
)
assert input_tokens is None
assert input_positions is None
assert attn_metadata is None
assert len(slot_mapping) == 0
assert len(return_seq_lens) == 0
assert return_seq_lens is None
@pytest.fixture
@ -353,8 +352,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
seq_group_metadata_list.append(seq_group_metadata)
decode_metadata_list.append(seq_group_metadata)
(input_tokens, input_positions, attn_metadata, _, _, _,
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
)
prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
@ -367,7 +370,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
attn_metadata = model_runner._prepare_model_input(
attn_metadata = model_runner._prepare_model_input_tensors(
seq_group_metadata_list).attn_metadata
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),

View File

@ -21,9 +21,13 @@ class AttentionBackend(ABC):
@staticmethod
@abstractmethod
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError
@classmethod
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_kv_cache_shape(

View File

@ -90,8 +90,8 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
return BlocksparseFlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
return BlocksparseFlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return BlocksparseFlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(

View File

@ -25,8 +25,8 @@ class FlashAttentionBackend(AttentionBackend):
return FlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
return FlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(

View File

@ -22,8 +22,8 @@ class FlashInferBackend(AttentionBackend):
return FlashInferImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
return FlashInferMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashInferMetadata
@staticmethod
def get_kv_cache_shape(

View File

@ -25,8 +25,8 @@ class IpexAttnBackend(AttentionBackend):
return IpexAttnBackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "IpexAttnMetadata":
return IpexAttnMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
return IpexAttnMetadata
@staticmethod
def get_kv_cache_shape(

View File

@ -16,8 +16,8 @@ class PallasAttentionBackend(AttentionBackend):
return PallasAttentionBackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "PallasMetadata":
return PallasMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["PallasMetadata"]:
return PallasMetadata
@staticmethod
def get_kv_cache_shape(

View File

@ -25,8 +25,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
return ROCmFlashAttentionImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
return ROCmFlashAttentionMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return ROCmFlashAttentionMetadata
@staticmethod
def get_kv_cache_shape(

View File

@ -31,8 +31,8 @@ class TorchSDPABackend(AttentionBackend):
return TorchSDPABackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
return TorchSDPAMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return TorchSDPAMetadata
@staticmethod
def get_kv_cache_shape(

View File

@ -28,8 +28,8 @@ class XFormersBackend(AttentionBackend):
return XFormersImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
return XFormersMetadata(*args, **kwargs)
def get_metadata_cls() -> Type["AttentionMetadata"]:
return XFormersMetadata
@staticmethod
def get_kv_cache_shape(

View File

@ -64,8 +64,8 @@ class DistributedGPUExecutor(GPUExecutor):
num_cpu_blocks=num_cpu_blocks)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
"start_worker_execution_loop",
@ -79,7 +79,7 @@ class DistributedGPUExecutor(GPUExecutor):
if self.parallel_worker_tasks is None:
return
self._driver_execute_model()
self._driver_execute_model(execute_model_req=None)
parallel_worker_tasks = self.parallel_worker_tasks
self.parallel_worker_tasks = None
# Ensure that workers exit model loop cleanly
@ -123,13 +123,13 @@ class DistributedGPUExecutor(GPUExecutor):
@abstractmethod
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
Passing None will cause the driver to stop the model execution loop
running in each of the remote workers. In this case, this method
returns None. Otherwise, this method returns the model output.
"""
raise NotImplementedError

View File

@ -69,8 +69,8 @@ class ExecutorBase(ABC):
@abstractmethod
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences."""
raise NotImplementedError

View File

@ -87,7 +87,7 @@ class GPUExecutor(ExecutorBase):
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> List[Union[SamplerOutput, PoolerOutput]]:
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output

View File

@ -78,16 +78,14 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
worker_monitor.close()
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
return self.driver_worker.execute_model(
execute_model_req=execute_model_req)
return self.driver_worker.execute_model(execute_model_req)
def _run_workers(
self,

View File

@ -55,8 +55,7 @@ class NeuronExecutor(ExecutorBase):
assert execute_model_req.num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.")
output = self.driver_worker.execute_model(
execute_model_req.seq_group_metadata_list)
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:

View File

@ -190,9 +190,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_parallel_loading_workers)
def _driver_execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution

View File

@ -887,7 +887,8 @@ class HiddenStates:
@dataclass
class ExecuteModelRequest:
"""The model execution request."""
"""The model execution request, containing CPU metadata only. The LLM
engine should create an instance of this class for each request batch."""
# The sequence group metadata list.
seq_group_metadata_list: List[SequenceGroupMetadata]
# Blocks to swap in. List of CPU -> GPU block number.

View File

@ -7,7 +7,6 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.worker.model_runner import ModelInput
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
@ -56,7 +55,7 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, List[int], List[int]]:
if not seq_group_metadata_list:
return ModelInput.empty(self.device)
return torch.empty(0, device=self.device), [], []
input_tokens: List[int] = []
seq_lens: List[int] = []

View File

@ -1,5 +1,6 @@
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch
from torch import nn
@ -8,20 +9,64 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
class CPUModelRunner:
@dataclass(frozen=True)
class CPUModelInput(ModelRunnerInputBase):
"""
Used by the CPUModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["CPUModelInput"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None
) -> "CPUModelInput":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
def __init__(
self,
@ -270,86 +315,70 @@ class CPUModelRunner:
attn_metadata,
)
def prepare_input_tensors(
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> CPUModelInput:
return CPUModelInput.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Optional[Dict[str, torch.Tensor]]]:
) -> CPUModelInput:
multi_modal_kwargs = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
# Broadcast the metadata.
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"selected_token_indices":
sampling_metadata.selected_token_indices,
}
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs
) = self._prepare_prompt(seq_group_metadata_list)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
seq_lens=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
generators=None,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, multi_modal_kwargs)
(input_tokens, input_positions,
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
seq_lens = []
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
seq_lens,
# query_lens is not needed if chunked prefill is not
# supported. Since CPU worker doesn't support chunked prefill
# just use seq_lens instead.
seq_lens,
self.device,
pin_memory=False)
return CPUModelInput(
input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
sampling_metadata=sampling_metadata,
)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: CPUModelInput,
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"input_ids": model_input.input_tokens,
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
"attn_metadata": model_input.attn_metadata,
}
if self.vision_language_config and multi_modal_input is not None:
execute_model_kwargs.update(multi_modal_input)
if (self.vision_language_config
and model_input.multi_modal_kwargs is not None):
execute_model_kwargs.update(model_input.multi_modal_kwargs)
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
@ -358,6 +387,6 @@ class CPUModelRunner:
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
sampling_metadata=model_input.sampling_metadata,
)
return output

View File

@ -1,5 +1,5 @@
"""A CPU worker class."""
from typing import Any, Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed
@ -8,15 +8,15 @@ from vllm.attention import get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
logger = init_logger(__name__)
@ -110,7 +110,7 @@ class CPUCacheEngine:
return dtype_size * total
class CPUWorker(LoraNotSupportedWorkerBase):
class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a CPU socket.
Each worker is associated with a single CPU socket. The worker is
@ -154,7 +154,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = CPUModelRunner(
self.model_runner: CPUModelRunner = CPUModelRunner(
model_config,
parallel_config,
scheduler_config,
@ -255,54 +255,37 @@ class CPUWorker(LoraNotSupportedWorkerBase):
for layer_cache in self.cpu_cache:
layer_cache.fill_(0)
def cache_copy(
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[torch.Tensor]]:
return self.cpu_cache
def execute_worker(
self,
blocks_to_copy: torch.Tensor,
worker_input: WorkerInput,
) -> None:
if blocks_to_copy.numel() > 0:
self.cache_engine.copy(blocks_to_copy)
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy)
@torch.inference_mode()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups: int = len(seq_group_metadata_list)
assert execute_model_req is not None
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
dtype=torch.int64).view(-1, 2)
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_copy": execute_model_req.blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"]
self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.cpu_cache)
# CPU worker only supports single-step execution.
return [output]
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
assert execute_model_req is not None
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
blocks_to_copy = execute_model_req.blocks_to_copy
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device="cpu",
dtype=torch.int64).view(-1, 2)
assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(execute_model_req.blocks_to_swap_out) == 0
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_copy=blocks_to_copy,
)
def init_distributed_environment(self) -> None:
"""Initialize the distributed environment."""

View File

@ -1,24 +1,32 @@
from typing import Dict, List, Optional, Set, Tuple
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.attention import AttentionMetadata
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.pooling_params import PoolingParams
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
logger = init_logger(__name__)
class EmbeddingModelRunner(ModelRunner):
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
"""
Used by the EmbeddingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None
class EmbeddingModelRunner(
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
ModelInputForGPUWithPoolingMetadata)
def __init__(
self,
@ -47,21 +55,22 @@ class EmbeddingModelRunner(ModelRunner):
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
model_input: ModelInputForGPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
) -> Optional[PoolerOutput]:
(input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
# Currently cuda graph is only supported by the decode phase.
prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
@ -70,13 +79,14 @@ class EmbeddingModelRunner(ModelRunner):
kv_caches = [None] * num_layers
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"input_ids": model_input.input_tokens,
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
"attn_metadata": model_input.attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
execute_model_kwargs.update({"image_input": multi_modal_kwargs})
hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker.
@ -84,66 +94,31 @@ class EmbeddingModelRunner(ModelRunner):
return None
return self.model.pooler(hidden_states=hidden_states,
pooling_metadata=pooling_metadata)
pooling_metadata=model_input.pooling_metadata)
def prepare_input_tensors(
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForGPUWithPoolingMetadata:
return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)
def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
# Prepare input tensors.
(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
_,
lora_mapping,
lora_requests,
multi_modal_kwargs,
slot_mapping,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
) = self._prepare_model_input(seq_group_metadata_list)
# Prepare PoolingMetadata
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
seq_lens)
) -> ModelInputForGPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
}
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
attn_metadata = None
pooling_metadata = PoolingMetadata(seq_groups=None,
seq_data=None,
prompt_lens=None)
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
lora_requests, lora_mapping, multi_modal_kwargs)
return dataclasses.replace(model_input,
pooling_metadata=pooling_metadata)
def _prepare_pooling(
self,

View File

@ -1,8 +1,10 @@
import dataclasses
import gc
import time
import warnings
from collections import defaultdict
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)
import numpy as np
import torch
@ -12,7 +14,6 @@ from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict
from vllm.distributed.parallel_state import graph_capture
from vllm.logger import init_logger
from vllm.lora.layers import LoRAMapping
@ -26,6 +27,15 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
is_pin_memory_available, make_tensor_with_pad)
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
@ -39,40 +49,90 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
]
_NUM_WARMUP_ITERS = 2
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
class ModelInput(NamedTuple):
input_tokens: torch.Tensor
input_positions: torch.Tensor
attn_metadata: Optional[AttentionMetadata]
seq_lens: List[int]
query_lens: List[int]
lora_mapping: Optional[LoRAMapping]
lora_requests: Set[LoRARequest]
multi_modal_kwargs: Dict[str, torch.Tensor]
slot_mapping: torch.Tensor
num_prefill_tokens: int
num_decode_tokens: int
num_prefills: int
@dataclasses.dataclass(frozen=True)
class ModelInputForGPU(ModelRunnerInputBase):
"""
This base class contains metadata needed for the base model forward pass
but not metadata for possible additional steps, e.g., sampling. Model
runners that run additional steps should subclass this method to add
additional fields.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
query_lens: Optional[List[int]] = None
lora_mapping: Optional["LoRAMapping"] = None
lora_requests: Optional[Set[LoRARequest]] = None
attn_metadata: Optional["AttentionMetadata"] = None
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
return tensor_dict
@classmethod
def empty(cls, device):
return ModelInput(
input_tokens=torch.empty(0, device=device),
input_positions=torch.empty(0, device=device),
attn_metadata=None,
seq_lens=[],
query_lens=[],
lora_mapping=None,
lora_requests=set(),
multi_modal_kwargs={},
slot_mapping=torch.empty(0, device=device),
num_prefill_tokens=0,
num_decode_tokens=0,
num_prefills=0,
)
def from_broadcasted_tensor_dict(
cls: Type[TModelInputForGPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> TModelInputForGPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class ModelRunner:
@dataclasses.dataclass(frozen=True)
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"""
Used by the ModelRunner.
"""
sampling_metadata: Optional["SamplingMetadata"] = None
# Used for speculative decoding. We do not broadcast it because it is only
# used by the driver worker.
is_prompt: Optional[bool] = None
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForGPUWithSamplingMetadata":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"""
Helper class for shared methods between GPU model runners.
"""
_model_input_cls: Type[TModelInputForGPU]
def __init__(
self,
@ -241,11 +301,13 @@ class ModelRunner:
block_size = self.block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size
def _prepare_model_input(
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInput:
"""Prepare the model input based on a given sequence group.
) -> TModelInputForGPU:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
@ -296,7 +358,7 @@ class ModelRunner:
paged_kv_last_page_len: List[int] = []
if len(seq_group_metadata_list) == 0:
return ModelInput.empty(self.device)
return self._model_input_cls()
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window + self.block_size -
@ -646,7 +708,7 @@ class ModelRunner:
for k, v in multi_modal_kwargs_list.items()
}
return ModelInput(
return self._model_input_cls(
input_tokens=input_tokens_tensor,
input_positions=input_positions_tensor,
attn_metadata=attn_metadata,
@ -655,132 +717,8 @@ class ModelRunner:
lora_mapping=lora_mapping,
lora_requests=lora_requests,
multi_modal_kwargs=multi_modal_kwargs,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
)
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
# Prepare input tensors.
(
input_tokens,
input_positions,
attn_metadata,
seq_lens,
query_lens,
lora_mapping,
lora_requests,
multi_modal_kwargs,
slot_mapping,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
) = self._prepare_model_input(seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.pin_memory)
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"selected_token_indices":
sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
"multi_modal_kwargs": multi_modal_kwargs,
"num_prefill_tokens": num_prefill_tokens,
"num_decode_tokens": num_decode_tokens,
"slot_mapping": slot_mapping,
"num_prefills": num_prefills,
}
if attn_metadata:
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
if metadata_dict:
attn_metadata = self.attn_backend.make_metadata(
**metadata_dict)
else:
attn_metadata = None
sampling_metadata = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=0,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping,
multi_modal_kwargs)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_mapping, multi_modal_kwargs
) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
# Currently cuda graph is only supported by the decode phase.
prefill_meta = attn_metadata.prefill_metadata
decode_meta = attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
hidden_states = model_executable(
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
**multi_modal_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return None
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert seq_group_metadata_list is not None
if seq_group_metadata_list[0].is_prompt:
hidden_states = hidden_states.index_select(
0, sampling_metadata.selected_token_indices)
output.hidden_states = hidden_states
return output
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
@ -853,7 +791,8 @@ class ModelRunner:
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
self.execute_model(seqs, kv_caches)
model_input = self.prepare_model_input(seqs)
self.execute_model(model_input, kv_caches)
torch.cuda.synchronize()
return
@ -986,6 +925,110 @@ class ModelRunner:
return self.model_config.get_vocab_size()
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
"""
GPU model runner with sampling step.
"""
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
ModelInputForGPUWithSamplingMetadata)
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForGPUWithSamplingMetadata:
return (
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> ModelInputForGPUWithSamplingMetadata:
"""Prepare the model input based on a given sequence group, including
metadata for the sampling step.
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
The result tensors and data structure also batches input in prefill
-> decode order. For example,
- input_tokens[:num_prefill_tokens] contains prefill tokens.
- input_tokens[num_prefill_tokens:] contains decode tokens.
If cuda graph is required, this API automatically pads inputs.
"""
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
self.pin_memory)
is_prompt = (seq_group_metadata_list[0].is_prompt
if seq_group_metadata_list else None)
return dataclasses.replace(model_input,
sampling_metadata=sampling_metadata,
is_prompt=is_prompt)
@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
) -> SamplerOutput:
if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)
# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
**multi_modal_kwargs,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
return None
# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
if model_input.is_prompt:
assert model_input.sampling_metadata is not None
hidden_states = hidden_states.index_select(
0, model_input.sampling_metadata.selected_token_indices)
output.hidden_states = hidden_states
return output
class CUDAGraphRunner:
def __init__(self, model: nn.Module):

View File

@ -0,0 +1,157 @@
import dataclasses
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
TypeVar)
import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
T = TypeVar('T', bound="ModelRunnerInputBase")
def _add_attn_metadata_broadcastable_dict(
tensor_dict: Dict[str, Any],
attn_metadata: Optional["AttentionMetadata"]) -> None:
"""
Helper method to update tensor_dict with broadcastable
AttentionMetadata fields.
"""
if attn_metadata is not None:
tensor_dict.update(attn_metadata.asdict_zerocopy())
def _init_attn_metadata_from_tensor_dict(
attn_backend: "AttentionBackend",
tensor_dict: Dict[str, Any],
) -> Dict[str, Any]:
"""
Helper method to initialize AttentionMetadata based on an
AttentionBackend and broadcastable AttentionMetadata fields.
"""
# Extract the fields used to create AttentionMetadata.
valid_attn_kwargs = {}
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
val = tensor_dict.pop(field.name, None)
if val is not None:
valid_attn_kwargs[field.name] = val
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
tensor_dict["attn_metadata"] = attn_metadata
return tensor_dict
def _init_sampling_metadata_from_tensor_dict( # type: ignore
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Helper method to initialize SamplingMetadata based on broadcastable
SamplingMetadata fields.
"""
from vllm.model_executor import SamplingMetadata
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
# An empty SamplingMetadata to signal that the worker should skip
# sampling.
if selected_token_indices is not None:
tensor_dict["sampling_metadata"] = SamplingMetadata(
seq_groups=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
num_prompts=0,
)
return tensor_dict
def _add_sampling_metadata_broadcastable_dict(
tensor_dict: Dict[str, Any],
sampling_metadata: Optional["SamplingMetadata"]) -> None:
"""
Helper method to update tensor_dict with broadcastable
SamplingMetadata fields.
"""
if sampling_metadata is not None:
tensor_dict["selected_token_indices"] = (
sampling_metadata.selected_token_indices)
@dataclasses.dataclass(frozen=True)
class ModelRunnerInputBase(ABC):
"""Local inputs to each worker's model runner. May contain
device-specific data. Different worker backends may have different methods
of converting from the global ExecuteModelRequest produced by the LLM
engine to the worker-local ModelRunnerInputBase objects.
Model runners that support multi-GPU execution should define a
ModelRunnerInputBase subclass, add their required fields, and specify how to
serialize/deserialize a ModelInput for broadcast between workers.
"""
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"""
Extract broadcastable fields. Override for fields that require some
custom deserialization.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def from_broadcasted_tensor_dict(
cls: Type[T],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> T:
"""
Pop fields from the given tensor_dict and populate a new instance of
ModelRunnerInputBase.
"""
raise NotImplementedError
class ModelRunnerBase(ABC, Generic[T]):
"""
Model runner interface that abstracts a particular hardware and/or type of
model. Model execution may communicate data with model runners in other
processes, but it should not include control plane metadata communication.
Each ModelRunnerBase subclass should define a corresponding
ModelRunnerInputBase subclass.
"""
@abstractmethod
def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> T:
"""
Make an instance of a ModelRunnerInputBase from the broadcasted tensor
dict.
"""
raise NotImplementedError
@abstractmethod
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> T:
"""
Prepare the inputs to ModelRunnerBase.execute_model from an execution
request. This method may move data to the worker's local device. It is
not allowed to communicate with other workers or devices.
"""
raise NotImplementedError
@torch.inference_mode()
def execute_model(
self,
model_input: T,
kv_caches: Optional[List[torch.Tensor]],
) -> Optional[SamplerOutput]:
"""
Execute the model on the given input.
"""
raise NotImplementedError

View File

@ -1,4 +1,5 @@
from typing import List, Optional, Tuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
@ -10,11 +11,39 @@ from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
class NeuronModelRunner:
@dataclass(frozen=True)
class ModelInputForNeuron(ModelRunnerInputBase):
"""
Used by the NeuronModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: Optional["SamplingMetadata"] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")
@classmethod
def from_broadcasted_tensor_dict(
cls,
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForNeuron":
assert attn_backend is None
return cls.from_broadcasted_tensor_dict(tensor_dict)
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def __init__(
self,
@ -139,10 +168,14 @@ class NeuronModelRunner:
return input_tokens, input_positions, input_block_ids
def prepare_input_tensors(
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
) -> ModelInputForNeuron:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
@ -164,30 +197,31 @@ class NeuronModelRunner:
self.device,
self.pin_memory)
return (input_tokens, input_positions, input_block_ids,
sampling_metadata)
return ModelInputForNeuron(input_tokens=input_tokens,
input_positions=input_positions,
input_block_ids=input_block_ids,
sampling_metadata=sampling_metadata)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: ModelInputForNeuron,
kv_caches: Optional[List[torch.Tensor]] = None,
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_block_ids, sampling_metadata
) = self.prepare_input_tensors(seq_group_metadata_list)
hidden_states = self.model(
input_ids=input_tokens,
positions=input_positions,
input_block_ids=input_block_ids,
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
sampling_metadata=model_input.sampling_metadata,
)
return output

View File

@ -1,5 +1,5 @@
"""A Neuron worker class."""
from typing import List, Tuple
from typing import List, Optional, Tuple
import torch
import torch.distributed
@ -7,12 +7,13 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.sequence import ExecuteModelRequest
from vllm.worker.neuron_model_runner import NeuronModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
LoraNotSupportedWorkerBase, WorkerInput)
class NeuronWorker(LoraNotSupportedWorkerBase):
class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
"""A worker class that executes the model on a group of neuron cores.
"""
@ -34,8 +35,9 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.model_runner = NeuronModelRunner(model_config, parallel_config,
scheduler_config, device_config)
self.model_runner: NeuronModelRunner = NeuronModelRunner(
model_config, parallel_config, scheduler_config, device_config)
self.is_driver_worker = True
def init_device(self) -> None:
# Set random seed.
@ -73,22 +75,19 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
@property
def do_metadata_broadcast(self) -> bool:
return False
@property
def kv_cache(self) -> Optional[List[torch.Tensor]]:
return None
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[SamplerOutput]:
num_seq_groups = len(seq_group_metadata_list)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list)
# Neuron worker only supports single-step output. Wrap the output in a
# list to conform to interface.
return [output]
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
return WorkerInput(num_seq_groups=len(
execute_model_req.seq_group_metadata_list), )
def get_cache_block_size_bytes(self) -> int:
"""Determine the size in bytes of a cache block.

View File

@ -1,7 +1,7 @@
"""A GPU worker class."""
import gc
import os
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import List, Optional, Set, Tuple, Type
import torch
import torch.distributed
@ -9,21 +9,20 @@ import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker_base import WorkerBase
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
class Worker(WorkerBase):
class Worker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
@ -78,9 +77,10 @@ class Worker(WorkerBase):
or (speculative_config.draft_model_config.hf_config.model_type !=
"mlp_speculator") else {"return_hidden_states": True}
ModelRunnerClass = (EmbeddingModelRunner if
self.model_config.embedding_mode else ModelRunner)
self.model_runner = ModelRunnerClass(
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if self.model_config.embedding_mode:
ModelRunnerClass = EmbeddingModelRunner
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
model_config,
parallel_config,
scheduler_config,
@ -225,40 +225,18 @@ class Worker(WorkerBase):
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def cache_swap(
self,
blocks_to_swap_in: torch.Tensor,
blocks_to_swap_out: torch.Tensor,
blocks_to_copy: torch.Tensor,
) -> None:
# Issue cache operations.
if blocks_to_swap_in.numel() > 0:
self.cache_engine.swap_in(blocks_to_swap_in)
if blocks_to_swap_out.numel() > 0:
self.cache_engine.swap_out(blocks_to_swap_out)
if blocks_to_copy.numel() > 0:
self.cache_engine.copy(blocks_to_copy)
@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1
@property
def kv_cache(self) -> Optional[List[torch.Tensor]]:
return self.gpu_cache
@torch.inference_mode()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[Union[SamplerOutput, PoolerOutput]]:
if not self.is_driver_worker:
self._execute_model_non_driver()
return []
if execute_model_req is None:
# This signals that there's no more requests to process for now.
# All workers are running infinite loop with broadcast_tensor_dict,
# and it stops the loop when the driver broadcasts an empty input.
# Send an empty input to notify all other workers to stop their
# execution loop.
broadcast_tensor_dict({}, src=0)
return []
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
num_seq_groups = len(seq_group_metadata_list)
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
@ -273,59 +251,26 @@ class Worker(WorkerBase):
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return []
output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache)
# Worker only supports single-step execution. Wrap the output in a list
# to conform to interface.
return [output]
return WorkerInput(
num_seq_groups=num_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
@torch.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
while self._execute_model_non_driver():
pass
def _execute_model_non_driver(self) -> bool:
"""Execute model in parallel worker.
Returns True iff there are remaining sequences to process.
"""
assert not self.is_driver_worker
data = broadcast_tensor_dict(src=0)
if not data:
return False
num_seq_groups = data.get("num_seq_groups", 0)
blocks_to_swap_in = data.get("blocks_to_swap_in")
blocks_to_swap_out = data.get("blocks_to_swap_out")
blocks_to_copy = data.get("blocks_to_copy")
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return False
self.model_runner.execute_model(None, self.gpu_cache)
return True
def execute_worker(self, worker_input: WorkerInput) -> None:
# Issue cache operations.
if (worker_input.blocks_to_swap_in is not None
and worker_input.blocks_to_swap_in.numel() > 0):
self.cache_engine.swap_in(worker_input.blocks_to_swap_in)
if (worker_input.blocks_to_swap_out is not None
and worker_input.blocks_to_swap_out.numel() > 0):
self.cache_engine.swap_out(worker_input.blocks_to_swap_out)
if (worker_input.blocks_to_copy is not None
and worker_input.blocks_to_copy.numel() > 0):
self.cache_engine.copy(worker_input.blocks_to_copy)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_runner.add_lora(lora_request)

View File

@ -1,20 +1,26 @@
import dataclasses
import importlib
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import torch
from vllm.distributed import broadcast_tensor_dict
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
update_environment_variables)
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
logger = init_logger(__name__)
class WorkerBase(ABC):
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware.
different hardware. Also abstracts control plane communication, e.g., to
communicate request metadata to other workers.
"""
@abstractmethod
@ -46,13 +52,23 @@ class WorkerBase(ABC):
"""
raise NotImplementedError
@torch.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.
You can stop the loop by executing a driver worker with an empty output.
See `stop_remote_worker_execution_loop` for more details.
"""
while True:
output = self.execute_model(execute_model_req=None)
if output is None:
return None
@abstractmethod
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
) -> Optional[List[SamplerOutput]]:
raise NotImplementedError
@abstractmethod
@ -98,6 +114,150 @@ class LoraNotSupportedWorkerBase(WorkerBase):
raise ValueError(f"{type(self)} does not support LoRA")
@dataclasses.dataclass(frozen=True)
class WorkerInput:
"""Local inputs to each worker. May contain device-specific data. These
fields should be broadcastable to other workers.
"""
num_seq_groups: Optional[int] = None
blocks_to_swap_in: Optional[torch.Tensor] = None
blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
tensor_dict: Dict[str, Any],
) -> "WorkerInput":
"""
Pop fields from the given tensor_dict and populate a new instance of
WorkerInput.
"""
return cls(
num_seq_groups=tensor_dict.pop("num_seq_groups"),
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
)
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
"""
Extract broadcastable fields.
"""
tensor_dict = {
"num_seq_groups": self.num_seq_groups,
"blocks_to_swap_in": self.blocks_to_swap_in,
"blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy,
}
return tensor_dict
class LocalOrDistributedWorkerBase(WorkerBase):
"""
Partial implementation of WorkerBase that has a default `execute_model`
definition to perform metadata transfer between workers when in distributed
mode. Subclasses of this interface should use model runners that inherit
from ModelRunnerBase, and should only need to implement worker-local logic.
If custom control plane logic is needed to transfer metadata, or if the
model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
"""
is_driver_worker: bool
model_runner: ModelRunnerBase
@property
@abstractmethod
def do_metadata_broadcast(self) -> bool:
"""
Used by the default `execute_model` to check whether broadcast is
needed to transfer request inputs from the driver worker to other
workers in the TP group. If WorkerBase subclass only supports
single-worker execution, then this method should return False.
"""
raise NotImplementedError
@property
@abstractmethod
def kv_cache(self) -> Optional[List[torch.Tensor]]:
"""
Get the kv cache to pass to the worker's model runner. Used by the
default `execute_model`. If the worker's model runner does not follow
the ModelRunnerBase interface, then inherit from WorkerBase instead.
"""
raise NotImplementedError
@abstractmethod
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
"""
Prepare the inputs to WorkerBase.execute_worker from an execution
request. This method may move data to the worker's local device. It is
not allowed to communicate with other workers or devices.
"""
raise NotImplementedError
@abstractmethod
def execute_worker(self, worker_input: WorkerInput) -> None:
"""
Process an execution request.
"""
raise NotImplementedError
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
# This signals that there's no more requests to process for
# now. All workers are running infinite loop with
# broadcast_tensor_dict, and it stops the loop when the
# driver broadcasts an empty input. Send an empty input to
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list))
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
else:
assert self.do_metadata_broadcast
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
output = self.model_runner.execute_model(model_input, self.kv_cache)
# Worker only supports single-step execution. Wrap the output in a
# list to conform to interface.
return [output]
class WorkerWrapperBase:
"""
The whole point of this class is to lazily initialize the worker.

View File

@ -1,4 +1,5 @@
from typing import List, Optional, Tuple
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
@ -14,6 +15,15 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase,
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
logger = init_logger(__name__)
@ -24,7 +34,42 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
]
class XPUModelRunner:
@dataclass(frozen=True)
class ModelInputForXPU(ModelRunnerInputBase):
"""
Used by the NeuronModelRunner.
"""
input_tokens: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_input: Optional[Dict[str, torch.Tensor]] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
tensor_dict = {
"input_tokens": self.input_tokens,
"input_positions": self.input_positions,
}
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict,
self.sampling_metadata)
return tensor_dict
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["ModelInputForXPU"],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None,
) -> "ModelInputForXPU":
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)
class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
def __init__(
self,
@ -130,15 +175,22 @@ class XPUModelRunner:
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [None] * num_layers
self.execute_model(seqs, kv_caches)
model_input = self.prepare_model_input(seqs)
self.execute_model(model_input, kv_caches)
torch.xpu.synchronize()
return
def prepare_input_tensors(
def make_model_input_from_broadcasted_tensor_dict(
self, tensor_dict: Dict[str, Any]) -> ModelInputForXPU:
return (ModelInputForXPU.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
))
def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Optional[torch.Tensor]]:
) -> ModelInputForXPU:
multi_modal_input = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
@ -185,8 +237,11 @@ class XPUModelRunner:
num_prompts=0,
)
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, multi_modal_input)
return ModelInputForXPU(input_tokens=input_tokens,
input_positions=input_positions,
attn_metadata=attn_metadata,
sampling_metadata=sampling_metadata,
multi_modal_input=multi_modal_input)
def _prepare_decode(
self,
@ -277,27 +332,25 @@ class XPUModelRunner:
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
model_input: ModelInputForXPU,
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, attn_metadata, sampling_metadata,
multi_modal_input
) = self.prepare_input_tensors(seq_group_metadata_list)
model_executable = self.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"input_ids": model_input.input_tokens,
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
"attn_metadata": model_input.attn_metadata,
}
if self.vision_language_config:
execute_model_kwargs.update({"image_input": multi_modal_input})
execute_model_kwargs.update(
{"image_input": model_input.multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
logits = self.model.compute_logits(hidden_states,
model_input.sampling_metadata)
# Only perform sampling in the driver worker.
if not self.is_driver_worker:
@ -306,7 +359,7 @@ class XPUModelRunner:
# Sample the next token.
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
sampling_metadata=model_input.sampling_metadata,
)
return output