mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-09 00:27:07 +08:00
Merge branch 'main' into woosuk/model-runner-v2
This commit is contained in:
commit
6f038fc4fb
@ -49,7 +49,7 @@ repos:
|
||||
rev: 0.6.17
|
||||
hooks:
|
||||
- id: pip-compile
|
||||
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128]
|
||||
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28]
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- repo: local
|
||||
hooks:
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128
|
||||
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128 --python-platform x86_64-manylinux_2_28
|
||||
absl-py==2.1.0
|
||||
# via rouge-score
|
||||
accelerate==1.0.1
|
||||
|
||||
@ -83,7 +83,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("sliding_window", [None, 64, 128, 256])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
|
||||
505
tests/v1/kv_connector/unit/test_offloading_connector.py
Normal file
505
tests/v1/kv_connector/unit/test_offloading_connector.py
Normal file
@ -0,0 +1,505 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
|
||||
OffloadingConnector, OffloadingConnectorMetadata)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
|
||||
OffloadingManager, PrepareStoreOutput)
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
|
||||
TransferResult, TransferSpec)
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
from .utils import (EOS_TOKEN_ID, create_model_runner_output, create_scheduler,
|
||||
create_vllm_config)
|
||||
|
||||
|
||||
class MockLoadStoreSpec(LoadStoreSpec):
|
||||
|
||||
def __init__(self, block_hashes: Iterable[BlockHash]):
|
||||
self.block_hashes: list[BlockHash] = list(block_hashes)
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "Mock"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.block_hashes)
|
||||
|
||||
|
||||
class MockOffloadingHandler(OffloadingHandler):
|
||||
|
||||
def __init__(self):
|
||||
self.completed_transfers: list[TransferResult] = []
|
||||
self.completed_specs: list[TransferSpec] = []
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = self.completed_transfers
|
||||
self.completed_transfers = []
|
||||
return finished
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
self.completed_specs.append(spec)
|
||||
self.completed_transfers.append((job_id, True))
|
||||
return True
|
||||
|
||||
|
||||
class MockOffloadingSpec(OffloadingSpec):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.manager = MagicMock(spec=OffloadingManager)
|
||||
self.manager.lookup.return_value = 0
|
||||
self.manager.prepare_load = lambda block_hashes: (MockLoadStoreSpec(
|
||||
block_hashes))
|
||||
self.handler = MockOffloadingHandler()
|
||||
|
||||
def get_manager(self) -> OffloadingManager:
|
||||
return self.manager
|
||||
|
||||
def get_handlers(
|
||||
self, _
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
|
||||
OffloadingHandler]]:
|
||||
|
||||
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
|
||||
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
|
||||
|
||||
def get_completed_transfers(self) -> list[TransferSpec]:
|
||||
specs = self.handler.completed_specs
|
||||
self.handler.completed_specs = []
|
||||
return specs
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferSummary:
|
||||
gpu_block_indices: list[int]
|
||||
offload_addresses: list[Any]
|
||||
|
||||
|
||||
class RequestRunner:
|
||||
|
||||
def __init__(self, offloaded_block_size: int, gpu_block_size: int,
|
||||
num_gpu_blocks: int):
|
||||
self.offloaded_block_size: int = offloaded_block_size
|
||||
self.gpu_block_size: int = gpu_block_size
|
||||
self.num_gpu_blocks: int = num_gpu_blocks
|
||||
|
||||
self.req_id: int = -1
|
||||
|
||||
vllm_config = create_vllm_config(block_size=gpu_block_size,
|
||||
max_num_batched_tokens=1000)
|
||||
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="OffloadingConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"spec_name": "MockOffloadingSpec",
|
||||
"spec_module_path":
|
||||
"tests.v1.kv_connector.unit.test_offloading_connector",
|
||||
"block_size": offloaded_block_size,
|
||||
})
|
||||
|
||||
self.scheduler: Scheduler = create_scheduler(vllm_config,
|
||||
num_blocks=num_gpu_blocks)
|
||||
self.worker_connector = OffloadingConnector(vllm_config,
|
||||
KVConnectorRole.WORKER)
|
||||
|
||||
# register worker kv_caches to enable OffloadingWorker creations
|
||||
self.worker_connector.register_kv_caches(
|
||||
kv_caches={"a": torch.empty(0)})
|
||||
|
||||
# extract connector of scheduler
|
||||
scheduler_connector = self.scheduler.connector
|
||||
assert scheduler_connector is not None
|
||||
assert isinstance(scheduler_connector, OffloadingConnector)
|
||||
self.scheduler_connector: OffloadingConnector = scheduler_connector
|
||||
|
||||
# extract mocked OffloadingManager of scheduler connector
|
||||
connector_scheduler = scheduler_connector.connector_scheduler
|
||||
assert connector_scheduler is not None
|
||||
manager = connector_scheduler.manager
|
||||
assert isinstance(manager, MagicMock)
|
||||
self.manager: MagicMock = manager
|
||||
|
||||
assert connector_scheduler.gpu_block_size == gpu_block_size
|
||||
assert connector_scheduler.offloaded_block_size == offloaded_block_size
|
||||
|
||||
# extract OffloadingSpec of worker_connector
|
||||
connector_worker = self.worker_connector.connector_worker
|
||||
assert connector_worker is not None
|
||||
offloading_spec = connector_worker.spec
|
||||
assert isinstance(offloading_spec, MockOffloadingSpec)
|
||||
self.offloading_spec: MockOffloadingSpec = offloading_spec
|
||||
|
||||
# mapping (offloading address) -> gpu_block_index
|
||||
self.offloaded: dict[Any, int] = {}
|
||||
|
||||
self.pending_loads_count: int = 0
|
||||
self.pending_stores_count: int = 0
|
||||
|
||||
self.completed_loads: list[TransferSummary] = []
|
||||
self.completed_stores: list[TransferSummary] = []
|
||||
|
||||
# maps {block_id: block_offset}
|
||||
self.gpu_block_index: dict[int, int] = {}
|
||||
|
||||
init_none_hash(sha256)
|
||||
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
|
||||
|
||||
self._dummy_ctx: ForwardContext = ForwardContext(no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0)
|
||||
|
||||
def new_request(self, token_ids: list[int]):
|
||||
assert not self.scheduler.requests
|
||||
self.req_id += 1
|
||||
|
||||
req = Request(
|
||||
request_id=str(self.req_id),
|
||||
prompt_token_ids=token_ids,
|
||||
sampling_params=SamplingParams(max_tokens=1000),
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=self._block_hasher,
|
||||
)
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def _wait_for_transfers(self):
|
||||
block_size_factor = self.offloaded_block_size // self.gpu_block_size
|
||||
|
||||
while self.pending_loads_count or self.pending_stores_count:
|
||||
for transfer_spec in (
|
||||
self.offloading_spec.get_completed_transfers()):
|
||||
src_spec, dst_spec = transfer_spec
|
||||
|
||||
if isinstance(src_spec, GPULoadStoreSpec):
|
||||
store = True
|
||||
gpu_spec = src_spec
|
||||
offload_spec = dst_spec
|
||||
else:
|
||||
store = False
|
||||
gpu_spec = dst_spec
|
||||
offload_spec = src_spec
|
||||
|
||||
assert isinstance(offload_spec, MockLoadStoreSpec)
|
||||
assert isinstance(gpu_spec, GPULoadStoreSpec)
|
||||
|
||||
gpu_block_indices: list[int] = []
|
||||
for block_id in gpu_spec.block_ids:
|
||||
gpu_block_indices.append(
|
||||
self.gpu_block_index[block_id.item()])
|
||||
|
||||
# list of (block_hash, sub_block_offset)
|
||||
offload_addresses: list[Any] = []
|
||||
for block_hash in offload_spec.block_hashes:
|
||||
for sub_block_idx in range(block_size_factor):
|
||||
offload_addresses.append((block_hash, sub_block_idx))
|
||||
|
||||
if store:
|
||||
assert len(gpu_block_indices) == len(offload_addresses)
|
||||
|
||||
self.completed_stores.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses))
|
||||
self.pending_stores_count -= 1
|
||||
else:
|
||||
remainder_sub_block_count = (len(offload_addresses) -
|
||||
len(gpu_block_indices))
|
||||
assert remainder_sub_block_count >= 0
|
||||
assert remainder_sub_block_count < block_size_factor
|
||||
offload_addresses = offload_addresses[
|
||||
remainder_sub_block_count:]
|
||||
|
||||
self.completed_loads.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses))
|
||||
self.pending_loads_count -= 1
|
||||
|
||||
def _update_gpu_block_idx(self):
|
||||
for blocks in (self.scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks.values()):
|
||||
for block_idx, block in enumerate(blocks):
|
||||
self.gpu_block_index[block.block_id] = block_idx
|
||||
|
||||
def _run(self, decoded_tokens: list[int]):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
Assumes a single request is running.
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
"""
|
||||
|
||||
tokens_iter = iter(decoded_tokens)
|
||||
token_id = next(tokens_iter, None)
|
||||
while token_id is not None:
|
||||
assert self.scheduler.requests
|
||||
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
self._update_gpu_block_idx()
|
||||
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata,
|
||||
OffloadingConnectorMetadata)
|
||||
|
||||
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
|
||||
self.pending_stores_count += len(
|
||||
kv_connector_metadata.reqs_to_store)
|
||||
|
||||
self.worker_connector.bind_connector_metadata(
|
||||
kv_connector_metadata)
|
||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
self.worker_connector.wait_for_save()
|
||||
|
||||
finished_sending, finished_recving = (
|
||||
self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids))
|
||||
|
||||
self.worker_connector.clear_connector_metadata()
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=self.scheduler.running,
|
||||
finished_sending=list(finished_sending),
|
||||
finished_recving=list(finished_recving),
|
||||
token_id=token_id)
|
||||
|
||||
if self.scheduler.running:
|
||||
token_id = next(tokens_iter, None)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
self._wait_for_transfers()
|
||||
|
||||
# run one more step to update finished stored
|
||||
if EOS_TOKEN_ID in decoded_tokens:
|
||||
assert not self.scheduler.running
|
||||
|
||||
while self.scheduler.requests:
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
finished_sending, finished_recving = (
|
||||
self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids))
|
||||
|
||||
assert not finished_recving
|
||||
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
def run(
|
||||
self,
|
||||
decoded_tokens: list[int],
|
||||
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
|
||||
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
|
||||
):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
Assumes a single request is running.
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
expected_stored_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be written during the run.
|
||||
expected_loaded_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be loaded during the run.
|
||||
"""
|
||||
|
||||
self.manager.reset_mock()
|
||||
self._run(decoded_tokens)
|
||||
|
||||
loaded_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_loads:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses):
|
||||
loaded_gpu_block_indexes.add(gpu_block_idx)
|
||||
assert gpu_block_idx == self.offloaded[offloaded_address]
|
||||
|
||||
assert (
|
||||
set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes)
|
||||
self.completed_loads.clear()
|
||||
|
||||
stored_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_stores:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses):
|
||||
stored_gpu_block_indexes.add(gpu_block_idx)
|
||||
self.offloaded[offloaded_address] = gpu_block_idx
|
||||
|
||||
assert (
|
||||
set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes)
|
||||
self.completed_stores.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def request_runner():
|
||||
runners = []
|
||||
|
||||
def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks):
|
||||
runner = RequestRunner(offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
runners.append(runner)
|
||||
return runner
|
||||
|
||||
yield runner_factory # pass factory to the test
|
||||
|
||||
|
||||
def generate_store_output(block_hashes: Iterable[BlockHash]):
|
||||
block_hashes = list(block_hashes)
|
||||
return PrepareStoreOutput(
|
||||
block_hashes_to_store=list(block_hashes),
|
||||
store_spec=MockLoadStoreSpec(block_hashes),
|
||||
block_hashes_evicted=[],
|
||||
)
|
||||
|
||||
|
||||
def test_offloading_connector(request_runner):
|
||||
offloaded_block_size = 12
|
||||
gpu_block_size = 4
|
||||
num_gpu_blocks = 100
|
||||
block_size_factor = offloaded_block_size // gpu_block_size
|
||||
|
||||
runner = request_runner(offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# 3 blocks, store just the middle block (skip first and last)
|
||||
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size * 3)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
|
||||
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
|
||||
|
||||
# add block missing 1 token -> no offload
|
||||
runner.run(decoded_tokens=[0] * (offloaded_block_size - 1))
|
||||
runner.manager.prepare_store.assert_not_called()
|
||||
|
||||
# +1 token -> single block, fail prepare_store
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: None
|
||||
runner.run(decoded_tokens=[0])
|
||||
runner.manager.prepare_store.assert_called()
|
||||
|
||||
# 1 more block, now set block_hashes_to_store = []
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size)
|
||||
|
||||
# 1 more block, now check touch was called with all 6 blocks
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size,
|
||||
expected_stored_gpu_block_indexes=(15, 16, 17))
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes1 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes1) == 6
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# create a new request differing only on the last token
|
||||
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
|
||||
runner.run(decoded_tokens=[0],
|
||||
expected_stored_gpu_block_indexes=tuple(
|
||||
range(6 * block_size_factor)))
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes2 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes2) == 6
|
||||
|
||||
# verify hashes are the same, except for the last block
|
||||
assert block_hashes1[:5] == block_hashes2[:5]
|
||||
assert block_hashes1[5] != block_hashes2[5]
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# full_block_tokens - num_computed_tokens < offloaded_block_size
|
||||
runner.new_request(token_ids=[0] * gpu_block_size + [1] *
|
||||
(offloaded_block_size - gpu_block_size))
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_not_called()
|
||||
|
||||
# single block lookup with no hits
|
||||
runner.new_request(token_ids=[1] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_called()
|
||||
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
|
||||
|
||||
# single block lookup with a hit
|
||||
runner.scheduler.reset_prefix_cache()
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_loaded_gpu_block_indexes=(0, 1, 2))
|
||||
|
||||
# single block lookup with a hit in a middle block
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size * 2 +
|
||||
[1] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_loaded_gpu_block_indexes=(3, 4, 5))
|
||||
|
||||
# test take_events
|
||||
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
|
||||
return [BlockHash(str(i).encode()) for i in int_hashes]
|
||||
|
||||
def take_events() -> Iterable[OffloadingEvent]:
|
||||
yield OffloadingEvent(block_hashes=to_hashes([1, 2, 3]),
|
||||
block_size=16,
|
||||
medium="A",
|
||||
removed=False)
|
||||
yield OffloadingEvent(block_hashes=to_hashes([4, 5, 6]),
|
||||
block_size=32,
|
||||
medium="B",
|
||||
removed=True)
|
||||
|
||||
runner.manager.take_events.side_effect = take_events
|
||||
events = list(runner.scheduler_connector.take_events())
|
||||
assert len(events) == 2
|
||||
event = events[0]
|
||||
assert isinstance(event, BlockStored)
|
||||
assert event.block_hashes == to_hashes([1, 2, 3])
|
||||
assert event.block_size == 16
|
||||
assert event.medium == "A"
|
||||
assert event.token_ids == []
|
||||
assert event.parent_block_hash is None
|
||||
assert event.lora_id is None
|
||||
event = events[1]
|
||||
assert isinstance(event, BlockRemoved)
|
||||
assert event.block_hashes == to_hashes([4, 5, 6])
|
||||
assert event.medium == "B"
|
||||
@ -176,6 +176,7 @@ def create_model_runner_output(
|
||||
finished_sending: Optional[list[str]] = None,
|
||||
finished_recving: Optional[list[str]] = None,
|
||||
use_eos: bool = False,
|
||||
token_id: int = 0,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Make dummy model runner output for testing."""
|
||||
|
||||
@ -184,7 +185,7 @@ def create_model_runner_output(
|
||||
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
|
||||
|
||||
# Make sampled tokens.
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else 0
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else token_id
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
kv_connector_output = None if (
|
||||
|
||||
@ -184,8 +184,30 @@ def kernel_unified_attention_2d(
|
||||
# this prefix can be skipped)
|
||||
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||
|
||||
# iterate through tiles
|
||||
for j in range(0, num_tiles):
|
||||
# ---- Sliding-window tile pruning --------------------
|
||||
# Default: keep previous global behavior
|
||||
tile_start = 0
|
||||
tile_end = num_tiles
|
||||
if SLIDING_WINDOW > 0:
|
||||
# Query rows covered by this Q-block
|
||||
qpos_lo = q_block_local_idx * BLOCK_Q
|
||||
qpos_hi = tl.minimum(
|
||||
qpos_lo + (BLOCK_M - 1) // num_queries_per_kv,
|
||||
cur_batch_query_len - 1,
|
||||
)
|
||||
# For sliding window, each query position q can only attend to
|
||||
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
|
||||
# where q_abs = context_len + q
|
||||
# The union of allowed key positions for this Q-block is:
|
||||
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
|
||||
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
|
||||
last_allowed_key = context_len + qpos_hi
|
||||
# Convert to tile indices and clamp
|
||||
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
|
||||
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
|
||||
|
||||
# iterate through tiles (now limited to the sliding window range)
|
||||
for j in range(tile_start, tile_end):
|
||||
seq_offset = j * TILE_SIZE + offs_t
|
||||
tile_mask = seq_offset < max_seq_prefix_len
|
||||
|
||||
|
||||
@ -106,3 +106,8 @@ KVConnectorFactory.register_connector(
|
||||
"MultiConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
|
||||
"MultiConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"OffloadingConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
|
||||
"OffloadingConnector")
|
||||
|
||||
@ -0,0 +1,485 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_offload.abstract import OffloadingManager
|
||||
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingWorker, TransferSpec
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
ReqId = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
reqs_to_load: dict[ReqId, TransferSpec]
|
||||
reqs_to_store: dict[ReqId, TransferSpec]
|
||||
|
||||
|
||||
class OffloadingConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
super().__init__(vllm_config, role)
|
||||
|
||||
spec = OffloadingSpecFactory.create_spec(vllm_config)
|
||||
|
||||
self.connector_scheduler: Optional[OffloadingConnectorScheduler] = None
|
||||
self.connector_worker: Optional[OffloadingConnectorWorker] = None
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = OffloadingConnectorScheduler(spec)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_worker = OffloadingConnectorWorker(spec)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata,
|
||||
OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata,
|
||||
OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_store_kv(self._connector_metadata)
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(finished_req_ids)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.update_connector_output(connector_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.take_events()
|
||||
|
||||
|
||||
class OffloadingConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.gpu_block_size = spec.gpu_block_size
|
||||
self.offloaded_block_size = spec.offloaded_block_size
|
||||
self.block_size_factor = (self.offloaded_block_size //
|
||||
self.gpu_block_size)
|
||||
self.manager: OffloadingManager = spec.get_manager()
|
||||
|
||||
self._requests: dict[ReqId, Request] = {}
|
||||
# list of GPU block IDs per request
|
||||
self._request_block_ids: dict[ReqId, list[int]] = {}
|
||||
# requests to load for the current scheduler step
|
||||
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
|
||||
# request blocks are stored in order
|
||||
# index of next block (of size offloaded_block_size) to offload
|
||||
self._next_stored_block_idx: dict[ReqId, int] = {}
|
||||
|
||||
# request ID -> set(block hashes being stored/load)
|
||||
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
|
||||
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
|
||||
|
||||
def _get_block_hashes(
|
||||
self,
|
||||
req: Request,
|
||||
start_idx: int = 0,
|
||||
end_idx: Optional[int] = None,
|
||||
) -> Iterable[BlockHash]:
|
||||
return islice(
|
||||
req.block_hashes,
|
||||
self.block_size_factor * start_idx + self.block_size_factor - 1,
|
||||
self.block_size_factor * end_idx if end_idx else None,
|
||||
self.block_size_factor)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: Request,
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded beyond the
|
||||
num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- The number of tokens that can be loaded beyond what is
|
||||
already computed.
|
||||
- `True` if tokens will be loaded asynchronously
|
||||
(between scheduler steps).
|
||||
"""
|
||||
num_blocks = request.num_tokens // self.offloaded_block_size
|
||||
|
||||
assert (len(request.block_hashes) //
|
||||
self.block_size_factor == num_blocks)
|
||||
block_hashes = self._get_block_hashes(request)
|
||||
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
full_block_tokens = self.offloaded_block_size * num_blocks
|
||||
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
|
||||
# we can load less than a block, skip
|
||||
return 0, False
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
hits = self.manager.lookup(
|
||||
self._get_block_hashes(request, start_idx=start_block_idx))
|
||||
if hits == 0:
|
||||
return 0, False
|
||||
|
||||
num_hit_tokens = (self.offloaded_block_size *
|
||||
(start_block_idx + hits) - num_computed_tokens)
|
||||
logger.debug(
|
||||
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
|
||||
request.request_id,
|
||||
num_hit_tokens,
|
||||
num_computed_tokens,
|
||||
)
|
||||
if num_hit_tokens < self.offloaded_block_size:
|
||||
return 0, False
|
||||
|
||||
return num_hit_tokens, True
|
||||
|
||||
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
|
||||
num_external_tokens: int):
|
||||
self._requests[request.request_id] = request
|
||||
# the block ids are updated in _get_reqs_to_store
|
||||
self._request_block_ids[request.request_id] = []
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
block_groups = blocks.get_block_ids()
|
||||
block_ids = block_groups[0]
|
||||
|
||||
num_computed_gpu_blocks = sum(block.block_hash is not None
|
||||
for block in blocks.blocks[0])
|
||||
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
|
||||
full_block_tokens = num_computed_tokens + num_external_tokens
|
||||
assert full_block_tokens % self.offloaded_block_size == 0
|
||||
|
||||
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
|
||||
assert (num_external_tokens == num_pending_gpu_blocks *
|
||||
self.gpu_block_size)
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
num_blocks = full_block_tokens // self.offloaded_block_size
|
||||
|
||||
assert (len(request.block_hashes) // self.block_size_factor
|
||||
>= num_blocks)
|
||||
block_hashes = self._get_block_hashes(request,
|
||||
start_idx=start_block_idx,
|
||||
end_idx=num_blocks)
|
||||
|
||||
src_spec = self.manager.prepare_load(block_hashes)
|
||||
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
|
||||
|
||||
block_hashes = self._get_block_hashes(request,
|
||||
start_idx=start_block_idx,
|
||||
end_idx=num_blocks)
|
||||
|
||||
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_loaded[request.request_id].update(block_hashes)
|
||||
self._next_stored_block_idx[request.request_id] = num_blocks
|
||||
|
||||
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
|
||||
reqs_to_store: dict[ReqId, TransferSpec] = {}
|
||||
# iterate over both new and cached requests
|
||||
for req_id, new_block_id_groups, preempted in yield_req_data(
|
||||
scheduler_output):
|
||||
|
||||
if preempted:
|
||||
self._request_block_ids[req_id] = []
|
||||
|
||||
if new_block_id_groups:
|
||||
new_block_ids = new_block_id_groups[0]
|
||||
self._request_block_ids[req_id] += new_block_ids
|
||||
|
||||
block_ids = self._request_block_ids[req_id]
|
||||
|
||||
req = self._requests[req_id]
|
||||
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
total_tokens = req.num_computed_tokens + new_tokens
|
||||
num_blocks = total_tokens // self.offloaded_block_size
|
||||
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
|
||||
num_new_blocks = num_blocks - start_block_idx
|
||||
|
||||
if num_new_blocks <= 0:
|
||||
continue
|
||||
|
||||
num_gpu_blocks = num_blocks * self.block_size_factor
|
||||
assert len(req.block_hashes) >= num_gpu_blocks
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks)
|
||||
store_output = self.manager.prepare_store(new_block_hashes)
|
||||
if store_output is None:
|
||||
logger.warning("Cannot store %s blocks", num_new_blocks)
|
||||
break
|
||||
|
||||
self._next_stored_block_idx[req_id] = num_blocks
|
||||
|
||||
if not store_output.block_hashes_to_store:
|
||||
continue
|
||||
block_hashes_to_store = set(store_output.block_hashes_to_store)
|
||||
|
||||
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks)
|
||||
dst_spec = store_output.store_spec
|
||||
src_block_ids: list[int] = []
|
||||
for idx, blk_hash in enumerate(new_block_hashes):
|
||||
if blk_hash not in block_hashes_to_store:
|
||||
continue
|
||||
offloaded_block_idx = start_block_idx + idx
|
||||
gpu_block_idx = offloaded_block_idx * self.block_size_factor
|
||||
for i in range(self.block_size_factor):
|
||||
src_block_ids.append(block_ids[gpu_block_idx + i])
|
||||
src_spec = GPULoadStoreSpec(src_block_ids)
|
||||
|
||||
reqs_to_store[req_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_stored[req_id] |= block_hashes_to_store
|
||||
|
||||
logger.debug(
|
||||
"Request %s offloading %s blocks starting from block #%d",
|
||||
req_id,
|
||||
len(block_hashes_to_store),
|
||||
start_block_idx,
|
||||
)
|
||||
|
||||
return reqs_to_store
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
meta = OffloadingConnectorMetadata(
|
||||
reqs_to_load=self._reqs_to_load,
|
||||
reqs_to_store=self._get_reqs_to_store(scheduler_output))
|
||||
self._reqs_to_load = {}
|
||||
return meta
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
for req_id in connector_output.finished_sending or []:
|
||||
block_hashes = self._reqs_being_stored.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
|
||||
for req_id in connector_output.finished_recving or []:
|
||||
block_hashes = self._reqs_being_loaded.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_load(block_hashes)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: Request,
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
self._requests.pop(req_id, None)
|
||||
self._request_block_ids.pop(req_id, None)
|
||||
self._next_stored_block_idx.pop(req_id, None)
|
||||
|
||||
request_being_stored = req_id in self._reqs_being_stored
|
||||
return request_being_stored, None
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
"""Take the KV cache events from the connector.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
for event in self.manager.take_events():
|
||||
if event.removed:
|
||||
yield BlockRemoved(block_hashes=event.block_hashes,
|
||||
medium=event.medium)
|
||||
else:
|
||||
yield BlockStored(block_hashes=event.block_hashes,
|
||||
parent_block_hash=None,
|
||||
token_ids=[],
|
||||
lora_id=None,
|
||||
block_size=event.block_size,
|
||||
medium=event.medium)
|
||||
|
||||
|
||||
class OffloadingConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.spec = spec
|
||||
self.worker = OffloadingWorker()
|
||||
|
||||
self._job_counter = 0
|
||||
|
||||
# req_id -> (job_id, store)
|
||||
self._jobs: dict[int, tuple[ReqId, bool]] = {}
|
||||
# req_id -> active job IDs
|
||||
self._load_job: dict[ReqId, int] = {}
|
||||
# req_id -> set(active job IDs)
|
||||
self._store_jobs = defaultdict[ReqId, set[int]](set)
|
||||
|
||||
self._finished_reqs_waiting_for_store: set[ReqId] = set()
|
||||
|
||||
def _generate_job_id(self) -> int:
|
||||
job_id = self._job_counter
|
||||
self._job_counter = job_id + 1
|
||||
return job_id
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for src_cls, dst_cls, handler in (self.spec.get_handlers(kv_caches)):
|
||||
self.worker.register_handler(src_cls, dst_cls, handler)
|
||||
|
||||
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, False)
|
||||
assert req_id not in self._load_job
|
||||
self._load_job[req_id] = job_id
|
||||
assert self.worker.transfer_async(job_id, transfer_spec)
|
||||
|
||||
def start_store_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_store.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, True)
|
||||
self._store_jobs[req_id].add(job_id)
|
||||
assert self.worker.transfer_async(job_id, transfer_spec)
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
Returns a list of request IDs that finished loading or storing.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
"""
|
||||
finished_sending = set()
|
||||
finished_recving = set()
|
||||
for job_id, success in self.worker.get_finished():
|
||||
# we currently do not support job failures
|
||||
assert success
|
||||
req_id, store = self._jobs.pop(job_id)
|
||||
if store:
|
||||
req_jobs = self._store_jobs[req_id]
|
||||
req_jobs.remove(job_id)
|
||||
if req_jobs:
|
||||
continue
|
||||
|
||||
if req_id in self._finished_reqs_waiting_for_store:
|
||||
self._finished_reqs_waiting_for_store.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
else:
|
||||
req_job = self._load_job[req_id]
|
||||
assert job_id == req_job
|
||||
del self._load_job[req_id]
|
||||
finished_recving.add(req_id)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
pending_req_jobs = self._store_jobs.get(req_id)
|
||||
if pending_req_jobs:
|
||||
self._finished_reqs_waiting_for_store.add(req_id)
|
||||
elif pending_req_jobs is not None:
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
|
||||
return finished_sending, finished_recving
|
||||
|
||||
|
||||
def yield_req_data(
|
||||
scheduler_output) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
|
||||
"""
|
||||
Yields:
|
||||
(req_id, new_block_id_groups, preempted)
|
||||
"""
|
||||
# new requests
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
yield req_data.req_id, req_data.block_ids, False
|
||||
|
||||
# cached requests
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
yield from zip(cached_reqs.req_ids, cached_reqs.new_block_ids,
|
||||
cached_reqs.resumed_from_preemption)
|
||||
@ -286,6 +286,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
|
||||
@ -126,6 +126,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -136,5 +137,5 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert experts is not None
|
||||
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
activation, global_num_experts, expert_map, a1q_scale,
|
||||
workspace13, workspace2, expert_tokens_meta,
|
||||
a2_scale, workspace13, workspace2, expert_tokens_meta,
|
||||
apply_router_weight_on_input)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.utils import cdiv, has_triton_kernels
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if TYPE_CHECKING and has_triton_kernels:
|
||||
if has_triton_kernels():
|
||||
from triton_kernels.matmul_ogs import PrecisionConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -241,6 +241,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -262,7 +263,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
|
||||
a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
|
||||
a1q_scale, a2_scale, self.ab_strides1, self.ab_strides2,
|
||||
self.c_strides1, self.c_strides2, workspace13, workspace2,
|
||||
expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
@ -705,6 +706,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor], # unused
|
||||
a2_scale: Optional[torch.Tensor], # unused
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
|
||||
@ -214,13 +214,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert a1q_scale is not None
|
||||
assert self.a2_scale is None
|
||||
assert a2_scale is None
|
||||
assert self.block_shape is not None
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
|
||||
@ -129,6 +129,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
|
||||
@ -688,6 +688,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -879,6 +880,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -970,7 +972,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
intermediate_cache1.view(-1, N))
|
||||
|
||||
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
|
||||
intermediate_cache2, self.a2_scale, max_num_tokens, E, N,
|
||||
intermediate_cache2, a2_scale, max_num_tokens, E, N,
|
||||
expert_num_tokens, self.quant_dtype, self.per_act_token_quant,
|
||||
self.block_shape)
|
||||
|
||||
|
||||
@ -1598,6 +1598,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -1690,7 +1691,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2, self.a2_scale, self.quant_dtype,
|
||||
intermediate_cache2, a2_scale, self.quant_dtype,
|
||||
self.per_act_token_quant, self.block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
|
||||
@ -179,6 +179,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
|
||||
@ -519,6 +519,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
@ -634,6 +635,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
@ -671,6 +673,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
@ -718,6 +721,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=self.fused_experts.a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@ -803,6 +807,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
@ -111,6 +111,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -134,6 +135,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_tokens_meta,
|
||||
|
||||
@ -103,6 +103,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
|
||||
@ -638,8 +638,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
return None
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
w1_scale = layer.w13_precision_config
|
||||
w2_scale = layer.w2_precision_config
|
||||
w1_scale = self.w13_precision_config
|
||||
w2_scale = self.w2_precision_config
|
||||
else:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
|
||||
@ -6,6 +6,8 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
from .common import apply_rotary_emb_torch
|
||||
|
||||
@ -30,9 +32,17 @@ class RotaryEmbedding(CustomOp):
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
# Flashinfer only supports head_size=64, 128, 256, 512.
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
|
||||
self.use_flashinfer = (self.enabled()
|
||||
and dtype in (torch.float16, torch.bfloat16)
|
||||
and current_platform.is_cuda()
|
||||
and has_flashinfer()
|
||||
and self.head_size in [64, 128, 256, 512])
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
if not self.use_flashinfer:
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
@ -57,6 +67,14 @@ class RotaryEmbedding(CustomOp):
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
|
||||
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
|
||||
# is expensive, so avoid calling it if possible
|
||||
if self.cos_sin_cache.device != query.device or \
|
||||
self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@ -94,15 +112,16 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if self.use_flashinfer:
|
||||
torch.ops.vllm.flashinfer_rotary_embedding(positions, query, key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
|
||||
# is expensive, so avoid calling it if possible
|
||||
if self.cos_sin_cache.device != query.device or \
|
||||
self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
@ -117,8 +136,7 @@ class RotaryEmbedding(CustomOp):
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||
dtype=query.dtype)
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
if key is None:
|
||||
|
||||
@ -6,6 +6,7 @@ import math
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
@ -103,3 +104,48 @@ def yarn_get_mscale(scale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def _flashinfer_rotary_embedding(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
"""Custom op wrapper for flashinfer's rotary embedding.
|
||||
|
||||
This is an in-place operation that modifies query and key tensors directly.
|
||||
"""
|
||||
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
|
||||
|
||||
apply_rope_with_cos_sin_cache_inplace(
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
head_size=head_size,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=is_neox,
|
||||
)
|
||||
|
||||
|
||||
def _flashinfer_rotary_embedding_fake(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
# Register flashinfer rotary embedding custom op
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_rotary_embedding",
|
||||
op_func=_flashinfer_rotary_embedding,
|
||||
mutates_args=["query", "key"], # These tensors are modified in-place
|
||||
fake_impl=_flashinfer_rotary_embedding_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -97,15 +97,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
assert key is not None
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
if self.cos_sin_cache.device != positions.device:
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||
positions.device)
|
||||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||
if offsets is not None else positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
@ -59,7 +59,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert key is not None
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
query_ = torch.view_as_complex(query.float().reshape(
|
||||
*query.shape[:-1], -1, 2))
|
||||
key_ = torch.view_as_complex(key.float().reshape(
|
||||
|
||||
@ -245,6 +245,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
@ -293,6 +294,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
53
vllm/v1/kv_offload/factory.py
Normal file
53
vllm/v1/kv_offload/factory.py
Normal file
@ -0,0 +1,53 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OffloadingSpecFactory:
|
||||
_registry: dict[str, Callable[[], type[OffloadingSpec]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_spec(cls, name: str, module_path: str,
|
||||
class_name: str) -> None:
|
||||
"""Register a spec with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[OffloadingSpec]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_spec(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
) -> OffloadingSpec:
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
assert kv_transfer_config is not None
|
||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
spec_name = extra_config.get("spec_name", "CPUOffloadingSpec")
|
||||
if spec_name in cls._registry:
|
||||
spec_cls = cls._registry[spec_name]()
|
||||
else:
|
||||
spec_module_path = extra_config.get("spec_module_path")
|
||||
if spec_module_path is None:
|
||||
raise ValueError(f"Unsupported spec type: {spec_name}")
|
||||
spec_module = importlib.import_module(spec_module_path)
|
||||
spec_cls = getattr(spec_module, spec_name)
|
||||
assert issubclass(spec_cls, OffloadingSpec)
|
||||
logger.info("Creating offloading spec with name: %s", spec_name)
|
||||
return spec_cls(config)
|
||||
|
||||
|
||||
# Register various specs here.
|
||||
61
vllm/v1/kv_offload/spec.py
Normal file
61
vllm/v1/kv_offload/spec.py
Normal file
@ -0,0 +1,61 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OffloadingSpec(ABC):
|
||||
"""Spec for an offloading connector"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
logger.warning(
|
||||
"Initializing OffloadingSpec. This API is experimental and "
|
||||
"subject to change in the future as we iterate the design.")
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
assert kv_transfer_config is not None
|
||||
self.extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
|
||||
self.gpu_block_size = vllm_config.cache_config.block_size
|
||||
self.offloaded_block_size = int(
|
||||
self.extra_config.get("block_size", self.gpu_block_size))
|
||||
|
||||
assert self.offloaded_block_size % self.gpu_block_size == 0
|
||||
|
||||
@abstractmethod
|
||||
def get_manager(self) -> OffloadingManager:
|
||||
"""
|
||||
Get an OffloadingManager that will be used
|
||||
by the scheduler-side offloading connector to track
|
||||
offloaded blocks and manage evictions.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_handlers(
|
||||
self, kv_caches: dict[str, torch.Tensor]
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
|
||||
OffloadingHandler]]:
|
||||
"""
|
||||
Get offloading handlers along with their respective src and dst types.
|
||||
|
||||
Args:
|
||||
kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.
|
||||
|
||||
Yields:
|
||||
Tuples of (src_type, dst_type, offloading_handler).
|
||||
"""
|
||||
pass
|
||||
Loading…
x
Reference in New Issue
Block a user