mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[Core] Refactor Worker and ModelRunner to consolidate control plane communication (#5408)
Signed-off-by: Stephanie Wang <swang@cs.berkeley.edu> Signed-off-by: Stephanie <swang@anyscale.com> Co-authored-by: Stephanie <swang@anyscale.com>
This commit is contained in:
parent
82079729cc
commit
dda4811591
152
tests/worker/test_model_input.py
Normal file
152
tests/worker/test_model_input.py
Normal file
@ -0,0 +1,152 @@
|
||||
import dataclasses
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.worker.embedding_model_runner import (
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
|
||||
class MockAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls():
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return AttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def test_model_runner_input():
|
||||
sampling_metadata = SamplingMetadata(
|
||||
["seq_group"],
|
||||
"selected_token_indices",
|
||||
"categorized_sample_indices",
|
||||
"num_prompts",
|
||||
)
|
||||
attn_metadata = AttentionMetadata(
|
||||
num_prefills=1,
|
||||
num_prefill_tokens=2,
|
||||
num_decode_tokens=3,
|
||||
slot_mapping=torch.zeros(1),
|
||||
)
|
||||
model_input = ModelInputForGPUWithSamplingMetadata(
|
||||
input_tokens=torch.ones(10),
|
||||
input_positions=torch.ones(10),
|
||||
sampling_metadata=sampling_metadata,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
assert isinstance(model_input, ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
# Test round trip serialization.
|
||||
tensor_dict = model_input.as_broadcastable_tensor_dict()
|
||||
attn_backend = MockAttentionBackend()
|
||||
received_model_input = (
|
||||
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict, attn_backend=attn_backend))
|
||||
# Check that received copy has correct values.
|
||||
assert isinstance(received_model_input,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
assert received_model_input.input_tokens is not None
|
||||
assert (
|
||||
received_model_input.input_tokens == model_input.input_tokens).all()
|
||||
assert received_model_input.input_positions is not None
|
||||
assert (received_model_input.input_positions == model_input.input_positions
|
||||
).all()
|
||||
assert received_model_input.multi_modal_kwargs is None
|
||||
assert (received_model_input.multi_modal_kwargs ==
|
||||
model_input.multi_modal_kwargs)
|
||||
assert received_model_input.lora_requests is None
|
||||
assert received_model_input.lora_requests == model_input.lora_requests
|
||||
assert received_model_input.lora_mapping is None
|
||||
assert received_model_input.lora_mapping == model_input.lora_mapping
|
||||
for field in dataclasses.fields(AttentionMetadata):
|
||||
assert getattr(received_model_input.attn_metadata, field.name,
|
||||
None) == getattr(attn_metadata, field.name, None)
|
||||
# For sampling metadata, only selected_token_indices is copied.
|
||||
assert (received_model_input.sampling_metadata.selected_token_indices ==
|
||||
sampling_metadata.selected_token_indices)
|
||||
assert received_model_input.sampling_metadata.seq_groups is None
|
||||
|
||||
|
||||
def test_embedding_model_runner_input():
|
||||
pooling_metadata = PoolingMetadata(
|
||||
seq_groups=[[0]],
|
||||
seq_data={},
|
||||
prompt_lens=[1],
|
||||
)
|
||||
attn_metadata = AttentionMetadata(
|
||||
num_prefills=1,
|
||||
num_prefill_tokens=2,
|
||||
num_decode_tokens=3,
|
||||
slot_mapping=torch.zeros(1),
|
||||
)
|
||||
model_input = ModelInputForGPUWithPoolingMetadata(
|
||||
input_tokens=torch.ones(10),
|
||||
input_positions=torch.ones(10),
|
||||
pooling_metadata=pooling_metadata,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
assert isinstance(model_input, ModelInputForGPUWithPoolingMetadata)
|
||||
|
||||
# Test round trip serialization.
|
||||
tensor_dict = model_input.as_broadcastable_tensor_dict()
|
||||
attn_backend = MockAttentionBackend()
|
||||
received_model_input = (
|
||||
ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict, attn_backend=attn_backend))
|
||||
# Check that received copy has correct values.
|
||||
assert isinstance(received_model_input,
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
assert received_model_input.input_tokens is not None
|
||||
assert (
|
||||
received_model_input.input_tokens == model_input.input_tokens).all()
|
||||
assert received_model_input.input_positions is not None
|
||||
assert (received_model_input.input_positions == model_input.input_positions
|
||||
).all()
|
||||
assert received_model_input.multi_modal_kwargs is None
|
||||
assert (received_model_input.multi_modal_kwargs ==
|
||||
model_input.multi_modal_kwargs)
|
||||
assert received_model_input.lora_requests is None
|
||||
assert received_model_input.lora_requests == model_input.lora_requests
|
||||
assert received_model_input.lora_mapping is None
|
||||
assert received_model_input.lora_mapping == model_input.lora_mapping
|
||||
for field in dataclasses.fields(AttentionMetadata):
|
||||
assert getattr(received_model_input.attn_metadata, field.name,
|
||||
None) == getattr(attn_metadata, field.name, None)
|
||||
# Pooling metadata is not broadcast.
|
||||
assert received_model_input.pooling_metadata is None
|
||||
@ -61,12 +61,13 @@ def test_prepare_prompt(batch_size):
|
||||
expected_selected_token_indices.append(selected_token_start_idx +
|
||||
seq_len - 1)
|
||||
selected_token_start_idx += seq_len
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
slot_mapping = model_input.slot_mapping
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
assert return_seq_lens == seq_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
@ -174,10 +175,11 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||
model_input.input_tokens, model_input.input_positions,
|
||||
model_input.attn_metadata, model_input.slot_mapping)
|
||||
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||
@ -259,32 +261,29 @@ def test_empty_seq_group():
|
||||
enforce_eager=False,
|
||||
)
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.slot_mapping,
|
||||
)
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert input_tokens is None
|
||||
assert input_positions is None
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
|
||||
model_input = model_runner._prepare_model_input(seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata, slot_mapping,
|
||||
return_seq_lens) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.slot_mapping,
|
||||
model_input.seq_lens,
|
||||
)
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.seq_lens,
|
||||
)
|
||||
assert input_tokens is None
|
||||
assert input_positions is None
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
assert len(return_seq_lens) == 0
|
||||
assert return_seq_lens is None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -353,8 +352,12 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
decode_metadata_list.append(seq_group_metadata)
|
||||
|
||||
(input_tokens, input_positions, attn_metadata, _, _, _,
|
||||
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
|
||||
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
)
|
||||
|
||||
prefill_meta_actual = attn_metadata.prefill_metadata
|
||||
decode_meta_actual = attn_metadata.decode_metadata
|
||||
@ -367,7 +370,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
|
||||
# Verify attn metadata is consistent. We don't need to test individual
|
||||
# values here because they are tested above.
|
||||
attn_metadata = model_runner._prepare_model_input(
|
||||
attn_metadata = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list).attn_metadata
|
||||
|
||||
for attr_expected, attr_actual in zip(vars(attn_metadata.prefill_metadata),
|
||||
|
||||
@ -21,9 +21,13 @@ class AttentionBackend(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def make_metadata(*args, **kwargs) -> "AttentionMetadata":
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
|
||||
return cls.get_metadata_cls()(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_kv_cache_shape(
|
||||
|
||||
@ -90,8 +90,8 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
|
||||
return BlocksparseFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
|
||||
return BlocksparseFlashAttentionMetadata(*args, **kwargs)
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return BlocksparseFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
||||
@ -25,8 +25,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
return FlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "FlashAttentionMetadata":
|
||||
return FlashAttentionMetadata(*args, **kwargs)
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return FlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
||||
@ -22,8 +22,8 @@ class FlashInferBackend(AttentionBackend):
|
||||
return FlashInferImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "FlashInferMetadata":
|
||||
return FlashInferMetadata(*args, **kwargs)
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return FlashInferMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
||||
@ -25,8 +25,8 @@ class IpexAttnBackend(AttentionBackend):
|
||||
return IpexAttnBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "IpexAttnMetadata":
|
||||
return IpexAttnMetadata(*args, **kwargs)
|
||||
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
|
||||
return IpexAttnMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
||||
@ -16,8 +16,8 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "PallasMetadata":
|
||||
return PallasMetadata(*args, **kwargs)
|
||||
def get_metadata_cls() -> Type["PallasMetadata"]:
|
||||
return PallasMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
||||
@ -25,8 +25,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
return ROCmFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
|
||||
return ROCmFlashAttentionMetadata(*args, **kwargs)
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return ROCmFlashAttentionMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
||||
@ -31,8 +31,8 @@ class TorchSDPABackend(AttentionBackend):
|
||||
return TorchSDPABackendImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "TorchSDPAMetadata":
|
||||
return TorchSDPAMetadata(*args, **kwargs)
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return TorchSDPAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
||||
@ -28,8 +28,8 @@ class XFormersBackend(AttentionBackend):
|
||||
return XFormersImpl
|
||||
|
||||
@staticmethod
|
||||
def make_metadata(*args, **kwargs) -> "XFormersMetadata":
|
||||
return XFormersMetadata(*args, **kwargs)
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return XFormersMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
|
||||
@ -64,8 +64,8 @@ class DistributedGPUExecutor(GPUExecutor):
|
||||
num_cpu_blocks=num_cpu_blocks)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if self.parallel_worker_tasks is None:
|
||||
self.parallel_worker_tasks = self._run_workers(
|
||||
"start_worker_execution_loop",
|
||||
@ -79,7 +79,7 @@ class DistributedGPUExecutor(GPUExecutor):
|
||||
if self.parallel_worker_tasks is None:
|
||||
return
|
||||
|
||||
self._driver_execute_model()
|
||||
self._driver_execute_model(execute_model_req=None)
|
||||
parallel_worker_tasks = self.parallel_worker_tasks
|
||||
self.parallel_worker_tasks = None
|
||||
# Ensure that workers exit model loop cleanly
|
||||
@ -123,13 +123,13 @@ class DistributedGPUExecutor(GPUExecutor):
|
||||
|
||||
@abstractmethod
|
||||
def _driver_execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
Passing None will cause the driver to stop the model execution loop
|
||||
running in each of the remote workers. In this case, this method
|
||||
returns None. Otherwise, this method returns the model output.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -69,8 +69,8 @@ class ExecutorBase(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes at least one model step on the given sequences."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -87,7 +87,7 @@ class GPUExecutor(ExecutorBase):
|
||||
|
||||
def execute_model(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
|
||||
@ -78,16 +78,14 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
worker_monitor.close()
|
||||
|
||||
def _driver_execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
return self.driver_worker.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
return self.driver_worker.execute_model(execute_model_req)
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
|
||||
@ -55,8 +55,7 @@ class NeuronExecutor(ExecutorBase):
|
||||
assert execute_model_req.num_lookahead_slots == 0, (
|
||||
"lookahead not supported for Neuron backend.")
|
||||
|
||||
output = self.driver_worker.execute_model(
|
||||
execute_model_req.seq_group_metadata_list)
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
|
||||
@ -190,9 +190,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
||||
max_parallel_loading_workers)
|
||||
|
||||
def _driver_execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
|
||||
@ -887,7 +887,8 @@ class HiddenStates:
|
||||
|
||||
@dataclass
|
||||
class ExecuteModelRequest:
|
||||
"""The model execution request."""
|
||||
"""The model execution request, containing CPU metadata only. The LLM
|
||||
engine should create an instance of this class for each request batch."""
|
||||
# The sequence group metadata list.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
# Blocks to swap in. List of CPU -> GPU block number.
|
||||
|
||||
@ -7,7 +7,6 @@ from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||
from vllm.worker.model_runner import ModelInput
|
||||
|
||||
|
||||
class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||
@ -56,7 +55,7 @@ class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker):
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[torch.Tensor, List[int], List[int]]:
|
||||
if not seq_group_metadata_list:
|
||||
return ModelInput.empty(self.device)
|
||||
return torch.empty(0, device=self.device), [], []
|
||||
|
||||
input_tokens: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -8,20 +9,64 @@ from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict,
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
class CPUModelRunner:
|
||||
@dataclass(frozen=True)
|
||||
class CPUModelInput(ModelRunnerInputBase):
|
||||
"""
|
||||
Used by the CPUModelRunner.
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
self.sampling_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type["CPUModelInput"],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None
|
||||
) -> "CPUModelInput":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -270,86 +315,70 @@ class CPUModelRunner:
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
def prepare_input_tensors(
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> CPUModelInput:
|
||||
return CPUModelInput.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||
Optional[Dict[str, torch.Tensor]]]:
|
||||
) -> CPUModelInput:
|
||||
multi_modal_kwargs = None
|
||||
if self.is_driver_worker:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
||||
seq_lens = []
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
# query_lens is not needed if chunked prefill is not
|
||||
# supported. Since CPU worker doesn't support chunked prefill
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
self.device,
|
||||
pin_memory=False)
|
||||
# Broadcast the metadata.
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"input_positions": input_positions,
|
||||
"selected_token_indices":
|
||||
sampling_metadata.selected_token_indices,
|
||||
}
|
||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs
|
||||
) = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
input_tokens = metadata_dict.pop("input_tokens")
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
selected_token_indices = metadata_dict.pop(
|
||||
"selected_token_indices")
|
||||
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
seq_data=None,
|
||||
seq_lens=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
generators=None,
|
||||
)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
sampling_metadata, multi_modal_kwargs)
|
||||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(seq_group_metadata_list)
|
||||
seq_lens = []
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
seq_lens,
|
||||
# query_lens is not needed if chunked prefill is not
|
||||
# supported. Since CPU worker doesn't support chunked prefill
|
||||
# just use seq_lens instead.
|
||||
seq_lens,
|
||||
self.device,
|
||||
pin_memory=False)
|
||||
return CPUModelInput(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
model_input: CPUModelInput,
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[SamplerOutput]:
|
||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
||||
multi_modal_input
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
model_executable = self.model
|
||||
execute_model_kwargs = {
|
||||
"input_ids": input_tokens,
|
||||
"positions": input_positions,
|
||||
"input_ids": model_input.input_tokens,
|
||||
"positions": model_input.input_positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": attn_metadata,
|
||||
"attn_metadata": model_input.attn_metadata,
|
||||
}
|
||||
if self.vision_language_config and multi_modal_input is not None:
|
||||
execute_model_kwargs.update(multi_modal_input)
|
||||
if (self.vision_language_config
|
||||
and model_input.multi_modal_kwargs is not None):
|
||||
execute_model_kwargs.update(model_input.multi_modal_kwargs)
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
@ -358,6 +387,6 @@ class CPUModelRunner:
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return output
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""A CPU worker class."""
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -8,15 +8,15 @@ from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
LoraNotSupportedWorkerBase, WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -110,7 +110,7 @@ class CPUCacheEngine:
|
||||
return dtype_size * total
|
||||
|
||||
|
||||
class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a CPU socket.
|
||||
|
||||
Each worker is associated with a single CPU socket. The worker is
|
||||
@ -154,7 +154,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
self.model_runner = CPUModelRunner(
|
||||
self.model_runner: CPUModelRunner = CPUModelRunner(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
@ -255,54 +255,37 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
for layer_cache in self.cpu_cache:
|
||||
layer_cache.fill_(0)
|
||||
|
||||
def cache_copy(
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||
return self.cpu_cache
|
||||
|
||||
def execute_worker(
|
||||
self,
|
||||
blocks_to_copy: torch.Tensor,
|
||||
worker_input: WorkerInput,
|
||||
) -> None:
|
||||
if blocks_to_copy.numel() > 0:
|
||||
self.cache_engine.copy(blocks_to_copy)
|
||||
if (worker_input.blocks_to_copy is not None
|
||||
and worker_input.blocks_to_copy.numel() > 0):
|
||||
self.cache_engine.copy(worker_input.blocks_to_copy)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> List[SamplerOutput]:
|
||||
|
||||
if execute_model_req is None:
|
||||
seq_group_metadata_list = None
|
||||
else:
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
if self.is_driver_worker:
|
||||
assert seq_group_metadata_list is not None
|
||||
num_seq_groups: int = len(seq_group_metadata_list)
|
||||
assert execute_model_req is not None
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||
data: Dict[str, Any] = {
|
||||
"num_seq_groups": num_seq_groups,
|
||||
"blocks_to_copy": execute_model_req.blocks_to_copy,
|
||||
}
|
||||
broadcast_tensor_dict(data, src=0)
|
||||
else:
|
||||
data = broadcast_tensor_dict(src=0)
|
||||
num_seq_groups = data["num_seq_groups"]
|
||||
blocks_to_copy = data["blocks_to_copy"]
|
||||
|
||||
self.cache_copy(blocks_to_copy)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||
self.cpu_cache)
|
||||
|
||||
# CPU worker only supports single-step execution.
|
||||
return [output]
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
assert execute_model_req is not None
|
||||
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
|
||||
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
def init_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
||||
@ -1,24 +1,32 @@
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelInputForGPU
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EmbeddingModelRunner(ModelRunner):
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU):
|
||||
"""
|
||||
Used by the EmbeddingModelRunner.
|
||||
"""
|
||||
pooling_metadata: Optional["PoolingMetadata"] = None
|
||||
|
||||
|
||||
class EmbeddingModelRunner(
|
||||
GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = (
|
||||
ModelInputForGPUWithPoolingMetadata)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -47,21 +55,22 @@ class EmbeddingModelRunner(ModelRunner):
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
model_input: ModelInputForGPUWithPoolingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[PoolerOutput]:
|
||||
(input_tokens, input_positions, attn_metadata, pooling_metadata,
|
||||
lora_requests, lora_mapping, multi_modal_input
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
if self.lora_config:
|
||||
self.set_active_loras(lora_requests, lora_mapping)
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
prefill_meta = attn_metadata.prefill_metadata
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
graph_batch_size = input_tokens.shape[0]
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
@ -70,13 +79,14 @@ class EmbeddingModelRunner(ModelRunner):
|
||||
kv_caches = [None] * num_layers
|
||||
|
||||
execute_model_kwargs = {
|
||||
"input_ids": input_tokens,
|
||||
"positions": input_positions,
|
||||
"input_ids": model_input.input_tokens,
|
||||
"positions": model_input.input_positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": attn_metadata,
|
||||
"attn_metadata": model_input.attn_metadata,
|
||||
}
|
||||
if self.vision_language_config:
|
||||
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
execute_model_kwargs.update({"image_input": multi_modal_kwargs})
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
@ -84,66 +94,31 @@ class EmbeddingModelRunner(ModelRunner):
|
||||
return None
|
||||
|
||||
return self.model.pooler(hidden_states=hidden_states,
|
||||
pooling_metadata=pooling_metadata)
|
||||
pooling_metadata=model_input.pooling_metadata)
|
||||
|
||||
def prepare_input_tensors(
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str,
|
||||
Any]) -> ModelInputForGPUWithPoolingMetadata:
|
||||
return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata,
|
||||
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
|
||||
if self.is_driver_worker:
|
||||
assert seq_group_metadata_list is not None
|
||||
# Prepare input tensors.
|
||||
(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
seq_lens,
|
||||
_,
|
||||
lora_mapping,
|
||||
lora_requests,
|
||||
multi_modal_kwargs,
|
||||
slot_mapping,
|
||||
num_prefill_tokens,
|
||||
num_decode_tokens,
|
||||
num_prefills,
|
||||
) = self._prepare_model_input(seq_group_metadata_list)
|
||||
# Prepare PoolingMetadata
|
||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||
seq_lens)
|
||||
) -> ModelInputForGPUWithPoolingMetadata:
|
||||
assert seq_group_metadata_list is not None
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
# Prepare PoolingMetadata.
|
||||
assert model_input.seq_lens is not None
|
||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||
model_input.seq_lens)
|
||||
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"input_positions": input_positions,
|
||||
"lora_requests": lora_requests,
|
||||
"lora_mapping": lora_mapping,
|
||||
"multi_modal_kwargs": multi_modal_kwargs,
|
||||
"num_prefill_tokens": num_prefill_tokens,
|
||||
"num_decode_tokens": num_decode_tokens,
|
||||
"slot_mapping": slot_mapping,
|
||||
"num_prefills": num_prefills,
|
||||
}
|
||||
if attn_metadata:
|
||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
else:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
input_tokens = metadata_dict.pop("input_tokens")
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||
lora_requests = metadata_dict.pop("lora_requests")
|
||||
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
|
||||
if metadata_dict:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
else:
|
||||
attn_metadata = None
|
||||
pooling_metadata = PoolingMetadata(seq_groups=None,
|
||||
seq_data=None,
|
||||
prompt_lens=None)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, pooling_metadata,
|
||||
lora_requests, lora_mapping, multi_modal_kwargs)
|
||||
return dataclasses.replace(model_input,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
def _prepare_pooling(
|
||||
self,
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
import dataclasses
|
||||
import gc
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
|
||||
TypeVar, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -12,7 +14,6 @@ from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.distributed.parallel_state import graph_capture
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
@ -26,6 +27,15 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip,
|
||||
is_pin_memory_available, make_tensor_with_pad)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict,
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -39,40 +49,90 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
]
|
||||
_NUM_WARMUP_ITERS = 2
|
||||
|
||||
TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
|
||||
|
||||
class ModelInput(NamedTuple):
|
||||
input_tokens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
attn_metadata: Optional[AttentionMetadata]
|
||||
seq_lens: List[int]
|
||||
query_lens: List[int]
|
||||
lora_mapping: Optional[LoRAMapping]
|
||||
lora_requests: Set[LoRARequest]
|
||||
multi_modal_kwargs: Dict[str, torch.Tensor]
|
||||
slot_mapping: torch.Tensor
|
||||
num_prefill_tokens: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelInputForGPU(ModelRunnerInputBase):
|
||||
"""
|
||||
This base class contains metadata needed for the base model forward pass
|
||||
but not metadata for possible additional steps, e.g., sampling. Model
|
||||
runners that run additional steps should subclass this method to add
|
||||
additional fields.
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
query_lens: Optional[List[int]] = None
|
||||
lora_mapping: Optional["LoRAMapping"] = None
|
||||
lora_requests: Optional[Set[LoRARequest]] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def empty(cls, device):
|
||||
return ModelInput(
|
||||
input_tokens=torch.empty(0, device=device),
|
||||
input_positions=torch.empty(0, device=device),
|
||||
attn_metadata=None,
|
||||
seq_lens=[],
|
||||
query_lens=[],
|
||||
lora_mapping=None,
|
||||
lora_requests=set(),
|
||||
multi_modal_kwargs={},
|
||||
slot_mapping=torch.empty(0, device=device),
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=0,
|
||||
num_prefills=0,
|
||||
)
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type[TModelInputForGPU],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> TModelInputForGPU:
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
||||
"""
|
||||
Used by the ModelRunner.
|
||||
"""
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
# Used for speculative decoding. We do not broadcast it because it is only
|
||||
# used by the driver worker.
|
||||
is_prompt: Optional[bool] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
self.sampling_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForGPUWithSamplingMetadata":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
"""
|
||||
Helper class for shared methods between GPU model runners.
|
||||
"""
|
||||
_model_input_cls: Type[TModelInputForGPU]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -241,11 +301,13 @@ class ModelRunner:
|
||||
block_size = self.block_size
|
||||
return (self.max_seq_len_to_capture + block_size - 1) // block_size
|
||||
|
||||
def _prepare_model_input(
|
||||
def _prepare_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> ModelInput:
|
||||
"""Prepare the model input based on a given sequence group.
|
||||
) -> TModelInputForGPU:
|
||||
"""Helper method to prepare the model input based on a given sequence
|
||||
group. Prepares metadata needed for the base model forward pass but not
|
||||
metadata for possible additional steps, e.g., sampling.
|
||||
|
||||
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
||||
|
||||
@ -296,7 +358,7 @@ class ModelRunner:
|
||||
paged_kv_last_page_len: List[int] = []
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return ModelInput.empty(self.device)
|
||||
return self._model_input_cls()
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window + self.block_size -
|
||||
@ -646,7 +708,7 @@ class ModelRunner:
|
||||
for k, v in multi_modal_kwargs_list.items()
|
||||
}
|
||||
|
||||
return ModelInput(
|
||||
return self._model_input_cls(
|
||||
input_tokens=input_tokens_tensor,
|
||||
input_positions=input_positions_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
@ -655,132 +717,8 @@ class ModelRunner:
|
||||
lora_mapping=lora_mapping,
|
||||
lora_requests=lora_requests,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
)
|
||||
|
||||
def prepare_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||
Set[LoRARequest], LoRAMapping, Dict[str, torch.Tensor]]:
|
||||
if self.is_driver_worker:
|
||||
assert seq_group_metadata_list is not None
|
||||
# Prepare input tensors.
|
||||
(
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
seq_lens,
|
||||
query_lens,
|
||||
lora_mapping,
|
||||
lora_requests,
|
||||
multi_modal_kwargs,
|
||||
slot_mapping,
|
||||
num_prefill_tokens,
|
||||
num_decode_tokens,
|
||||
num_prefills,
|
||||
) = self._prepare_model_input(seq_group_metadata_list)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||
self.pin_memory)
|
||||
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"input_positions": input_positions,
|
||||
"selected_token_indices":
|
||||
sampling_metadata.selected_token_indices,
|
||||
"lora_requests": lora_requests,
|
||||
"lora_mapping": lora_mapping,
|
||||
"multi_modal_kwargs": multi_modal_kwargs,
|
||||
"num_prefill_tokens": num_prefill_tokens,
|
||||
"num_decode_tokens": num_decode_tokens,
|
||||
"slot_mapping": slot_mapping,
|
||||
"num_prefills": num_prefills,
|
||||
}
|
||||
if attn_metadata:
|
||||
metadata_dict.update(attn_metadata.asdict_zerocopy())
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
else:
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
input_tokens = metadata_dict.pop("input_tokens")
|
||||
input_positions = metadata_dict.pop("input_positions")
|
||||
selected_token_indices = metadata_dict.pop(
|
||||
"selected_token_indices")
|
||||
lora_mapping = metadata_dict.pop("lora_mapping")
|
||||
lora_requests = metadata_dict.pop("lora_requests")
|
||||
multi_modal_kwargs = metadata_dict.pop("multi_modal_kwargs")
|
||||
if metadata_dict:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
**metadata_dict)
|
||||
else:
|
||||
attn_metadata = None
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
num_prompts=0,
|
||||
)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
sampling_metadata, lora_requests, lora_mapping,
|
||||
multi_modal_kwargs)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[SamplerOutput]:
|
||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
||||
lora_requests, lora_mapping, multi_modal_kwargs
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
if self.lora_config:
|
||||
self.set_active_loras(lora_requests, lora_mapping)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
prefill_meta = attn_metadata.prefill_metadata
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
graph_batch_size = input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
hidden_states = model_executable(
|
||||
input_ids=input_tokens,
|
||||
positions=input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
**multi_modal_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return None
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
assert seq_group_metadata_list is not None
|
||||
if seq_group_metadata_list[0].is_prompt:
|
||||
hidden_states = hidden_states.index_select(
|
||||
0, sampling_metadata.selected_token_indices)
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
return output
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
@ -853,7 +791,8 @@ class ModelRunner:
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [None] * num_layers
|
||||
self.execute_model(seqs, kv_caches)
|
||||
model_input = self.prepare_model_input(seqs)
|
||||
self.execute_model(model_input, kv_caches)
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
|
||||
@ -986,6 +925,110 @@ class ModelRunner:
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
|
||||
class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
"""
|
||||
GPU model runner with sampling step.
|
||||
"""
|
||||
_model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = (
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
return (
|
||||
ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
))
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> ModelInputForGPUWithSamplingMetadata:
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
|
||||
The API assumes seq_group_metadata_list is sorted by prefill -> decode.
|
||||
|
||||
The result tensors and data structure also batches input in prefill
|
||||
-> decode order. For example,
|
||||
|
||||
- input_tokens[:num_prefill_tokens] contains prefill tokens.
|
||||
- input_tokens[num_prefill_tokens:] contains decode tokens.
|
||||
|
||||
If cuda graph is required, this API automatically pads inputs.
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
self.pin_memory)
|
||||
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||
if seq_group_metadata_list else None)
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
is_prompt=is_prompt)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> SamplerOutput:
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
# Currently cuda graph is only supported by the decode phase.
|
||||
assert model_input.attn_metadata is not None
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
decode_meta = model_input.attn_metadata.decode_metadata
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[graph_batch_size]
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
**multi_modal_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return None
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
if model_input.is_prompt:
|
||||
assert model_input.sampling_metadata is not None
|
||||
hidden_states = hidden_states.index_select(
|
||||
0, model_input.sampling_metadata.selected_token_indices)
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CUDAGraphRunner:
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
|
||||
157
vllm/worker/model_runner_base.py
Normal file
157
vllm/worker/model_runner_base.py
Normal file
@ -0,0 +1,157 @@
|
||||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
||||
TypeVar)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
|
||||
T = TypeVar('T', bound="ModelRunnerInputBase")
|
||||
|
||||
|
||||
def _add_attn_metadata_broadcastable_dict(
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_metadata: Optional["AttentionMetadata"]) -> None:
|
||||
"""
|
||||
Helper method to update tensor_dict with broadcastable
|
||||
AttentionMetadata fields.
|
||||
"""
|
||||
if attn_metadata is not None:
|
||||
tensor_dict.update(attn_metadata.asdict_zerocopy())
|
||||
|
||||
|
||||
def _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend: "AttentionBackend",
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to initialize AttentionMetadata based on an
|
||||
AttentionBackend and broadcastable AttentionMetadata fields.
|
||||
"""
|
||||
# Extract the fields used to create AttentionMetadata.
|
||||
valid_attn_kwargs = {}
|
||||
for field in dataclasses.fields(attn_backend.get_metadata_cls()):
|
||||
val = tensor_dict.pop(field.name, None)
|
||||
if val is not None:
|
||||
valid_attn_kwargs[field.name] = val
|
||||
|
||||
attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs)
|
||||
tensor_dict["attn_metadata"] = attn_metadata
|
||||
return tensor_dict
|
||||
|
||||
|
||||
def _init_sampling_metadata_from_tensor_dict( # type: ignore
|
||||
tensor_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper method to initialize SamplingMetadata based on broadcastable
|
||||
SamplingMetadata fields.
|
||||
"""
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
|
||||
selected_token_indices = tensor_dict.pop("selected_token_indices", None)
|
||||
# An empty SamplingMetadata to signal that the worker should skip
|
||||
# sampling.
|
||||
if selected_token_indices is not None:
|
||||
tensor_dict["sampling_metadata"] = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
categorized_sample_indices=None,
|
||||
num_prompts=0,
|
||||
)
|
||||
return tensor_dict
|
||||
|
||||
|
||||
def _add_sampling_metadata_broadcastable_dict(
|
||||
tensor_dict: Dict[str, Any],
|
||||
sampling_metadata: Optional["SamplingMetadata"]) -> None:
|
||||
"""
|
||||
Helper method to update tensor_dict with broadcastable
|
||||
SamplingMetadata fields.
|
||||
"""
|
||||
if sampling_metadata is not None:
|
||||
tensor_dict["selected_token_indices"] = (
|
||||
sampling_metadata.selected_token_indices)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelRunnerInputBase(ABC):
|
||||
"""Local inputs to each worker's model runner. May contain
|
||||
device-specific data. Different worker backends may have different methods
|
||||
of converting from the global ExecuteModelRequest produced by the LLM
|
||||
engine to the worker-local ModelRunnerInputBase objects.
|
||||
|
||||
Model runners that support multi-GPU execution should define a
|
||||
ModelRunnerInputBase subclass, add their required fields, and specify how to
|
||||
serialize/deserialize a ModelInput for broadcast between workers.
|
||||
"""
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract broadcastable fields. Override for fields that require some
|
||||
custom deserialization.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type[T],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> T:
|
||||
"""
|
||||
Pop fields from the given tensor_dict and populate a new instance of
|
||||
ModelRunnerInputBase.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ModelRunnerBase(ABC, Generic[T]):
|
||||
"""
|
||||
Model runner interface that abstracts a particular hardware and/or type of
|
||||
model. Model execution may communicate data with model runners in other
|
||||
processes, but it should not include control plane metadata communication.
|
||||
|
||||
Each ModelRunnerBase subclass should define a corresponding
|
||||
ModelRunnerInputBase subclass.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> T:
|
||||
"""
|
||||
Make an instance of a ModelRunnerInputBase from the broadcasted tensor
|
||||
dict.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> T:
|
||||
"""
|
||||
Prepare the inputs to ModelRunnerBase.execute_model from an execution
|
||||
request. This method may move data to the worker's local device. It is
|
||||
not allowed to communicate with other workers or devices.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: T,
|
||||
kv_caches: Optional[List[torch.Tensor]],
|
||||
) -> Optional[SamplerOutput]:
|
||||
"""
|
||||
Execute the model on the given input.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -10,11 +11,39 @@ from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader.neuron import get_neuron_model
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class NeuronModelRunner:
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForNeuron(ModelRunnerInputBase):
|
||||
"""
|
||||
Used by the NeuronModelRunner.
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
input_block_ids: Optional[torch.Tensor] = None
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
raise NotImplementedError("ModelInputForNeuron cannot be broadcast.")
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForNeuron":
|
||||
assert attn_backend is None
|
||||
return cls.from_broadcasted_tensor_dict(tensor_dict)
|
||||
|
||||
|
||||
class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -139,10 +168,14 @@ class NeuronModelRunner:
|
||||
|
||||
return input_tokens, input_positions, input_block_ids
|
||||
|
||||
def prepare_input_tensors(
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron:
|
||||
return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
|
||||
) -> ModelInputForNeuron:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
@ -164,30 +197,31 @@ class NeuronModelRunner:
|
||||
self.device,
|
||||
self.pin_memory)
|
||||
|
||||
return (input_tokens, input_positions, input_block_ids,
|
||||
sampling_metadata)
|
||||
return ModelInputForNeuron(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
input_block_ids=input_block_ids,
|
||||
sampling_metadata=sampling_metadata)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
model_input: ModelInputForNeuron,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
) -> Optional[SamplerOutput]:
|
||||
(input_tokens, input_positions, input_block_ids, sampling_metadata
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_tokens,
|
||||
positions=input_positions,
|
||||
input_block_ids=input_block_ids,
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""A Neuron worker class."""
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -7,12 +7,13 @@ import torch.distributed
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.neuron_model_runner import NeuronModelRunner
|
||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
LoraNotSupportedWorkerBase, WorkerInput)
|
||||
|
||||
|
||||
class NeuronWorker(LoraNotSupportedWorkerBase):
|
||||
class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes the model on a group of neuron cores.
|
||||
"""
|
||||
|
||||
@ -34,8 +35,9 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
self.model_runner = NeuronModelRunner(model_config, parallel_config,
|
||||
scheduler_config, device_config)
|
||||
self.model_runner: NeuronModelRunner = NeuronModelRunner(
|
||||
model_config, parallel_config, scheduler_config, device_config)
|
||||
self.is_driver_worker = True
|
||||
|
||||
def init_device(self) -> None:
|
||||
# Set random seed.
|
||||
@ -73,22 +75,19 @@ class NeuronWorker(LoraNotSupportedWorkerBase):
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> List[SamplerOutput]:
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
output = self.model_runner.execute_model(seq_group_metadata_list)
|
||||
|
||||
# Neuron worker only supports single-step output. Wrap the output in a
|
||||
# list to conform to interface.
|
||||
return [output]
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
return WorkerInput(num_seq_groups=len(
|
||||
execute_model_req.seq_group_metadata_list), )
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Determine the size in bytes of a cache block.
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -9,21 +9,20 @@ import torch.distributed
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, VisionLanguageConfig)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
ensure_model_parallel_initialized,
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.embedding_model_runner import EmbeddingModelRunner
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.worker.worker_base import WorkerBase
|
||||
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
|
||||
|
||||
|
||||
class Worker(WorkerBase):
|
||||
class Worker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a GPU.
|
||||
|
||||
Each worker is associated with a single GPU. The worker is responsible for
|
||||
@ -78,9 +77,10 @@ class Worker(WorkerBase):
|
||||
or (speculative_config.draft_model_config.hf_config.model_type !=
|
||||
"mlp_speculator") else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass = (EmbeddingModelRunner if
|
||||
self.model_config.embedding_mode else ModelRunner)
|
||||
self.model_runner = ModelRunnerClass(
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
if self.model_config.embedding_mode:
|
||||
ModelRunnerClass = EmbeddingModelRunner
|
||||
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
@ -225,40 +225,18 @@ class Worker(WorkerBase):
|
||||
# the model initialization and profiling.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def cache_swap(
|
||||
self,
|
||||
blocks_to_swap_in: torch.Tensor,
|
||||
blocks_to_swap_out: torch.Tensor,
|
||||
blocks_to_copy: torch.Tensor,
|
||||
) -> None:
|
||||
# Issue cache operations.
|
||||
if blocks_to_swap_in.numel() > 0:
|
||||
self.cache_engine.swap_in(blocks_to_swap_in)
|
||||
if blocks_to_swap_out.numel() > 0:
|
||||
self.cache_engine.swap_out(blocks_to_swap_out)
|
||||
if blocks_to_copy.numel() > 0:
|
||||
self.cache_engine.copy(blocks_to_copy)
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||
return self.gpu_cache
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[Union[SamplerOutput, PoolerOutput]]:
|
||||
if not self.is_driver_worker:
|
||||
self._execute_model_non_driver()
|
||||
return []
|
||||
|
||||
if execute_model_req is None:
|
||||
# This signals that there's no more requests to process for now.
|
||||
# All workers are running infinite loop with broadcast_tensor_dict,
|
||||
# and it stops the loop when the driver broadcasts an empty input.
|
||||
# Send an empty input to notify all other workers to stop their
|
||||
# execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return []
|
||||
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||
# they contain parameters to launch cudamemcpyasync.
|
||||
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
|
||||
@ -273,59 +251,26 @@ class Worker(WorkerBase):
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device=self.device,
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
data: Dict[str, Any] = {
|
||||
"num_seq_groups": num_seq_groups,
|
||||
"blocks_to_swap_in": blocks_to_swap_in,
|
||||
"blocks_to_swap_out": blocks_to_swap_out,
|
||||
"blocks_to_copy": blocks_to_copy,
|
||||
}
|
||||
broadcast_tensor_dict(data, src=0)
|
||||
|
||||
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||
self.gpu_cache)
|
||||
|
||||
# Worker only supports single-step execution. Wrap the output in a list
|
||||
# to conform to interface.
|
||||
return [output]
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def start_worker_execution_loop(self) -> None:
|
||||
"""Execute model loop in parallel worker.
|
||||
|
||||
You can stop the loop by executing a driver worker with an empty output.
|
||||
See `stop_remote_worker_execution_loop` for more details.
|
||||
"""
|
||||
while self._execute_model_non_driver():
|
||||
pass
|
||||
|
||||
def _execute_model_non_driver(self) -> bool:
|
||||
"""Execute model in parallel worker.
|
||||
|
||||
Returns True iff there are remaining sequences to process.
|
||||
"""
|
||||
assert not self.is_driver_worker
|
||||
data = broadcast_tensor_dict(src=0)
|
||||
if not data:
|
||||
return False
|
||||
|
||||
num_seq_groups = data.get("num_seq_groups", 0)
|
||||
blocks_to_swap_in = data.get("blocks_to_swap_in")
|
||||
blocks_to_swap_out = data.get("blocks_to_swap_out")
|
||||
blocks_to_copy = data.get("blocks_to_copy")
|
||||
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if num_seq_groups == 0:
|
||||
return False
|
||||
|
||||
self.model_runner.execute_model(None, self.gpu_cache)
|
||||
return True
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
# Issue cache operations.
|
||||
if (worker_input.blocks_to_swap_in is not None
|
||||
and worker_input.blocks_to_swap_in.numel() > 0):
|
||||
self.cache_engine.swap_in(worker_input.blocks_to_swap_in)
|
||||
if (worker_input.blocks_to_swap_out is not None
|
||||
and worker_input.blocks_to_swap_out.numel() > 0):
|
||||
self.cache_engine.swap_out(worker_input.blocks_to_swap_out)
|
||||
if (worker_input.blocks_to_copy is not None
|
||||
and worker_input.blocks_to_copy.numel() > 0):
|
||||
self.cache_engine.copy(worker_input.blocks_to_copy)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
@ -1,20 +1,26 @@
|
||||
import dataclasses
|
||||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.utils import (enable_trace_function_call_for_thread, is_hip,
|
||||
update_environment_variables)
|
||||
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WorkerBase(ABC):
|
||||
"""Worker interface that allows vLLM to cleanly separate implementations for
|
||||
different hardware.
|
||||
different hardware. Also abstracts control plane communication, e.g., to
|
||||
communicate request metadata to other workers.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -46,13 +52,23 @@ class WorkerBase(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.inference_mode()
|
||||
def start_worker_execution_loop(self) -> None:
|
||||
"""Execute model loop in parallel worker.
|
||||
|
||||
You can stop the loop by executing a driver worker with an empty output.
|
||||
See `stop_remote_worker_execution_loop` for more details.
|
||||
"""
|
||||
while True:
|
||||
output = self.execute_model(execute_model_req=None)
|
||||
if output is None:
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
"""Executes at least one model step on the given sequences, unless no
|
||||
sequences are provided."""
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -98,6 +114,150 @@ class LoraNotSupportedWorkerBase(WorkerBase):
|
||||
raise ValueError(f"{type(self)} does not support LoRA")
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WorkerInput:
|
||||
"""Local inputs to each worker. May contain device-specific data. These
|
||||
fields should be broadcastable to other workers.
|
||||
"""
|
||||
|
||||
num_seq_groups: Optional[int] = None
|
||||
blocks_to_swap_in: Optional[torch.Tensor] = None
|
||||
blocks_to_swap_out: Optional[torch.Tensor] = None
|
||||
blocks_to_copy: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type["WorkerInput"],
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> "WorkerInput":
|
||||
"""
|
||||
Pop fields from the given tensor_dict and populate a new instance of
|
||||
WorkerInput.
|
||||
"""
|
||||
return cls(
|
||||
num_seq_groups=tensor_dict.pop("num_seq_groups"),
|
||||
blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"),
|
||||
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
|
||||
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
|
||||
)
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
"""
|
||||
Extract broadcastable fields.
|
||||
"""
|
||||
tensor_dict = {
|
||||
"num_seq_groups": self.num_seq_groups,
|
||||
"blocks_to_swap_in": self.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": self.blocks_to_swap_out,
|
||||
"blocks_to_copy": self.blocks_to_copy,
|
||||
}
|
||||
|
||||
return tensor_dict
|
||||
|
||||
|
||||
class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
"""
|
||||
Partial implementation of WorkerBase that has a default `execute_model`
|
||||
definition to perform metadata transfer between workers when in distributed
|
||||
mode. Subclasses of this interface should use model runners that inherit
|
||||
from ModelRunnerBase, and should only need to implement worker-local logic.
|
||||
If custom control plane logic is needed to transfer metadata, or if the
|
||||
model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
|
||||
"""
|
||||
is_driver_worker: bool
|
||||
model_runner: ModelRunnerBase
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
"""
|
||||
Used by the default `execute_model` to check whether broadcast is
|
||||
needed to transfer request inputs from the driver worker to other
|
||||
workers in the TP group. If WorkerBase subclass only supports
|
||||
single-worker execution, then this method should return False.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def kv_cache(self) -> Optional[List[torch.Tensor]]:
|
||||
"""
|
||||
Get the kv cache to pass to the worker's model runner. Used by the
|
||||
default `execute_model`. If the worker's model runner does not follow
|
||||
the ModelRunnerBase interface, then inherit from WorkerBase instead.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
"""
|
||||
Prepare the inputs to WorkerBase.execute_worker from an execution
|
||||
request. This method may move data to the worker's local device. It is
|
||||
not allowed to communicate with other workers or devices.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
"""
|
||||
Process an execution request.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes at least one model step on the given sequences, unless no
|
||||
sequences are provided."""
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
# This signals that there's no more requests to process for
|
||||
# now. All workers are running infinite loop with
|
||||
# broadcast_tensor_dict, and it stops the loop when the
|
||||
# driver broadcasts an empty input. Send an empty input to
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list))
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(
|
||||
model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
else:
|
||||
assert self.do_metadata_broadcast
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if not broadcast_data:
|
||||
return None
|
||||
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
||||
broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.
|
||||
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
|
||||
|
||||
self.execute_worker(worker_input)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if worker_input.num_seq_groups == 0:
|
||||
return []
|
||||
|
||||
output = self.model_runner.execute_model(model_input, self.kv_cache)
|
||||
# Worker only supports single-step execution. Wrap the output in a
|
||||
# list to conform to interface.
|
||||
return [output]
|
||||
|
||||
|
||||
class WorkerWrapperBase:
|
||||
"""
|
||||
The whole point of this class is to lazily initialize the worker.
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -14,6 +15,15 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict,
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -24,7 +34,42 @@ _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
]
|
||||
|
||||
|
||||
class XPUModelRunner:
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForXPU(ModelRunnerInputBase):
|
||||
"""
|
||||
Used by the NeuronModelRunner.
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
multi_modal_input: Optional[Dict[str, torch.Tensor]] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
self.sampling_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type["ModelInputForXPU"],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForXPU":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -130,15 +175,22 @@ class XPUModelRunner:
|
||||
# Run the model with the dummy inputs.
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
kv_caches = [None] * num_layers
|
||||
self.execute_model(seqs, kv_caches)
|
||||
model_input = self.prepare_model_input(seqs)
|
||||
self.execute_model(model_input, kv_caches)
|
||||
torch.xpu.synchronize()
|
||||
return
|
||||
|
||||
def prepare_input_tensors(
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> ModelInputForXPU:
|
||||
return (ModelInputForXPU.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
))
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
|
||||
Optional[torch.Tensor]]:
|
||||
) -> ModelInputForXPU:
|
||||
multi_modal_input = None
|
||||
if self.is_driver_worker:
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
@ -185,8 +237,11 @@ class XPUModelRunner:
|
||||
num_prompts=0,
|
||||
)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata,
|
||||
sampling_metadata, multi_modal_input)
|
||||
return ModelInputForXPU(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
multi_modal_input=multi_modal_input)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
@ -277,27 +332,25 @@ class XPUModelRunner:
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
model_input: ModelInputForXPU,
|
||||
kv_caches: List[torch.Tensor],
|
||||
) -> Optional[SamplerOutput]:
|
||||
(input_tokens, input_positions, attn_metadata, sampling_metadata,
|
||||
multi_modal_input
|
||||
) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
model_executable = self.model
|
||||
execute_model_kwargs = {
|
||||
"input_ids": input_tokens,
|
||||
"positions": input_positions,
|
||||
"input_ids": model_input.input_tokens,
|
||||
"positions": model_input.input_positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": attn_metadata,
|
||||
"attn_metadata": model_input.attn_metadata,
|
||||
}
|
||||
if self.vision_language_config:
|
||||
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||
execute_model_kwargs.update(
|
||||
{"image_input": model_input.multi_modal_input})
|
||||
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
@ -306,7 +359,7 @@ class XPUModelRunner:
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user