[Misc][Refactor] Introduce ExecuteModelData (#4540)

This commit is contained in:
Cody Yu 2024-05-03 17:47:07 -07:00 committed by GitHub
parent 344bf7cd2d
commit bc8ad68455
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 355 additions and 511 deletions

View File

@ -5,13 +5,12 @@ import pytest
import torch import torch
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from .utils import (assert_logprobs_dict_allclose, create_batch, from .utils import (assert_logprobs_dict_allclose, create_batch,
create_execute_model_data,
create_seq_group_metadata_from_prompts, create_worker, create_seq_group_metadata_from_prompts, create_worker,
patch_execute_model_with_seeds, zero_kv_cache) patch_execute_model_with_seeds, zero_kv_cache)
@ -105,31 +104,32 @@ def test_same_output_for_single_step():
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
multi_step_execute_model_data = create_execute_model_data( multi_step_seq_group = create_seq_group_metadata_from_prompts(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts, prompts,
num_gpu_blocks, num_gpu_blocks,
block_size, block_size,
final_prompt_lens=final_prompt_lens)) final_prompt_lens=final_prompt_lens)
single_step_execute_model_data = create_execute_model_data(
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens))
zero_kv_cache(multi_step_worker.cache_engine) zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed) set_random_seed(seed)
actual_output, _ = multi_step_worker.sampler_output( actual_output, _ = multi_step_worker.sampler_output(
**multi_step_execute_model_data.to_dict(), sample_len=num_steps) execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=multi_step_seq_group),
sample_len=num_steps)
assert len(actual_output) == num_steps assert len(actual_output) == num_steps
actual_output = actual_output[0] actual_output = actual_output[0]
single_step_seq_group = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
final_prompt_lens=final_prompt_lens)
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
set_random_seed(seed) set_random_seed(seed)
expected_output = worker.execute_model( expected_output = worker.execute_model(
**single_step_execute_model_data.to_dict(), )[0] execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=single_step_seq_group))[0]
actual_token_ids = [ actual_token_ids = [
output.samples[0].output_token for output in actual_output output.samples[0].output_token for output in actual_output
@ -193,19 +193,20 @@ def test_same_output_for_multi_step():
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
continuations = [[1] for _ in prompts] continuations = [[1] for _ in prompts]
execute_model_data = create_execute_model_data( seq_group_metadata_list = create_seq_group_metadata_from_prompts(
create_seq_group_metadata_from_prompts(
prompts, prompts,
num_gpu_blocks, num_gpu_blocks,
block_size, block_size,
continuations=continuations, continuations=continuations,
final_prompt_lens=final_prompt_lens), ) final_prompt_lens=final_prompt_lens)
# Run multi-step. # Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine) zero_kv_cache(multi_step_worker.cache_engine)
set_random_seed(seed) set_random_seed(seed)
multi_step_output, _ = multi_step_worker.sampler_output( multi_step_output, _ = multi_step_worker.sampler_output(
**execute_model_data.to_dict(), sample_len=num_steps) execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps)
# Run single-step repeatedly. # Run single-step repeatedly.
zero_kv_cache(worker.cache_engine) zero_kv_cache(worker.cache_engine)
@ -215,16 +216,16 @@ def test_same_output_for_multi_step():
for _ in multi_step_output: for _ in multi_step_output:
execute_model_data = create_execute_model_data( seq_group_metadata_list = create_seq_group_metadata_from_prompts(
create_seq_group_metadata_from_prompts(
prompts, prompts,
num_gpu_blocks, num_gpu_blocks,
block_size, block_size,
continuations=continuations, continuations=continuations,
final_prompt_lens=final_prompt_lens)) final_prompt_lens=final_prompt_lens)
single_step_output.extend( single_step_output.extend(
worker.execute_model(**execute_model_data.to_dict(), )) worker.execute_model(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list)))
# Append output tokens to new sequence data. # Append output tokens to new sequence data.
for i, seq_group_output in enumerate(single_step_output[-1]): for i, seq_group_output in enumerate(single_step_output[-1]):
@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len():
) for _ in range(k) ) for _ in range(k)
], True ], True
execute_model_data, _, _ = create_batch(batch_size, k) seq_group_metadata_list, _, _ = create_batch(batch_size, k)
proposals = proposer.get_proposals( proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
**execute_model_data.to_dict(), seq_group_metadata_list=seq_group_metadata_list,
proposal_len=k, num_lookahead_slots=k), )
)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations():
max_proposal_len=prompt_len + k - 1, max_proposal_len=prompt_len + k - 1,
) )
execute_model_data, _, _ = create_batch(batch_size, seq_group_metadata_list, _, _ = create_batch(batch_size,
k, k,
prompt_len=prompt_len) prompt_len=prompt_len)
proposals = proposer.get_proposals( proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
**execute_model_data.to_dict(), seq_group_metadata_list=seq_group_metadata_list,
proposal_len=k, num_lookahead_slots=k), )
)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k():
) for _ in range(k) ) for _ in range(k)
], True ], True
execute_model_data, _, _ = create_batch( seq_group_metadata_list, _, _ = create_batch(
batch_size, batch_size,
k, k,
prompt_len=prompt_len, prompt_len=prompt_len,
prev_output_token_len=prev_output_token_len, prev_output_token_len=prev_output_token_len,
) )
proposals = proposer.get_proposals( proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
**execute_model_data.to_dict(), seq_group_metadata_list=seq_group_metadata_list,
proposal_len=k, num_lookahead_slots=k), )
)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)

View File

@ -1,10 +1,10 @@
import torch import torch
from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.ngram_worker import NGramWorker
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from .utils import (create_execute_model_data, from .utils import create_seq_group_metadata_from_prompts, create_worker
create_seq_group_metadata_from_prompts, create_worker)
def test_ngram_algo_correctness_for_single_no_match(): def test_ngram_algo_correctness_for_single_no_match():
@ -44,17 +44,15 @@ def test_ngram_algo_correctness_for_single_no_match():
proposal_len = 5 proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list = create_seq_group_metadata_from_prompts( seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts, prompts,
num_gpu_blocks, num_gpu_blocks,
block_size, block_size,
final_prompt_lens=final_prompt_lens)) final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals( proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
**ngram_sampler_output_data.to_dict(), seq_group_metadata_list=seq_group_metadata_list,
proposal_len=proposal_len, num_lookahead_slots=proposal_len), )
)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
@ -113,17 +111,15 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
proposal_len = 5 proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list = create_seq_group_metadata_from_prompts( seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts, prompts,
num_gpu_blocks, num_gpu_blocks,
block_size, block_size,
final_prompt_lens=final_prompt_lens)) final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals( proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
**ngram_sampler_output_data.to_dict(), seq_group_metadata_list=seq_group_metadata_list,
proposal_len=proposal_len, num_lookahead_slots=proposal_len), )
)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)
@ -185,17 +181,15 @@ def test_ngram_algo_correctness_for_batches_match_all():
proposal_len = 5 proposal_len = 5
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
ngram_sampler_output_data = create_execute_model_data(
seq_group_metadata_list = create_seq_group_metadata_from_prompts( seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts, prompts,
num_gpu_blocks, num_gpu_blocks,
block_size, block_size,
final_prompt_lens=final_prompt_lens)) final_prompt_lens=final_prompt_lens)
proposals = proposer.get_proposals( proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
**ngram_sampler_output_data.to_dict(), seq_group_metadata_list=seq_group_metadata_list,
proposal_len=proposal_len, num_lookahead_slots=proposal_len), )
)
assert torch.is_tensor(proposals.proposal_token_ids) assert torch.is_tensor(proposals.proposal_token_ids)
assert torch.is_tensor(proposals.proposal_probs) assert torch.is_tensor(proposals.proposal_probs)

View File

@ -7,7 +7,7 @@ import torch
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.metrics import (AsyncMetricsCollector, from vllm.spec_decode.metrics import (AsyncMetricsCollector,
SpecDecodeWorkerMetrics) SpecDecodeWorkerMetrics)
@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
split_num_cache_blocks_evenly) split_num_cache_blocks_evenly)
from .utils import (ExecuteModelData, create_batch, create_sampler_output_list, from .utils import create_batch, create_sampler_output_list, mock_worker
mock_worker)
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@ -36,24 +35,19 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
exception_secret = 'artificial stop' exception_secret = 'artificial stop'
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
execute_model_data, _, _ = create_batch(batch_size, k) seq_group_metadata_list, _, _ = create_batch(batch_size, k)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), worker.execute_model(execute_model_req=execute_model_req)
num_lookahead_slots=k)
call_args_list = draft_worker.get_spec_proposals.call_args_list call_args_list = draft_worker.get_spec_proposals.call_args_list
assert len(call_args_list) == 1 assert len(call_args_list) == 1
for args, _ in call_args_list: for args, _ in call_args_list:
(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, actual_execute_model_data = args[0]
blocks_to_copy, actual_k) = args assert actual_execute_model_data == execute_model_req
actual_execute_model_data = ExecuteModelData(seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy)
assert actual_execute_model_data == execute_model_data
assert actual_k == k
@pytest.mark.parametrize('k', [1, 2, 6]) @pytest.mark.parametrize('k', [1, 2, 6])
@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
proposal_lens = torch.ones(batch_size, dtype=torch.int64, proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k device='cuda') * k
execute_model_data, prompts, prev_output_tokens = create_batch( seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
batch_size, k) batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals( draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
@ -105,20 +99,20 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
target_worker.execute_model.side_effect = ValueError(exception_secret) target_worker.execute_model.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), worker.execute_model(execute_model_req=ExecuteModelRequest(
num_lookahead_slots=k) seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
seen_contexts = [] seen_contexts = []
call_args_list = target_worker.execute_model.call_args_list call_args_list = target_worker.execute_model.call_args_list
assert len(call_args_list) == 1 assert len(call_args_list) == 1
for args, kwargs in call_args_list: for _, kwargs in call_args_list:
target_execute_model_data = ExecuteModelData.from_dict(kwargs) seq_group_metadata_list = kwargs[
"execute_model_req"].seq_group_metadata_list
assert len(target_execute_model_data.seq_group_metadata_list) == ( assert len(seq_group_metadata_list) == (k + 1) * batch_size
k + 1) * batch_size for seq_group_metadata in seq_group_metadata_list:
for seq_group_metadata in (
target_execute_model_data.seq_group_metadata_list):
for seq_data in seq_group_metadata.seq_data.values(): for seq_data in seq_group_metadata.seq_data.values():
seen_contexts.append(seq_data.get_token_ids()) seen_contexts.append(seq_data.get_token_ids())
@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
proposal_lens = torch.ones(batch_size, dtype=torch.int64, proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k) seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals( draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids, proposal_token_ids=proposal_token_ids,
@ -207,8 +201,9 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
rejection_sampler.side_effect = ValueError(exception_secret) rejection_sampler.side_effect = ValueError(exception_secret)
with pytest.raises(ValueError, match=exception_secret): with pytest.raises(ValueError, match=exception_secret):
worker.execute_model(**execute_model_data.to_dict(), worker.execute_model(execute_model_req=ExecuteModelRequest(
num_lookahead_slots=k) seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert len(rejection_sampler.call_args_list) == 1 assert len(rejection_sampler.call_args_list) == 1
_, kwargs = rejection_sampler.call_args_list[0] _, kwargs = rejection_sampler.call_args_list[0]
@ -262,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
proposal_lens = torch.ones(batch_size, dtype=torch.int64, proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k) seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals( draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids, proposal_token_ids=proposal_token_ids,
@ -302,8 +297,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
rejection_sampler.return_value = rejection_sampler_output rejection_sampler.return_value = rejection_sampler_output
output = worker.execute_model(**execute_model_data.to_dict(), output = worker.execute_model(execute_model_req=ExecuteModelRequest(
num_lookahead_slots=k) seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
expected_output = create_sampler_output_list( expected_output = create_sampler_output_list(
token_ids=rejection_sampler_output.transpose(0, 1), token_ids=rejection_sampler_output.transpose(0, 1),
@ -312,7 +308,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
seq_ids = [ seq_ids = [
next(iter(seq_group_metadata.seq_data.keys())) next(iter(seq_group_metadata.seq_data.keys()))
for seq_group_metadata in execute_model_data.seq_group_metadata_list for seq_group_metadata in seq_group_metadata_list
] ]
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids} actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids} expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
@ -383,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
proposal_lens = torch.ones(batch_size, dtype=torch.int64, proposal_lens = torch.ones(batch_size, dtype=torch.int64,
device='cuda') * k device='cuda') * k
execute_model_data, _, _ = create_batch(batch_size, k) seq_group_metadata_list, _, _ = create_batch(batch_size, k)
draft_worker.get_spec_proposals.return_value = SpeculativeProposals( draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
proposal_token_ids=proposal_token_ids, proposal_token_ids=proposal_token_ids,
@ -428,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
metrics_collector.maybe_collect_rejsample_metrics.return_value = ( metrics_collector.maybe_collect_rejsample_metrics.return_value = (
mock_rejsample_metrics) mock_rejsample_metrics)
output = worker.execute_model(**execute_model_data.to_dict(), output = worker.execute_model(execute_model_req=ExecuteModelRequest(
num_lookahead_slots=k) seq_group_metadata_list=seq_group_metadata_list,
num_lookahead_slots=k))
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
call_args_list = ( call_args_list = (
@ -462,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector) metrics_collector)
execute_model_data, prompts, prev_output_tokens = create_batch( seq_group_metadata_list, _, _ = create_batch(batch_size,
batch_size, k, prev_output_token_len=0) k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(**execute_model_data.to_dict(), out = worker.execute_model(execute_model_req=execute_model_req)
num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}" assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None" assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[ assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None" 0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with( draft_worker.execute_model.assert_called_once_with(execute_model_req)
**execute_model_data.to_dict()) target_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
@pytest.mark.parametrize('k', [0, 5]) @pytest.mark.parametrize('k', [0, 5])
@ -503,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int):
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler, worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
metrics_collector) metrics_collector)
execute_model_data, prompts, prev_output_tokens = create_batch( seq_group_metadata_list, _, _ = create_batch(batch_size,
batch_size, k, prev_output_token_len=0) k,
prev_output_token_len=0)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
out = worker.execute_model(**execute_model_data.to_dict(), out = worker.execute_model(execute_model_req=execute_model_req)
num_lookahead_slots=k)
assert len(out) == 1, f"expected only one token output when {k=}" assert len(out) == 1, f"expected only one token output when {k=}"
assert out[0].probs is None, "expect gpu tensor references to be None" assert out[0].probs is None, "expect gpu tensor references to be None"
assert out[ assert out[
0].sampled_tokens is None, "expect gpu tensor references to be None" 0].sampled_tokens is None, "expect gpu tensor references to be None"
draft_worker.execute_model.assert_called_once_with( draft_worker.execute_model.assert_called_once_with(execute_model_req)
**execute_model_data.to_dict()) target_worker.execute_model.assert_called_once_with(execute_model_req)
target_worker.execute_model.assert_called_once_with(
**execute_model_data.to_dict())
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup

View File

@ -1,4 +1,3 @@
from dataclasses import dataclass, fields
from itertools import count from itertools import count
from typing import Dict, Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
@dataclass
class ExecuteModelData:
"""Helper data structure which facilitates cleaner tests.
"""
seq_group_metadata_list: List[SequenceGroupMetadata]
blocks_to_swap_in: Dict[int, int]
blocks_to_swap_out: Dict[int, int]
blocks_to_copy: Dict[int, List[int]]
def to_dict(self):
return dict(
(field.name, getattr(self, field.name)) for field in fields(self))
@classmethod
def from_dict(cls, d):
cleaned = dict((field.name, d[field.name]) for field in fields(cls))
return cls(**cleaned)
def round_up_to_next_block(seq_len: int, block_size: int) -> int: def round_up_to_next_block(seq_len: int, block_size: int) -> int:
return (seq_len + block_size - 1) // block_size return (seq_len + block_size - 1) // block_size
def create_execute_model_data(
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, int]] = None,
) -> ExecuteModelData:
if blocks_to_swap_in is None:
blocks_to_swap_in = {}
if blocks_to_swap_out is None:
blocks_to_swap_out = {}
if blocks_to_copy is None:
blocks_to_copy = {}
return ExecuteModelData(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
def mock_worker(cls=None, def mock_worker(cls=None,
vocab_size: int = 30_000, vocab_size: int = 30_000,
max_model_len: int = 2048, max_model_len: int = 2048,
@ -258,8 +217,7 @@ def create_batch(batch_size,
for prompt, prev_output_token in zip(prompts, prev_output_tokens) for prompt, prev_output_token in zip(prompts, prev_output_tokens)
] ]
execute_model_data = create_execute_model_data( seq_group_metadata_list = create_seq_group_metadata_from_prompts(
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks, prompts, num_gpu_blocks, block_size, final_prompt_lens,
block_size, final_prompt_lens, prev_output_tokens, seq_ids)
prev_output_tokens, seq_ids), ) return seq_group_metadata_list, prompts, prev_output_tokens
return execute_model_data, prompts, prev_output_tokens

View File

@ -1,6 +1,7 @@
import torch import torch
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
@ -54,10 +55,14 @@ def test_swap() -> None:
# Test swap out. # Test swap out.
blocks_to_swap_out = {3: 72, 56: 35, 84: 34} blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
worker.execute_model(seq_group_metadata_list=[], execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=[],
blocks_to_swap_in={}, blocks_to_swap_in={},
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy={}) blocks_to_copy={},
)
worker.execute_model(execute_model_req=execute_model_req)
for i in range(num_layers): for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i] gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i]
@ -66,14 +71,19 @@ def test_swap() -> None:
assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
# Test swap in. # Test swap in.
blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71} execute_model_req.blocks_to_swap_out = {}
worker.execute_model(seq_group_metadata_list=[], execute_model_req.blocks_to_swap_in = {
blocks_to_swap_in=blocks_to_swap_in, 19: 45,
blocks_to_swap_out={}, 67: 23,
blocks_to_copy={}) 12: 78,
40: 99,
1: 71
}
worker.execute_model(execute_model_req=execute_model_req)
for i in range(num_layers): for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i] gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in blocks_to_swap_in.items(): for src, dst in execute_model_req.blocks_to_swap_in.items():
assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) assert allclose(gpu_value_cache[dst], cpu_value_cache[src])

View File

@ -128,6 +128,8 @@ class SchedulerOutputs:
ignored_seq_groups: List[SequenceGroup] ignored_seq_groups: List[SequenceGroup]
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
num_lookahead_slots: int num_lookahead_slots: int
# The number of requests in the running queue
running_queue_size: int
def __post_init__(self): def __post_init__(self):
# Swap in and swap out should never happen at the same time. # Swap in and swap out should never happen at the same time.
@ -797,6 +799,7 @@ class Scheduler:
ignored_seq_groups=prefills.ignored_seq_groups + ignored_seq_groups=prefills.ignored_seq_groups +
swapped_in.infeasible_seq_groups, swapped_in.infeasible_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots, num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
) )
def _schedule_chunked_prefill(self): def _schedule_chunked_prefill(self):
@ -883,6 +886,7 @@ class Scheduler:
swapped_in.blocks_to_copy), swapped_in.blocks_to_copy),
ignored_seq_groups=prefills.ignored_seq_groups, ignored_seq_groups=prefills.ignored_seq_groups,
num_lookahead_slots=running_scheduled.num_lookahead_slots, num_lookahead_slots=running_scheduled.num_lookahead_slots,
running_queue_size=len(self.running),
) )
def _schedule(self) -> SchedulerOutputs: def _schedule(self) -> SchedulerOutputs:

View File

@ -16,7 +16,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData, SamplerOutput from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__) logger = init_logger(__name__)
@ -210,12 +210,16 @@ class _AsyncLLMEngine(LLMEngine):
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Execute the model. # Execute the model.
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
output = await self.model_executor.execute_model_async( output = await self.model_executor.execute_model_async(
seq_group_metadata_list, execute_model_req)
scheduler_outputs.blocks_to_swap_in,
scheduler_outputs.blocks_to_swap_out,
scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
else: else:
output = [] output = []

View File

@ -22,8 +22,8 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput,
SequenceGroup, SequenceGroupMetadata, Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus) SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
@ -583,12 +583,16 @@ class LLMEngine:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
output = self.model_executor.execute_model( execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy, blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots) num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
else: else:
output = [] output = []

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Set, Tuple from typing import List, Set, Tuple
import torch import torch
@ -7,7 +7,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
@ -72,18 +72,10 @@ class CPUExecutor(ExecutorBase):
logger.info("# CPU blocks: %d", num_gpu_blocks) logger.info("# CPU blocks: %d", num_gpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self, def execute_model(
seq_group_metadata_list: List[SequenceGroupMetadata], self,
blocks_to_swap_in: Dict[int, int], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_out: Dict[int, int], output = self.driver_worker.execute_model(execute_model_req)
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return output return output
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
@ -105,18 +97,9 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
async def execute_model_async( async def execute_model_async(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Dict[int, int], output = await make_async(self.driver_worker.execute_model
blocks_to_swap_out: Dict[int, int], )(execute_model_req=execute_model_req, )
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=num_lookahead_slots)
return output return output
async def check_health_async(self) -> None: async def check_health_async(self) -> None:

View File

@ -1,11 +1,11 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig, ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, VisionLanguageConfig) SpeculativeConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
class ExecutorBase(ABC): class ExecutorBase(ABC):
@ -68,12 +68,9 @@ class ExecutorBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def execute_model(self, def execute_model(
seq_group_metadata_list: List[SequenceGroupMetadata], self,
blocks_to_swap_in: Dict[int, int], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences.""" """Executes at least one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
@ -108,12 +105,7 @@ class ExecutorAsyncBase(ExecutorBase):
@abstractmethod @abstractmethod
async def execute_model_async( async def execute_model_async(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
"""Executes one model step on the given sequences.""" """Executes one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError

View File

@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) make_async)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
@ -118,19 +118,8 @@ class GPUExecutor(ExecutorBase):
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Dict[int, int], output = self.driver_worker.execute_model(execute_model_req)
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=num_lookahead_slots,
)
return output return output
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
@ -154,16 +143,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async def execute_model_async( async def execute_model_async(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)( output = await make_async(self.driver_worker.execute_model
seq_group_metadata_list=seq_group_metadata_list, )(execute_model_req=execute_model_req, )
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
num_lookahead_slots=num_lookahead_slots)
return output return output

View File

@ -1,9 +1,9 @@
from typing import Dict, List, Set, Tuple from typing import List, Set, Tuple
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import make_async from vllm.utils import make_async
logger = init_logger(__name__) logger = init_logger(__name__)
@ -45,20 +45,18 @@ class NeuronExecutor(ExecutorBase):
""" """
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self, def execute_model(
seq_group_metadata_list: List[SequenceGroupMetadata], self,
blocks_to_swap_in: Dict[int, int], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_out: Dict[int, int], assert (execute_model_req.blocks_to_swap_in == {}
blocks_to_copy: Dict[int, List[int]], and execute_model_req.blocks_to_swap_out == {}
num_lookahead_slots: int) -> List[SamplerOutput]: and execute_model_req.blocks_to_copy == {}), (
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
and blocks_to_copy == {}), (
"Cache operations are not supported for Neuron backend.") "Cache operations are not supported for Neuron backend.")
assert num_lookahead_slots == 0, ( assert execute_model_req.num_lookahead_slots == 0, (
"lookahead not supported for Neuron backend.") "lookahead not supported for Neuron backend.")
output = self.driver_worker.execute_model( output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list) execute_model_req.seq_group_metadata_list)
return output return output
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
@ -80,14 +78,11 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
async def execute_model_async( async def execute_model_async(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)( output = await make_async(
seq_group_metadata_list=seq_group_metadata_list, ) self.driver_worker.execute_model
)(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
return output return output
async def check_health_async(self) -> None: async def check_health_async(self) -> None:

View File

@ -10,7 +10,7 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync) DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
@ -166,21 +166,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers) max_parallel_loading_workers)
def execute_model(self, def execute_model(
seq_group_metadata_list: List[SequenceGroupMetadata], self,
blocks_to_swap_in: Dict[int, int], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int = 0) -> List[SamplerOutput]:
all_outputs = self._run_workers( all_outputs = self._run_workers(
"execute_model", "execute_model",
driver_kwargs={ driver_kwargs={"execute_model_req": execute_model_req},
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
"num_lookahead_slots": num_lookahead_slots,
},
use_ray_compiled_dag=USE_RAY_COMPILED_DAG) use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.

View File

@ -1,7 +1,7 @@
"""Sequence and its related classes.""" """Sequence and its related classes."""
import copy import copy
import enum import enum
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Union from typing import TYPE_CHECKING, Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
@ -734,3 +734,33 @@ class SamplerOutput:
f"sampled_token_probs={sampled_token_probs_repr}, " f"sampled_token_probs={sampled_token_probs_repr}, "
f"sampled_token_ids={sampled_token_ids_repr}, " f"sampled_token_ids={sampled_token_ids_repr}, "
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
@dataclass
class ExecuteModelRequest:
"""The model execution request."""
# The sequence group metadata list.
seq_group_metadata_list: List[SequenceGroupMetadata]
# Blocks to swap in. Dict of CPU -> GPU block number.
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
# Blocks to copy. Source to a list of dest blocks.
blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict)
# The number of slots for lookahead decoding.
num_lookahead_slots: int = 0
# The number of requests in the running queue.
running_queue_size: int = 0
def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
) -> "ExecuteModelRequest":
"""Clone the request with a new sequence group metadata list."""
return ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=self.blocks_to_swap_in.copy(),
blocks_to_swap_out=self.blocks_to_swap_out.copy(),
blocks_to_copy=self.blocks_to_copy.copy(),
num_lookahead_slots=self.num_lookahead_slots,
running_queue_size=self.running_queue_size,
)

View File

@ -1,9 +1,10 @@
from itertools import chain, count from itertools import chain, count
from typing import Dict, Iterator, List, Optional, Tuple from typing import Iterator, List, Tuple
import torch import torch
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range, from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
@nvtx_range("BatchExpansionTop1Scorer.score_proposals") @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
def score_proposals( def score_proposals(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
) -> SpeculativeScores: ) -> SpeculativeScores:
"""Score the proposed tokens via the scorer model. """Score the proposed tokens via the scorer model.
@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
no speculation is produced for that sequence. no speculation is produced for that sequence.
Args: Args:
seq_group_metadata_list: The input sequence group metadata. execute_model_req: The execution request.
blocks_to_swap_in: This is passed to the worker during scoring.
blocks_to_swap_out: This is passed to the worker during scoring.
blocks_to_copy: This is passed to the worker during scoring.
k: The fixed proposal length.
proposals: The speculative proposals to score. proposals: The speculative proposals to score.
Returns: Returns:
SpeculativeScores: The scores of each speculative token, along with SpeculativeScores: The scores of each speculative token, along with
@ -80,28 +73,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
(spec_indices, non_spec_indices, target_seq_group_metadata_list, (spec_indices, non_spec_indices, target_seq_group_metadata_list,
num_scoring_tokens) = self._expand_batch( num_scoring_tokens) = self._expand_batch(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
proposal_token_ids_list=proposal_token_ids_list_without_skips, proposal_token_ids_list=proposal_token_ids_list_without_skips,
proposal_lens_list=proposal_lens_list, proposal_lens_list=proposal_lens_list,
) )
target_sampler_output = self._scorer_worker.execute_model( target_sampler_output = self._scorer_worker.execute_model(
seq_group_metadata_list=target_seq_group_metadata_list, execute_model_req=execute_model_req.clone(
blocks_to_swap_in=blocks_to_swap_in, seq_group_metadata_list=target_seq_group_metadata_list, ))
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
assert len(target_sampler_output) == 1, "expected single-step output" assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0] target_sampler_output = target_sampler_output[0]
all_tokens, all_probs, spec_logprobs = self._contract_batch( all_tokens, all_probs, spec_logprobs = self._contract_batch(
contracted_bs=len(seq_group_metadata_list), contracted_bs=len(execute_model_req.seq_group_metadata_list),
target_sampler_output=target_sampler_output, target_sampler_output=target_sampler_output,
proposals=proposals, proposals=proposals,
num_scoring_tokens=num_scoring_tokens, num_scoring_tokens=num_scoring_tokens,
non_spec_indices=non_spec_indices, non_spec_indices=non_spec_indices,
spec_indices=spec_indices, spec_indices=spec_indices,
k=k, k=execute_model_req.num_lookahead_slots,
) )
return SpeculativeScores( return SpeculativeScores(

View File

@ -1,10 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional
import torch import torch
from vllm.sequence import SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest
@dataclass @dataclass
@ -58,11 +57,7 @@ class SpeculativeProposer(ABC):
@abstractmethod @abstractmethod
def get_proposals( def get_proposals(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals: ) -> SpeculativeProposals:
raise NotImplementedError raise NotImplementedError
@ -72,11 +67,7 @@ class SpeculativeScorer(ABC):
@abstractmethod @abstractmethod
def score_proposals( def score_proposals(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals, proposals: SpeculativeProposals,
) -> SpeculativeScores: ) -> SpeculativeScores:
raise NotImplementedError raise NotImplementedError

View File

@ -1,9 +1,10 @@
import copy import copy
from typing import Dict, List, Tuple from typing import List, Tuple
import torch import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
@ -44,10 +45,7 @@ class MultiStepWorker(Worker):
@torch.inference_mode() @torch.inference_mode()
def sampler_output( def sampler_output(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
sample_len: int, sample_len: int,
) -> Tuple[List[SamplerOutput], bool]: ) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass sample_len times. Returns the list of """Run the model forward pass sample_len times. Returns the list of
@ -57,26 +55,24 @@ class MultiStepWorker(Worker):
For multi step worker, this indicator shall be True. For multi step worker, this indicator shall be True.
""" """
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in, self._raise_if_unsupported(execute_model_req)
blocks_to_swap_out, blocks_to_copy)
# Shallow copy input data so modifications (such as appending tokens) # Shallow copy input data so modifications (such as appending tokens)
# do not cause side-effects. # do not cause side-effects.
copied_seq_group_metadata_list = self._shallow_copy_inputs( copied_seq_group_metadata_list = self._shallow_copy_inputs(
seq_group_metadata_list) execute_model_req.seq_group_metadata_list)
copied_execute_model_req = execute_model_req.clone(
copied_seq_group_metadata_list)
# Assert enough KV space for sample_len tokens per sequence. # Assert enough KV space for sample_len tokens per sequence.
self._assert_enough_kv_space(seq_group_metadata_list, sample_len) self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
sample_len)
# Run model sample_len times. # Run model sample_len times.
model_outputs = [] model_outputs = []
for _ in range(sample_len): for _ in range(sample_len):
model_output = super().execute_model( model_output = super().execute_model(
seq_group_metadata_list=copied_seq_group_metadata_list, execute_model_req=copied_execute_model_req)
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
assert (len(model_output) == 1 assert (len(model_output) == 1
), "composing multistep workers not supported" ), "composing multistep workers not supported"
model_output = model_output[0] model_output = model_output[0]
@ -89,23 +85,13 @@ class MultiStepWorker(Worker):
def get_spec_proposals( def get_spec_proposals(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals: ) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of """Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len. speculative tokens per sequence is determined by max_proposal_len.
""" """
return self._proposer.get_proposals( return self._proposer.get_proposals(execute_model_req)
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
max_proposal_len,
)
def _append_new_tokens( def _append_new_tokens(
self, model_output: SamplerOutput, self, model_output: SamplerOutput,
@ -196,20 +182,22 @@ class MultiStepWorker(Worker):
def _raise_if_unsupported( def _raise_if_unsupported(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None: ) -> None:
"""MultiStepWorker does not yet implement support for cache swap """MultiStepWorker does not yet implement support for cache swap
operations or beam search. operations or beam search.
""" """
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError( raise NotImplementedError(
"MultiStepWorker does not support cache operations") "MultiStepWorker does not support cache operations")
if any( if any(
len(seq_group_metadata.seq_data.keys()) != 1 len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in seq_group_metadata_list): for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError( raise NotImplementedError(
"MultiStepWorker does not support beam search.") "MultiStepWorker does not support beam search.")

View File

@ -1,8 +1,8 @@
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.top1_proposer import Top1Proposer from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase
@ -46,13 +46,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
# NGram don't need gpu sampler # NGram don't need gpu sampler
pass pass
def execute_model( def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
) -> None:
"""NGram doesn't depend on model execution, just pass this function""" """NGram doesn't depend on model execution, just pass this function"""
pass pass
@ -71,10 +65,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
def sampler_output( def sampler_output(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
sample_len: int, sample_len: int,
) -> Tuple[Optional[List[SamplerOutput]], bool]: ) -> Tuple[Optional[List[SamplerOutput]], bool]:
"""NGram match algo to pick proposal candidate. Returns the list of """NGram match algo to pick proposal candidate. Returns the list of
@ -83,16 +74,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
For ngram worker, we already done needed transposed internal, so the For ngram worker, we already done needed transposed internal, so the
indicator pass to sampler_output_to_torch shall be False. indicator pass to sampler_output_to_torch shall be False.
""" """
self._raise_if_unsupported( self._raise_if_unsupported(execute_model_req)
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
)
arr = [] arr = []
has_spec_out = False has_spec_out = False
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in execute_model_req.seq_group_metadata_list:
seq_data = next(iter(seq_group_metadata.seq_data.values())) seq_data = next(iter(seq_group_metadata.seq_data.values()))
input_ids = torch.as_tensor(seq_data.get_token_ids(), input_ids = torch.as_tensor(seq_data.get_token_ids(),
@ -135,17 +121,19 @@ class NGramWorker(LoraNotSupportedWorkerBase):
indices = token_ids.unsqueeze(2) indices = token_ids.unsqueeze(2)
token_probs = torch.zeros( token_probs = torch.zeros(
(len(seq_group_metadata_list), sample_len, self.vocab_size), (len(execute_model_req.seq_group_metadata_list), sample_len,
self.vocab_size),
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
token_probs.scatter_(2, indices, 1) token_probs.scatter_(2, indices, 1)
token_logprobs = torch.zeros( token_logprobs = torch.zeros(
(len(seq_group_metadata_list), sample_len, self.vocab_size), (len(execute_model_req.seq_group_metadata_list), sample_len,
self.vocab_size),
dtype=torch.float32, dtype=torch.float32,
device=self.device, device=self.device,
) )
for i in range(len(seq_group_metadata_list)): for i in range(len(execute_model_req.seq_group_metadata_list)):
outputs.append( outputs.append(
SamplerOutput( SamplerOutput(
outputs=None, outputs=None,
@ -157,40 +145,32 @@ class NGramWorker(LoraNotSupportedWorkerBase):
def get_spec_proposals( def get_spec_proposals(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
max_proposal_len: int,
) -> SpeculativeProposals: ) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of """Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len. speculative tokens per sequence is determined by max_proposal_len.
""" """
return self._proposer.get_proposals( return self._proposer.get_proposals(execute_model_req)
seq_group_metadata_list,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
max_proposal_len,
)
def _raise_if_unsupported( def _raise_if_unsupported(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None: ) -> None:
"""NGramWorker does not yet implement support for cache swap """NGramWorker does not yet implement support for cache swap
operations or beam search. operations or beam search.
""" """
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]): if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError( raise NotImplementedError(
"NGramWorker does not support cache operations") "NGramWorker does not support cache operations")
if any( if any(
len(seq_group_metadata.seq_data.keys()) != 1 len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in seq_group_metadata_list): for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError( raise NotImplementedError(
"NGramWorker does not support beam search.") "NGramWorker does not support beam search.")

View File

@ -1,11 +1,12 @@
from functools import cached_property from functools import cached_property
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores) SpeculativeScorer, SpeculativeScores)
@ -190,68 +191,36 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
num_lookahead_slots: int,
) -> List[SamplerOutput]:
"""Perform speculative decoding on the input batch. """Perform speculative decoding on the input batch.
""" """
assert seq_group_metadata_list is not None, ( assert execute_model_req.seq_group_metadata_list is not None, (
"speculative decoding " "speculative decoding "
"requires non-None seq_group_metadata_list") "requires non-None seq_group_metadata_list")
#logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
# num_lookahead_slots)
# If no spec tokens, call the proposer and scorer workers normally. # If no spec tokens, call the proposer and scorer workers normally.
# Used for prefill. # Used for prefill.
if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0: if execute_model_req.num_lookahead_slots == 0 or len(
return self._run_no_spec( execute_model_req.seq_group_metadata_list) == 0:
seq_group_metadata_list=seq_group_metadata_list, return self._run_no_spec(execute_model_req)
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return self._run_speculative_decoding_step( return self._run_speculative_decoding_step(execute_model_req)
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
k=num_lookahead_slots,
)
@nvtx_range("spec_decode_worker._run_no_spec") @nvtx_range("spec_decode_worker._run_no_spec")
def _run_no_spec( def _run_no_spec(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
) -> List[SamplerOutput]:
"""Run a prefill step, without any speculation. The input is sent to the """Run a prefill step, without any speculation. The input is sent to the
proposer and scorer model so that the KV cache is consistent between the proposer and scorer model so that the KV cache is consistent between the
two. two.
""" """
#logger.info("run proposer worker no spec") #logger.info("run proposer worker no spec")
self.proposer_worker.execute_model( self.proposer_worker.execute_model(execute_model_req)
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
#logger.info("run target worker no spec") #logger.info("run target worker no spec")
sampler_output = self.scorer_worker.execute_model( sampler_output = self.scorer_worker.execute_model(execute_model_req)
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
assert len(sampler_output) == 1 assert len(sampler_output) == 1
sampler_output = sampler_output[0] sampler_output = sampler_output[0]
@ -265,12 +234,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
@nvtx_range("spec_decode_worker._run_speculative_decoding_step") @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
def _run_speculative_decoding_step( def _run_speculative_decoding_step(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
blocks_to_swap_in: Optional[Dict[int, int]],
blocks_to_swap_out: Optional[Dict[int, int]],
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
) -> List[SamplerOutput]:
"""Execute a single step of speculative decoding. """Execute a single step of speculative decoding.
This invokes the proposer worker to get k speculative tokens for each This invokes the proposer worker to get k speculative tokens for each
@ -282,33 +246,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
#logger.info("get spec proposals") #logger.info("get spec proposals")
# Generate proposals using draft worker. # Generate proposals using draft worker.
assert blocks_to_swap_in is not None proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
proposals = self.proposer_worker.get_spec_proposals(
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
blocks_to_copy, k)
#logger.info("score proposals") #logger.info("score proposals")
proposal_scores = self.scorer.score_proposals( proposal_scores = self.scorer.score_proposals(
seq_group_metadata_list, execute_model_req,
blocks_to_swap_in,
blocks_to_swap_out,
blocks_to_copy,
k,
proposals, proposals,
) )
#logger.info("verify proposals") #logger.info("verify proposals")
accepted_token_ids, target_logprobs = self._verify_tokens( accepted_token_ids, target_logprobs = self._verify_tokens(
seq_group_metadata_list, proposal_scores, proposals, k) execute_model_req.seq_group_metadata_list, proposal_scores,
proposals, execute_model_req.num_lookahead_slots)
#logger.info("create output list") #logger.info("create output list")
return self._create_output_sampler_list( return self._create_output_sampler_list(
seq_group_metadata_list, execute_model_req.seq_group_metadata_list,
accepted_token_ids, accepted_token_ids,
target_logprobs=target_logprobs, target_logprobs=target_logprobs,
k=k) k=execute_model_req.num_lookahead_slots)
@nvtx_range("spec_decode_worker._verify_tokens") @nvtx_range("spec_decode_worker._verify_tokens")
def _verify_tokens( def _verify_tokens(

View File

@ -1,8 +1,9 @@
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import (SpeculativeProposals, from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeProposer) SpeculativeProposer)
from vllm.spec_decode.util import sampler_output_to_torch from vllm.spec_decode.util import sampler_output_to_torch
@ -40,17 +41,15 @@ class Top1Proposer(SpeculativeProposer):
def get_proposals( def get_proposals(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], execute_model_req: ExecuteModelRequest,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
proposal_len: int,
) -> SpeculativeProposals: ) -> SpeculativeProposals:
"""Get speculative proposals given the input batch. """Get speculative proposals given the input batch.
Sequences which would exceed the max model length are skipped during Sequences which would exceed the max model length are skipped during
speculation. speculation.
""" """
proposal_len = execute_model_req.num_lookahead_slots
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
# Split speculative- and non-speculative- sequences. # Split speculative- and non-speculative- sequences.
( (
@ -66,11 +65,12 @@ class Top1Proposer(SpeculativeProposer):
# token_ids is like [batch] format in proposal_len size list, # token_ids is like [batch] format in proposal_len size list,
# while if it is false, the format would be [proposal_len] # while if it is false, the format would be [proposal_len]
# in batch size list # in batch size list
maybe_sampler_output, transposed = self._worker.sampler_output( nonzero_execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=nonzero_proposal_len_seqs, seq_group_metadata_list=nonzero_proposal_len_seqs,
blocks_to_swap_in=blocks_to_swap_in, num_lookahead_slots=proposal_len,
blocks_to_swap_out=blocks_to_swap_out, )
blocks_to_copy=blocks_to_copy, maybe_sampler_output, transposed = self._worker.sampler_output(
execute_model_req=nonzero_execute_model_req,
sample_len=proposal_len, sample_len=proposal_len,
) )
else: else:

View File

@ -13,7 +13,7 @@ from vllm.distributed import (broadcast_tensor_dict,
init_distributed_environment) init_distributed_environment)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner
from vllm.worker.worker_base import LoraNotSupportedWorkerBase from vllm.worker.worker_base import LoraNotSupportedWorkerBase
@ -256,22 +256,24 @@ class CPUWorker(LoraNotSupportedWorkerBase):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, execute_model_req: Optional[ExecuteModelRequest] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
num_seq_groups: int = len(seq_group_metadata_list) num_seq_groups: int = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None assert execute_model_req is not None
assert blocks_to_swap_out is not None blocks_to_copy = execute_model_req.blocks_to_copy
assert blocks_to_copy is not None assert len(execute_model_req.blocks_to_swap_in) == 0
assert len(blocks_to_swap_in) == 0 assert len(execute_model_req.blocks_to_swap_out) == 0
assert len(blocks_to_swap_out) == 0
data: Dict[str, Any] = { data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups, "num_seq_groups": num_seq_groups,
"blocks_to_copy": blocks_to_copy, "blocks_to_copy": execute_model_req.blocks_to_copy,
} }
broadcast_tensor_dict(data, src=0) broadcast_tensor_dict(data, src=0)
else: else:
@ -279,7 +281,6 @@ class CPUWorker(LoraNotSupportedWorkerBase):
num_seq_groups = data["num_seq_groups"] num_seq_groups = data["num_seq_groups"]
blocks_to_copy = data["blocks_to_copy"] blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_copy is not None
self.cache_copy(blocks_to_copy) self.cache_copy(blocks_to_copy)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.

View File

@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.custom_all_reduce import (
init_custom_ar) init_custom_ar)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
@ -211,19 +211,21 @@ class Worker(WorkerBase):
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, execute_model_req: Optional[ExecuteModelRequest] = None
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
num_lookahead_slots: int = 0,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
if execute_model_req is None:
seq_group_metadata_list = None
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
assert execute_model_req is not None
num_seq_groups = len(seq_group_metadata_list) num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None blocks_to_swap_in = execute_model_req.blocks_to_swap_in
assert blocks_to_swap_out is not None blocks_to_swap_out = execute_model_req.blocks_to_swap_out
assert blocks_to_copy is not None blocks_to_copy = execute_model_req.blocks_to_copy
data: Dict[str, Any] = { data: Dict[str, Any] = {
"num_seq_groups": num_seq_groups, "num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_in": blocks_to_swap_in,
@ -238,9 +240,6 @@ class Worker(WorkerBase):
blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"] blocks_to_copy = data["blocks_to_copy"]
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.

View File

@ -5,7 +5,7 @@ from typing import Dict, List, Set, Tuple
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables) update_environment_variables)
@ -48,10 +48,8 @@ class WorkerBase(ABC):
@abstractmethod @abstractmethod
def execute_model( def execute_model(
self, seq_group_metadata_list: List[SequenceGroupMetadata], self,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
int],
blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences, unless no """Executes at least one model step on the given sequences, unless no
sequences are provided.""" sequences are provided."""
raise NotImplementedError raise NotImplementedError