mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:05:28 +08:00
fix: NIXL connector transfers partial block to pass full multi-modal context (#21074)
Signed-off-by: GuanLuo <gluo@nvidia.com>
This commit is contained in:
parent
f7dcce7a4a
commit
16fb668b61
@ -173,9 +173,9 @@ def test_prompt_less_than_block_size():
|
|||||||
"""
|
"""
|
||||||
Test that we can handle case where prompt is < block.
|
Test that we can handle case where prompt is < block.
|
||||||
|
|
||||||
In this case, the P worker will send empty remote_block_ids.
|
In this case, the P worker will still send remote_block_ids of the
|
||||||
The D worker should not schedule an async read in this case,
|
partial block. The D worker should schedule an async read
|
||||||
since there is nothing to pull.
|
in this case.
|
||||||
"""
|
"""
|
||||||
vllm_config = create_vllm_config()
|
vllm_config = create_vllm_config()
|
||||||
scheduler = create_scheduler(vllm_config)
|
scheduler = create_scheduler(vllm_config)
|
||||||
@ -184,22 +184,20 @@ def test_prompt_less_than_block_size():
|
|||||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
NUM_TOKENS = int(BLOCK_SIZE * 0.5)
|
NUM_TOKENS = int(BLOCK_SIZE * 0.5)
|
||||||
|
|
||||||
# Request will have 0 remote blocks.
|
# Request will have 1 partial remote block.
|
||||||
request = create_request(request_id=1,
|
request = create_request(request_id=1,
|
||||||
num_tokens=NUM_TOKENS,
|
num_tokens=NUM_TOKENS,
|
||||||
do_remote_prefill=True,
|
do_remote_prefill=True,
|
||||||
num_remote_blocks=0)
|
num_remote_blocks=1)
|
||||||
scheduler.add_request(request)
|
scheduler.add_request(request)
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
|
|
||||||
# This request should not have to read async.
|
# This request will read async.
|
||||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||||
assert kv_connector_metadata is not None
|
assert kv_connector_metadata is not None
|
||||||
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
|
assert isinstance(kv_connector_metadata, NixlConnectorMetadata)
|
||||||
assert len(kv_connector_metadata.reqs_to_recv) == 0
|
assert len(kv_connector_metadata.reqs_to_recv) == 1
|
||||||
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||||
# This request should be scheduled regularly.
|
|
||||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
|
||||||
|
|
||||||
|
|
||||||
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||||
|
|||||||
@ -121,13 +121,18 @@ def test_short_prompt_lifecycle():
|
|||||||
model_runner_output = create_model_runner_output(reqs=[request])
|
model_runner_output = create_model_runner_output(reqs=[request])
|
||||||
|
|
||||||
# (1c): update_from_output()
|
# (1c): update_from_output()
|
||||||
# Since tokens < block_size, there will be no kv xfer.
|
# Even though tokens < block_size, there will be kv xfer for partial block.
|
||||||
# So this should be cleaned up immediately.
|
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
_ = scheduler.update_from_output(scheduler_output, model_runner_output)
|
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||||
|
|
||||||
|
assert (len(kv_transfer_params["remote_block_ids"]) == 1)
|
||||||
|
|
||||||
# Confirm we do not have any memory leaks after req lifecycle.
|
# Confirm we do not have any memory leaks after req lifecycle.
|
||||||
# We need one more call to schedule() to clear data for persistent batch.
|
# We need to mark sending finish to clear data for persistent batch.
|
||||||
_ = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
|
model_runner_output.finished_sending = [request.request_id]
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert_scheduler_empty(scheduler)
|
assert_scheduler_empty(scheduler)
|
||||||
|
|
||||||
|
|
||||||
@ -169,16 +174,16 @@ def test_prefix_cache_lifecycle():
|
|||||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||||
|
|
||||||
# Ensure we send all block ids, even if there is a cache hit.
|
# Ensure we send all block ids, including the partial blocks,
|
||||||
|
# even if there is a cache hit.
|
||||||
assert (len(
|
assert (len(
|
||||||
kv_transfer_params["remote_block_ids"]) == NUM_EXTERNAL_FULL_BLOCKS)
|
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
|
||||||
|
1))
|
||||||
|
|
||||||
# STEP (2): Ensure it is freed.
|
# STEP (2): Ensure it is freed.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
scheduler.schedule()
|
|
||||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||||
finished_sending=[request_remote.request_id])
|
finished_sending=[request_remote.request_id])
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
_ = scheduler.schedule()
|
|
||||||
assert_scheduler_empty(scheduler)
|
assert_scheduler_empty(scheduler)
|
||||||
|
|||||||
@ -362,7 +362,7 @@ def test_cannot_schedule_after_recv():
|
|||||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
# Prompt will use 2 blocks + 1 block after we schedule.
|
# Prompt will use 2 blocks + 1 block after we schedule.
|
||||||
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||||
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
|
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||||
|
|
||||||
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
|
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
|
||||||
request_remote = create_request(request_id=2,
|
request_remote = create_request(request_id=2,
|
||||||
@ -393,14 +393,24 @@ def test_cannot_schedule_after_recv():
|
|||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 1
|
||||||
|
|
||||||
# Step 4: try to schedule, not enough blocks.
|
# Step 4: try to schedule, remote request is put to running list
|
||||||
|
# because the transfer is completed.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
reqs=[request_normal, request_remote])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
assert len(scheduler.running) == 2
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
|
||||||
|
# Step 5: Remote request will be put back to waiting list
|
||||||
|
# because it needs new block to hold generated token.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 1
|
||||||
|
|
||||||
# Step 5: finish the request, free it.
|
# Step 6: finish the request, free it.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
model_runner_output = create_model_runner_output(reqs=[request_normal],
|
model_runner_output = create_model_runner_output(reqs=[request_normal],
|
||||||
use_eos=True)
|
use_eos=True)
|
||||||
@ -408,15 +418,99 @@ def test_cannot_schedule_after_recv():
|
|||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 1
|
||||||
|
|
||||||
# Step 6: now we can schedule (with 2 blocks computed).
|
# Step 7: now we can schedule (with 2 blocks computed),
|
||||||
|
# request is retrieved from preempted list.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||||
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
|
assert (scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] ==
|
||||||
NUM_PROMPT_BLOCKS * BLOCK_SIZE)
|
NUM_PROMPT_BLOCKS * BLOCK_SIZE)
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.running) == 1
|
assert len(scheduler.running) == 1
|
||||||
assert len(scheduler.waiting) == 0
|
assert len(scheduler.waiting) == 0
|
||||||
|
|
||||||
|
# Step 8: free everything.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output(reqs=[request_remote],
|
||||||
|
use_eos=True)
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
_ = scheduler.schedule()
|
||||||
|
assert_scheduler_empty(scheduler)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cannot_recv():
|
||||||
|
"""
|
||||||
|
Test that we can handle no schedule KV block transfer due to not
|
||||||
|
enough remaining KV blocks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# NOTE: the KVCacheManager will use 1 null block.
|
||||||
|
# So there are 5 total working blocks.
|
||||||
|
TOTAL_NUM_BLOCKS = 6
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS)
|
||||||
|
|
||||||
|
# Prime the KVCache.
|
||||||
|
NUM_PROMPT_BLOCKS = 2
|
||||||
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||||
|
# Prompt will use 2 blocks + 1 block after we schedule.
|
||||||
|
NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS)
|
||||||
|
NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5))
|
||||||
|
|
||||||
|
request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL)
|
||||||
|
request_remote = create_request(request_id=2,
|
||||||
|
num_tokens=NUM_TOKENS_REMOTE,
|
||||||
|
do_remote_prefill=True)
|
||||||
|
|
||||||
|
# STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
|
||||||
|
scheduler.add_request(request_normal)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
|
||||||
|
# Step 2: 3 blocks are in use,
|
||||||
|
# need 3 new for remote blocks but only 2 are available.
|
||||||
|
scheduler.add_request(request_remote)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output(reqs=[request_normal])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
# Should not have KV transfer in progress.
|
||||||
|
assert (request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS)
|
||||||
|
|
||||||
|
# Step 3: finish the request, free it.
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output(reqs=[request_normal],
|
||||||
|
use_eos=True)
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
|
||||||
|
# Step 4: now we can initiate KV transfer (with 2 blocks computed).
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output(reqs=[])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
assert (request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
|
||||||
|
|
||||||
|
# Step 5: finish recving (5 blocks in use)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output(
|
||||||
|
reqs=[], finished_recving=[request_remote.request_id])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
|
||||||
|
# Step 6: schedule remote request
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
|
||||||
# Step 7: free everything.
|
# Step 7: free everything.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
model_runner_output = create_model_runner_output(reqs=[request_remote],
|
model_runner_output = create_model_runner_output(reqs=[request_remote],
|
||||||
|
|||||||
@ -29,7 +29,7 @@ from vllm.distributed.utils import divide
|
|||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend, current_platform
|
from vllm.platforms import _Backend, current_platform
|
||||||
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
|
from vllm.utils import make_zmq_path, make_zmq_socket
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.request import RequestStatus
|
from vllm.v1.request import RequestStatus
|
||||||
|
|
||||||
@ -275,10 +275,7 @@ class NixlConnectorScheduler:
|
|||||||
|
|
||||||
if params is not None and params.get("do_remote_prefill"):
|
if params is not None and params.get("do_remote_prefill"):
|
||||||
# Remote prefill: get all prompt blocks from remote.
|
# Remote prefill: get all prompt blocks from remote.
|
||||||
assert num_computed_tokens % self.block_size == 0
|
count = len(request.prompt_token_ids) - num_computed_tokens
|
||||||
rounded_num_prompt_tokens = round_down(
|
|
||||||
len(request.prompt_token_ids), self.block_size)
|
|
||||||
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
|
|
||||||
if count > 0:
|
if count > 0:
|
||||||
return count, True
|
return count, True
|
||||||
|
|
||||||
@ -301,18 +298,16 @@ class NixlConnectorScheduler:
|
|||||||
# NOTE: when accelerator is not directly supported by Nixl,
|
# NOTE: when accelerator is not directly supported by Nixl,
|
||||||
# prefilled blocks need to be saved to host memory before transfer.
|
# prefilled blocks need to be saved to host memory before transfer.
|
||||||
|
|
||||||
# figure out full computed blocks to save
|
# save all blocks
|
||||||
block_ids = blocks.get_block_ids()[0]
|
block_ids = blocks.get_block_ids()[0]
|
||||||
all_full = request.num_tokens % self.block_size == 0
|
|
||||||
full_block_ids = (block_ids if all_full else block_ids[:-1])
|
|
||||||
# TODO: skip the blocks that are already in the host xfer buffer.
|
# TODO: skip the blocks that are already in the host xfer buffer.
|
||||||
# Currently, the host xfer buffer block is 1-to-1 mapped to device
|
# Currently, the host xfer buffer block is 1-to-1 mapped to device
|
||||||
# kv blocks, so host blocks won't be flushed as long as its device
|
# kv blocks, so host blocks won't be flushed as long as its device
|
||||||
# block is not overwritten; and it will be safe to skip saving them
|
# block is not overwritten; and it will be safe to skip saving them
|
||||||
# to host xfer buffer.
|
# to host xfer buffer.
|
||||||
if full_block_ids:
|
if block_ids:
|
||||||
self._reqs_need_save[request.request_id] = \
|
self._reqs_need_save[request.request_id] = \
|
||||||
(request, full_block_ids)
|
(request, block_ids)
|
||||||
elif params.get("do_remote_prefill"):
|
elif params.get("do_remote_prefill"):
|
||||||
if params.get("remote_block_ids"):
|
if params.get("remote_block_ids"):
|
||||||
if all(p in params for p in ("remote_engine_id", "remote_host",
|
if all(p in params for p in ("remote_engine_id", "remote_host",
|
||||||
@ -401,12 +396,9 @@ class NixlConnectorScheduler:
|
|||||||
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
# Get computed blocks.
|
# TODO: check whether block_ids actually ever be 0. If not we could
|
||||||
all_full = request.num_computed_tokens % self.block_size == 0
|
# remove the conditional below
|
||||||
computed_block_ids = block_ids if all_full else block_ids[:-1]
|
delay_free_blocks = len(block_ids) > 0
|
||||||
|
|
||||||
# If prompt < block_size, no xfer so free blocks immediately.
|
|
||||||
delay_free_blocks = len(computed_block_ids) > 0
|
|
||||||
|
|
||||||
if delay_free_blocks:
|
if delay_free_blocks:
|
||||||
# Prefill request on remote. It will be read from D upon completion
|
# Prefill request on remote. It will be read from D upon completion
|
||||||
@ -416,7 +408,7 @@ class NixlConnectorScheduler:
|
|||||||
return delay_free_blocks, dict(
|
return delay_free_blocks, dict(
|
||||||
do_remote_prefill=True,
|
do_remote_prefill=True,
|
||||||
do_remote_decode=False,
|
do_remote_decode=False,
|
||||||
remote_block_ids=computed_block_ids,
|
remote_block_ids=block_ids,
|
||||||
remote_engine_id=self.engine_id,
|
remote_engine_id=self.engine_id,
|
||||||
remote_host=self.side_channel_host,
|
remote_host=self.side_channel_host,
|
||||||
remote_port=self.side_channel_port,
|
remote_port=self.side_channel_port,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user