mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 02:24:58 +08:00
[Misc][Refactor] Introduce ExecuteModelData (#4540)
This commit is contained in:
parent
344bf7cd2d
commit
bc8ad68455
@ -5,13 +5,12 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
from .utils import (assert_logprobs_dict_allclose, create_batch,
|
from .utils import (assert_logprobs_dict_allclose, create_batch,
|
||||||
create_execute_model_data,
|
|
||||||
create_seq_group_metadata_from_prompts, create_worker,
|
create_seq_group_metadata_from_prompts, create_worker,
|
||||||
patch_execute_model_with_seeds, zero_kv_cache)
|
patch_execute_model_with_seeds, zero_kv_cache)
|
||||||
|
|
||||||
@ -105,31 +104,32 @@ def test_same_output_for_single_step():
|
|||||||
|
|
||||||
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||||
|
|
||||||
multi_step_execute_model_data = create_execute_model_data(
|
multi_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
|
||||||
prompts,
|
prompts,
|
||||||
num_gpu_blocks,
|
num_gpu_blocks,
|
||||||
block_size,
|
block_size,
|
||||||
final_prompt_lens=final_prompt_lens))
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
single_step_execute_model_data = create_execute_model_data(
|
|
||||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
|
||||||
prompts,
|
|
||||||
num_gpu_blocks,
|
|
||||||
block_size,
|
|
||||||
final_prompt_lens=final_prompt_lens))
|
|
||||||
|
|
||||||
zero_kv_cache(multi_step_worker.cache_engine)
|
zero_kv_cache(multi_step_worker.cache_engine)
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
actual_output, _ = multi_step_worker.sampler_output(
|
actual_output, _ = multi_step_worker.sampler_output(
|
||||||
**multi_step_execute_model_data.to_dict(), sample_len=num_steps)
|
execute_model_req=ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=multi_step_seq_group),
|
||||||
|
sample_len=num_steps)
|
||||||
assert len(actual_output) == num_steps
|
assert len(actual_output) == num_steps
|
||||||
actual_output = actual_output[0]
|
actual_output = actual_output[0]
|
||||||
|
|
||||||
|
single_step_seq_group = create_seq_group_metadata_from_prompts(
|
||||||
|
prompts,
|
||||||
|
num_gpu_blocks,
|
||||||
|
block_size,
|
||||||
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
zero_kv_cache(worker.cache_engine)
|
zero_kv_cache(worker.cache_engine)
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
expected_output = worker.execute_model(
|
expected_output = worker.execute_model(
|
||||||
**single_step_execute_model_data.to_dict(), )[0]
|
execute_model_req=ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=single_step_seq_group))[0]
|
||||||
|
|
||||||
actual_token_ids = [
|
actual_token_ids = [
|
||||||
output.samples[0].output_token for output in actual_output
|
output.samples[0].output_token for output in actual_output
|
||||||
@ -193,19 +193,20 @@ def test_same_output_for_multi_step():
|
|||||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||||
|
|
||||||
continuations = [[1] for _ in prompts]
|
continuations = [[1] for _ in prompts]
|
||||||
execute_model_data = create_execute_model_data(
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
create_seq_group_metadata_from_prompts(
|
|
||||||
prompts,
|
prompts,
|
||||||
num_gpu_blocks,
|
num_gpu_blocks,
|
||||||
block_size,
|
block_size,
|
||||||
continuations=continuations,
|
continuations=continuations,
|
||||||
final_prompt_lens=final_prompt_lens), )
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
# Run multi-step.
|
# Run multi-step.
|
||||||
zero_kv_cache(multi_step_worker.cache_engine)
|
zero_kv_cache(multi_step_worker.cache_engine)
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
multi_step_output, _ = multi_step_worker.sampler_output(
|
multi_step_output, _ = multi_step_worker.sampler_output(
|
||||||
**execute_model_data.to_dict(), sample_len=num_steps)
|
execute_model_req=ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list),
|
||||||
|
sample_len=num_steps)
|
||||||
|
|
||||||
# Run single-step repeatedly.
|
# Run single-step repeatedly.
|
||||||
zero_kv_cache(worker.cache_engine)
|
zero_kv_cache(worker.cache_engine)
|
||||||
@ -215,16 +216,16 @@ def test_same_output_for_multi_step():
|
|||||||
|
|
||||||
for _ in multi_step_output:
|
for _ in multi_step_output:
|
||||||
|
|
||||||
execute_model_data = create_execute_model_data(
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
create_seq_group_metadata_from_prompts(
|
|
||||||
prompts,
|
prompts,
|
||||||
num_gpu_blocks,
|
num_gpu_blocks,
|
||||||
block_size,
|
block_size,
|
||||||
continuations=continuations,
|
continuations=continuations,
|
||||||
final_prompt_lens=final_prompt_lens))
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
single_step_output.extend(
|
single_step_output.extend(
|
||||||
worker.execute_model(**execute_model_data.to_dict(), ))
|
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list)))
|
||||||
|
|
||||||
# Append output tokens to new sequence data.
|
# Append output tokens to new sequence data.
|
||||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||||
@ -304,12 +305,11 @@ def test_draft_proposals_full_speculation_len():
|
|||||||
) for _ in range(k)
|
) for _ in range(k)
|
||||||
], True
|
], True
|
||||||
|
|
||||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(
|
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||||
**execute_model_data.to_dict(),
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
proposal_len=k,
|
num_lookahead_slots=k), )
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
@ -340,14 +340,13 @@ def test_draft_proposals_no_speculations():
|
|||||||
max_proposal_len=prompt_len + k - 1,
|
max_proposal_len=prompt_len + k - 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
execute_model_data, _, _ = create_batch(batch_size,
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||||
k,
|
k,
|
||||||
prompt_len=prompt_len)
|
prompt_len=prompt_len)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(
|
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||||
**execute_model_data.to_dict(),
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
proposal_len=k,
|
num_lookahead_slots=k), )
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
@ -409,17 +408,16 @@ def test_draft_proposals_mixed_k():
|
|||||||
) for _ in range(k)
|
) for _ in range(k)
|
||||||
], True
|
], True
|
||||||
|
|
||||||
execute_model_data, _, _ = create_batch(
|
seq_group_metadata_list, _, _ = create_batch(
|
||||||
batch_size,
|
batch_size,
|
||||||
k,
|
k,
|
||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
prev_output_token_len=prev_output_token_len,
|
prev_output_token_len=prev_output_token_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(
|
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||||
**execute_model_data.to_dict(),
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
proposal_len=k,
|
num_lookahead_slots=k), )
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
|
|
||||||
from .utils import (create_execute_model_data,
|
from .utils import create_seq_group_metadata_from_prompts, create_worker
|
||||||
create_seq_group_metadata_from_prompts, create_worker)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ngram_algo_correctness_for_single_no_match():
|
def test_ngram_algo_correctness_for_single_no_match():
|
||||||
@ -44,17 +44,15 @@ def test_ngram_algo_correctness_for_single_no_match():
|
|||||||
|
|
||||||
proposal_len = 5
|
proposal_len = 5
|
||||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||||
ngram_sampler_output_data = create_execute_model_data(
|
|
||||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
num_gpu_blocks,
|
num_gpu_blocks,
|
||||||
block_size,
|
block_size,
|
||||||
final_prompt_lens=final_prompt_lens))
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(
|
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||||
**ngram_sampler_output_data.to_dict(),
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
proposal_len=proposal_len,
|
num_lookahead_slots=proposal_len), )
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
@ -113,17 +111,15 @@ def test_ngram_algo_correctness_for_batches_not_match_all():
|
|||||||
|
|
||||||
proposal_len = 5
|
proposal_len = 5
|
||||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||||
ngram_sampler_output_data = create_execute_model_data(
|
|
||||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
num_gpu_blocks,
|
num_gpu_blocks,
|
||||||
block_size,
|
block_size,
|
||||||
final_prompt_lens=final_prompt_lens))
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(
|
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||||
**ngram_sampler_output_data.to_dict(),
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
proposal_len=proposal_len,
|
num_lookahead_slots=proposal_len), )
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
@ -185,17 +181,15 @@ def test_ngram_algo_correctness_for_batches_match_all():
|
|||||||
|
|
||||||
proposal_len = 5
|
proposal_len = 5
|
||||||
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
|
||||||
ngram_sampler_output_data = create_execute_model_data(
|
|
||||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
prompts,
|
prompts,
|
||||||
num_gpu_blocks,
|
num_gpu_blocks,
|
||||||
block_size,
|
block_size,
|
||||||
final_prompt_lens=final_prompt_lens))
|
final_prompt_lens=final_prompt_lens)
|
||||||
|
|
||||||
proposals = proposer.get_proposals(
|
proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
|
||||||
**ngram_sampler_output_data.to_dict(),
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
proposal_len=proposal_len,
|
num_lookahead_slots=proposal_len), )
|
||||||
)
|
|
||||||
|
|
||||||
assert torch.is_tensor(proposals.proposal_token_ids)
|
assert torch.is_tensor(proposals.proposal_token_ids)
|
||||||
assert torch.is_tensor(proposals.proposal_probs)
|
assert torch.is_tensor(proposals.proposal_probs)
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||||
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
from vllm.spec_decode.metrics import (AsyncMetricsCollector,
|
||||||
SpecDecodeWorkerMetrics)
|
SpecDecodeWorkerMetrics)
|
||||||
@ -15,8 +15,7 @@ from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
|||||||
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
|
||||||
split_num_cache_blocks_evenly)
|
split_num_cache_blocks_evenly)
|
||||||
|
|
||||||
from .utils import (ExecuteModelData, create_batch, create_sampler_output_list,
|
from .utils import create_batch, create_sampler_output_list, mock_worker
|
||||||
mock_worker)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
@ -36,24 +35,19 @@ def test_correctly_calls_draft_model(k: int, batch_size: int):
|
|||||||
exception_secret = 'artificial stop'
|
exception_secret = 'artificial stop'
|
||||||
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||||
|
execute_model_req = ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=exception_secret):
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
worker.execute_model(**execute_model_data.to_dict(),
|
worker.execute_model(execute_model_req=execute_model_req)
|
||||||
num_lookahead_slots=k)
|
|
||||||
|
|
||||||
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
call_args_list = draft_worker.get_spec_proposals.call_args_list
|
||||||
assert len(call_args_list) == 1
|
assert len(call_args_list) == 1
|
||||||
|
|
||||||
for args, _ in call_args_list:
|
for args, _ in call_args_list:
|
||||||
(seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
actual_execute_model_data = args[0]
|
||||||
blocks_to_copy, actual_k) = args
|
assert actual_execute_model_data == execute_model_req
|
||||||
actual_execute_model_data = ExecuteModelData(seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out,
|
|
||||||
blocks_to_copy)
|
|
||||||
assert actual_execute_model_data == execute_model_data
|
|
||||||
assert actual_k == k
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('k', [1, 2, 6])
|
@pytest.mark.parametrize('k', [1, 2, 6])
|
||||||
@ -93,7 +87,7 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
|||||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||||
device='cuda') * k
|
device='cuda') * k
|
||||||
|
|
||||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
seq_group_metadata_list, prompts, prev_output_tokens = create_batch(
|
||||||
batch_size, k)
|
batch_size, k)
|
||||||
|
|
||||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||||
@ -105,20 +99,20 @@ def test_correctly_calls_target_model(k: int, batch_size: int):
|
|||||||
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
target_worker.execute_model.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=exception_secret):
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
worker.execute_model(**execute_model_data.to_dict(),
|
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||||
num_lookahead_slots=k)
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
num_lookahead_slots=k))
|
||||||
|
|
||||||
seen_contexts = []
|
seen_contexts = []
|
||||||
|
|
||||||
call_args_list = target_worker.execute_model.call_args_list
|
call_args_list = target_worker.execute_model.call_args_list
|
||||||
assert len(call_args_list) == 1
|
assert len(call_args_list) == 1
|
||||||
for args, kwargs in call_args_list:
|
for _, kwargs in call_args_list:
|
||||||
target_execute_model_data = ExecuteModelData.from_dict(kwargs)
|
seq_group_metadata_list = kwargs[
|
||||||
|
"execute_model_req"].seq_group_metadata_list
|
||||||
|
|
||||||
assert len(target_execute_model_data.seq_group_metadata_list) == (
|
assert len(seq_group_metadata_list) == (k + 1) * batch_size
|
||||||
k + 1) * batch_size
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
for seq_group_metadata in (
|
|
||||||
target_execute_model_data.seq_group_metadata_list):
|
|
||||||
for seq_data in seq_group_metadata.seq_data.values():
|
for seq_data in seq_group_metadata.seq_data.values():
|
||||||
seen_contexts.append(seq_data.get_token_ids())
|
seen_contexts.append(seq_data.get_token_ids())
|
||||||
|
|
||||||
@ -175,7 +169,7 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
|||||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||||
device='cuda') * k
|
device='cuda') * k
|
||||||
|
|
||||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||||
|
|
||||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||||
proposal_token_ids=proposal_token_ids,
|
proposal_token_ids=proposal_token_ids,
|
||||||
@ -207,8 +201,9 @@ def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
|
|||||||
rejection_sampler.side_effect = ValueError(exception_secret)
|
rejection_sampler.side_effect = ValueError(exception_secret)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=exception_secret):
|
with pytest.raises(ValueError, match=exception_secret):
|
||||||
worker.execute_model(**execute_model_data.to_dict(),
|
worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||||
num_lookahead_slots=k)
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
num_lookahead_slots=k))
|
||||||
|
|
||||||
assert len(rejection_sampler.call_args_list) == 1
|
assert len(rejection_sampler.call_args_list) == 1
|
||||||
_, kwargs = rejection_sampler.call_args_list[0]
|
_, kwargs = rejection_sampler.call_args_list[0]
|
||||||
@ -262,7 +257,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
|||||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||||
device='cuda') * k
|
device='cuda') * k
|
||||||
|
|
||||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||||
|
|
||||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||||
proposal_token_ids=proposal_token_ids,
|
proposal_token_ids=proposal_token_ids,
|
||||||
@ -302,8 +297,9 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
|||||||
|
|
||||||
rejection_sampler.return_value = rejection_sampler_output
|
rejection_sampler.return_value = rejection_sampler_output
|
||||||
|
|
||||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||||
num_lookahead_slots=k)
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
num_lookahead_slots=k))
|
||||||
|
|
||||||
expected_output = create_sampler_output_list(
|
expected_output = create_sampler_output_list(
|
||||||
token_ids=rejection_sampler_output.transpose(0, 1),
|
token_ids=rejection_sampler_output.transpose(0, 1),
|
||||||
@ -312,7 +308,7 @@ def test_correctly_formats_output(k: int, batch_size: int):
|
|||||||
|
|
||||||
seq_ids = [
|
seq_ids = [
|
||||||
next(iter(seq_group_metadata.seq_data.keys()))
|
next(iter(seq_group_metadata.seq_data.keys()))
|
||||||
for seq_group_metadata in execute_model_data.seq_group_metadata_list
|
for seq_group_metadata in seq_group_metadata_list
|
||||||
]
|
]
|
||||||
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||||
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}
|
||||||
@ -383,7 +379,7 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
|||||||
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
proposal_lens = torch.ones(batch_size, dtype=torch.int64,
|
||||||
device='cuda') * k
|
device='cuda') * k
|
||||||
|
|
||||||
execute_model_data, _, _ = create_batch(batch_size, k)
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||||
|
|
||||||
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
|
||||||
proposal_token_ids=proposal_token_ids,
|
proposal_token_ids=proposal_token_ids,
|
||||||
@ -428,8 +424,9 @@ def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
|
|||||||
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
|
metrics_collector.maybe_collect_rejsample_metrics.return_value = (
|
||||||
mock_rejsample_metrics)
|
mock_rejsample_metrics)
|
||||||
|
|
||||||
output = worker.execute_model(**execute_model_data.to_dict(),
|
output = worker.execute_model(execute_model_req=ExecuteModelRequest(
|
||||||
num_lookahead_slots=k)
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
num_lookahead_slots=k))
|
||||||
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics
|
||||||
|
|
||||||
call_args_list = (
|
call_args_list = (
|
||||||
@ -462,21 +459,21 @@ def test_k_equals_zero(k: int, batch_size: int):
|
|||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
|
|
||||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||||
batch_size, k, prev_output_token_len=0)
|
k,
|
||||||
|
prev_output_token_len=0)
|
||||||
|
execute_model_req = ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||||
|
|
||||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||||
num_lookahead_slots=k)
|
|
||||||
|
|
||||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||||
assert out[
|
assert out[
|
||||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||||
|
|
||||||
draft_worker.execute_model.assert_called_once_with(
|
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||||
**execute_model_data.to_dict())
|
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||||
target_worker.execute_model.assert_called_once_with(
|
|
||||||
**execute_model_data.to_dict())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('k', [0, 5])
|
@pytest.mark.parametrize('k', [0, 5])
|
||||||
@ -503,21 +500,21 @@ def test_empty_input_batch(k: int, batch_size: int):
|
|||||||
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
|
||||||
metrics_collector)
|
metrics_collector)
|
||||||
|
|
||||||
execute_model_data, prompts, prev_output_tokens = create_batch(
|
seq_group_metadata_list, _, _ = create_batch(batch_size,
|
||||||
batch_size, k, prev_output_token_len=0)
|
k,
|
||||||
|
prev_output_token_len=0)
|
||||||
|
execute_model_req = ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k)
|
||||||
|
|
||||||
out = worker.execute_model(**execute_model_data.to_dict(),
|
out = worker.execute_model(execute_model_req=execute_model_req)
|
||||||
num_lookahead_slots=k)
|
|
||||||
|
|
||||||
assert len(out) == 1, f"expected only one token output when {k=}"
|
assert len(out) == 1, f"expected only one token output when {k=}"
|
||||||
assert out[0].probs is None, "expect gpu tensor references to be None"
|
assert out[0].probs is None, "expect gpu tensor references to be None"
|
||||||
assert out[
|
assert out[
|
||||||
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
0].sampled_tokens is None, "expect gpu tensor references to be None"
|
||||||
|
|
||||||
draft_worker.execute_model.assert_called_once_with(
|
draft_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||||
**execute_model_data.to_dict())
|
target_worker.execute_model.assert_called_once_with(execute_model_req)
|
||||||
target_worker.execute_model.assert_called_once_with(
|
|
||||||
**execute_model_data.to_dict())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
from dataclasses import dataclass, fields
|
|
||||||
from itertools import count
|
from itertools import count
|
||||||
from typing import Dict, Iterable, List, Optional, Union
|
from typing import Dict, Iterable, List, Optional, Union
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
@ -16,50 +15,10 @@ from vllm.worker.cache_engine import CacheEngine
|
|||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExecuteModelData:
|
|
||||||
"""Helper data structure which facilitates cleaner tests.
|
|
||||||
"""
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
|
||||||
blocks_to_swap_in: Dict[int, int]
|
|
||||||
blocks_to_swap_out: Dict[int, int]
|
|
||||||
blocks_to_copy: Dict[int, List[int]]
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return dict(
|
|
||||||
(field.name, getattr(self, field.name)) for field in fields(self))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, d):
|
|
||||||
cleaned = dict((field.name, d[field.name]) for field in fields(cls))
|
|
||||||
return cls(**cleaned)
|
|
||||||
|
|
||||||
|
|
||||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||||
return (seq_len + block_size - 1) // block_size
|
return (seq_len + block_size - 1) // block_size
|
||||||
|
|
||||||
|
|
||||||
def create_execute_model_data(
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
|
||||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
|
||||||
blocks_to_copy: Optional[Dict[int, int]] = None,
|
|
||||||
) -> ExecuteModelData:
|
|
||||||
if blocks_to_swap_in is None:
|
|
||||||
blocks_to_swap_in = {}
|
|
||||||
if blocks_to_swap_out is None:
|
|
||||||
blocks_to_swap_out = {}
|
|
||||||
if blocks_to_copy is None:
|
|
||||||
blocks_to_copy = {}
|
|
||||||
|
|
||||||
return ExecuteModelData(
|
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def mock_worker(cls=None,
|
def mock_worker(cls=None,
|
||||||
vocab_size: int = 30_000,
|
vocab_size: int = 30_000,
|
||||||
max_model_len: int = 2048,
|
max_model_len: int = 2048,
|
||||||
@ -258,8 +217,7 @@ def create_batch(batch_size,
|
|||||||
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
for prompt, prev_output_token in zip(prompts, prev_output_tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
execute_model_data = create_execute_model_data(
|
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||||
create_seq_group_metadata_from_prompts(prompts, num_gpu_blocks,
|
prompts, num_gpu_blocks, block_size, final_prompt_lens,
|
||||||
block_size, final_prompt_lens,
|
prev_output_tokens, seq_ids)
|
||||||
prev_output_tokens, seq_ids), )
|
return seq_group_metadata_list, prompts, prev_output_tokens
|
||||||
return execute_model_data, prompts, prev_output_tokens
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
@ -54,10 +55,14 @@ def test_swap() -> None:
|
|||||||
|
|
||||||
# Test swap out.
|
# Test swap out.
|
||||||
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
|
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
|
||||||
worker.execute_model(seq_group_metadata_list=[],
|
execute_model_req = ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=[],
|
||||||
blocks_to_swap_in={},
|
blocks_to_swap_in={},
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
blocks_to_copy={})
|
blocks_to_copy={},
|
||||||
|
)
|
||||||
|
worker.execute_model(execute_model_req=execute_model_req)
|
||||||
|
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||||
@ -66,14 +71,19 @@ def test_swap() -> None:
|
|||||||
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
|
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
|
||||||
|
|
||||||
# Test swap in.
|
# Test swap in.
|
||||||
blocks_to_swap_in = {19: 45, 67: 23, 12: 78, 40: 99, 1: 71}
|
execute_model_req.blocks_to_swap_out = {}
|
||||||
worker.execute_model(seq_group_metadata_list=[],
|
execute_model_req.blocks_to_swap_in = {
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
19: 45,
|
||||||
blocks_to_swap_out={},
|
67: 23,
|
||||||
blocks_to_copy={})
|
12: 78,
|
||||||
|
40: 99,
|
||||||
|
1: 71
|
||||||
|
}
|
||||||
|
worker.execute_model(execute_model_req=execute_model_req)
|
||||||
|
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||||
for src, dst in blocks_to_swap_in.items():
|
for src, dst in execute_model_req.blocks_to_swap_in.items():
|
||||||
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
|
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
|
||||||
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
|
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
|
||||||
|
|||||||
@ -128,6 +128,8 @@ class SchedulerOutputs:
|
|||||||
ignored_seq_groups: List[SequenceGroup]
|
ignored_seq_groups: List[SequenceGroup]
|
||||||
# The number of slots for lookahead decoding.
|
# The number of slots for lookahead decoding.
|
||||||
num_lookahead_slots: int
|
num_lookahead_slots: int
|
||||||
|
# The number of requests in the running queue
|
||||||
|
running_queue_size: int
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Swap in and swap out should never happen at the same time.
|
# Swap in and swap out should never happen at the same time.
|
||||||
@ -797,6 +799,7 @@ class Scheduler:
|
|||||||
ignored_seq_groups=prefills.ignored_seq_groups +
|
ignored_seq_groups=prefills.ignored_seq_groups +
|
||||||
swapped_in.infeasible_seq_groups,
|
swapped_in.infeasible_seq_groups,
|
||||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||||
|
running_queue_size=len(self.running),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _schedule_chunked_prefill(self):
|
def _schedule_chunked_prefill(self):
|
||||||
@ -883,6 +886,7 @@ class Scheduler:
|
|||||||
swapped_in.blocks_to_copy),
|
swapped_in.blocks_to_copy),
|
||||||
ignored_seq_groups=prefills.ignored_seq_groups,
|
ignored_seq_groups=prefills.ignored_seq_groups,
|
||||||
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
num_lookahead_slots=running_scheduled.num_lookahead_slots,
|
||||||
|
running_queue_size=len(self.running),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _schedule(self) -> SchedulerOutputs:
|
def _schedule(self) -> SchedulerOutputs:
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import MultiModalData, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -210,12 +210,16 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
|
|
||||||
if not scheduler_outputs.is_empty():
|
if not scheduler_outputs.is_empty():
|
||||||
# Execute the model.
|
# Execute the model.
|
||||||
|
execute_model_req = ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||||
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
|
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||||
|
running_queue_size=scheduler_outputs.running_queue_size,
|
||||||
|
)
|
||||||
output = await self.model_executor.execute_model_async(
|
output = await self.model_executor.execute_model_async(
|
||||||
seq_group_metadata_list,
|
execute_model_req)
|
||||||
scheduler_outputs.blocks_to_swap_in,
|
|
||||||
scheduler_outputs.blocks_to_swap_out,
|
|
||||||
scheduler_outputs.blocks_to_copy,
|
|
||||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
|
|
||||||
else:
|
else:
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
|
|||||||
@ -22,8 +22,8 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.sequence import (MultiModalData, SamplerOutput, Sequence,
|
from vllm.sequence import (ExecuteModelRequest, MultiModalData, SamplerOutput,
|
||||||
SequenceGroup, SequenceGroupMetadata,
|
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||||
SequenceStatus)
|
SequenceStatus)
|
||||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||||
@ -583,12 +583,16 @@ class LLMEngine:
|
|||||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||||
|
|
||||||
if not scheduler_outputs.is_empty():
|
if not scheduler_outputs.is_empty():
|
||||||
output = self.model_executor.execute_model(
|
execute_model_req = ExecuteModelRequest(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||||
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
|
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
|
||||||
|
running_queue_size=scheduler_outputs.running_queue_size,
|
||||||
|
)
|
||||||
|
output = self.model_executor.execute_model(
|
||||||
|
execute_model_req=execute_model_req)
|
||||||
else:
|
else:
|
||||||
output = []
|
output = []
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, List, Set, Tuple
|
from typing import List, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -7,7 +7,7 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
|||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
make_async)
|
make_async)
|
||||||
|
|
||||||
@ -72,18 +72,10 @@ class CPUExecutor(ExecutorBase):
|
|||||||
logger.info("# CPU blocks: %d", num_gpu_blocks)
|
logger.info("# CPU blocks: %d", num_gpu_blocks)
|
||||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
def execute_model(self,
|
def execute_model(
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
self,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
blocks_to_swap_out: Dict[int, int],
|
output = self.driver_worker.execute_model(execute_model_req)
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
|
||||||
output = self.driver_worker.execute_model(
|
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
@ -105,18 +97,9 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
|
|||||||
|
|
||||||
async def execute_model_async(
|
async def execute_model_async(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
blocks_to_swap_in: Dict[int, int],
|
output = await make_async(self.driver_worker.execute_model
|
||||||
blocks_to_swap_out: Dict[int, int],
|
)(execute_model_req=execute_model_req, )
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
num_lookahead_slots: int,
|
|
||||||
) -> List[SamplerOutput]:
|
|
||||||
output = await make_async(self.driver_worker.execute_model)(
|
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
num_lookahead_slots=num_lookahead_slots)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def check_health_async(self) -> None:
|
async def check_health_async(self) -> None:
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
SpeculativeConfig, VisionLanguageConfig)
|
SpeculativeConfig, VisionLanguageConfig)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
|
|
||||||
|
|
||||||
class ExecutorBase(ABC):
|
class ExecutorBase(ABC):
|
||||||
@ -68,12 +68,9 @@ class ExecutorBase(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_model(self,
|
def execute_model(
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
self,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
|
||||||
"""Executes at least one model step on the given sequences."""
|
"""Executes at least one model step on the given sequences."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -108,12 +105,7 @@ class ExecutorAsyncBase(ExecutorBase):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute_model_async(
|
async def execute_model_async(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
num_lookahead_slots: int,
|
|
||||||
) -> List[SamplerOutput]:
|
|
||||||
"""Executes one model step on the given sequences."""
|
"""Executes one model step on the given sequences."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple
|
|||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
make_async)
|
make_async)
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
@ -118,19 +118,8 @@ class GPUExecutor(ExecutorBase):
|
|||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
blocks_to_swap_in: Dict[int, int],
|
output = self.driver_worker.execute_model(execute_model_req)
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
num_lookahead_slots: int,
|
|
||||||
) -> List[SamplerOutput]:
|
|
||||||
output = self.driver_worker.execute_model(
|
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
num_lookahead_slots=num_lookahead_slots,
|
|
||||||
)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
@ -154,16 +143,8 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
|||||||
|
|
||||||
async def execute_model_async(
|
async def execute_model_async(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
num_lookahead_slots: int,
|
|
||||||
) -> List[SamplerOutput]:
|
) -> List[SamplerOutput]:
|
||||||
output = await make_async(self.driver_worker.execute_model)(
|
output = await make_async(self.driver_worker.execute_model
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
)(execute_model_req=execute_model_req, )
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
num_lookahead_slots=num_lookahead_slots)
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
from typing import Dict, List, Set, Tuple
|
from typing import List, Set, Tuple
|
||||||
|
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -45,20 +45,18 @@ class NeuronExecutor(ExecutorBase):
|
|||||||
"""
|
"""
|
||||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
def execute_model(self,
|
def execute_model(
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
self,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
blocks_to_swap_out: Dict[int, int],
|
assert (execute_model_req.blocks_to_swap_in == {}
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
and execute_model_req.blocks_to_swap_out == {}
|
||||||
num_lookahead_slots: int) -> List[SamplerOutput]:
|
and execute_model_req.blocks_to_copy == {}), (
|
||||||
assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
|
|
||||||
and blocks_to_copy == {}), (
|
|
||||||
"Cache operations are not supported for Neuron backend.")
|
"Cache operations are not supported for Neuron backend.")
|
||||||
assert num_lookahead_slots == 0, (
|
assert execute_model_req.num_lookahead_slots == 0, (
|
||||||
"lookahead not supported for Neuron backend.")
|
"lookahead not supported for Neuron backend.")
|
||||||
|
|
||||||
output = self.driver_worker.execute_model(
|
output = self.driver_worker.execute_model(
|
||||||
seq_group_metadata_list=seq_group_metadata_list)
|
execute_model_req.seq_group_metadata_list)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
@ -80,14 +78,11 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
|
|||||||
|
|
||||||
async def execute_model_async(
|
async def execute_model_async(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
num_lookahead_slots: int,
|
|
||||||
) -> List[SamplerOutput]:
|
) -> List[SamplerOutput]:
|
||||||
output = await make_async(self.driver_worker.execute_model)(
|
output = await make_async(
|
||||||
seq_group_metadata_list=seq_group_metadata_list, )
|
self.driver_worker.execute_model
|
||||||
|
)(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def check_health_async(self) -> None:
|
async def check_health_async(self) -> None:
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
|||||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||||
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
from vllm.executor.ray_utils import RayWorkerWrapper, ray
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||||
get_vllm_instance_id, make_async)
|
get_vllm_instance_id, make_async)
|
||||||
|
|
||||||
@ -166,21 +166,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
max_concurrent_workers=self.parallel_config.
|
max_concurrent_workers=self.parallel_config.
|
||||||
max_parallel_loading_workers)
|
max_parallel_loading_workers)
|
||||||
|
|
||||||
def execute_model(self,
|
def execute_model(
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
self,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
num_lookahead_slots: int = 0) -> List[SamplerOutput]:
|
|
||||||
all_outputs = self._run_workers(
|
all_outputs = self._run_workers(
|
||||||
"execute_model",
|
"execute_model",
|
||||||
driver_kwargs={
|
driver_kwargs={"execute_model_req": execute_model_req},
|
||||||
"seq_group_metadata_list": seq_group_metadata_list,
|
|
||||||
"blocks_to_swap_in": blocks_to_swap_in,
|
|
||||||
"blocks_to_swap_out": blocks_to_swap_out,
|
|
||||||
"blocks_to_copy": blocks_to_copy,
|
|
||||||
"num_lookahead_slots": num_lookahead_slots,
|
|
||||||
},
|
|
||||||
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
use_ray_compiled_dag=USE_RAY_COMPILED_DAG)
|
||||||
|
|
||||||
# Only the driver worker returns the sampling results.
|
# Only the driver worker returns the sampling results.
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""Sequence and its related classes."""
|
"""Sequence and its related classes."""
|
||||||
import copy
|
import copy
|
||||||
import enum
|
import enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
from vllm.block import LogicalTokenBlock
|
from vllm.block import LogicalTokenBlock
|
||||||
@ -734,3 +734,33 @@ class SamplerOutput:
|
|||||||
f"sampled_token_probs={sampled_token_probs_repr}, "
|
f"sampled_token_probs={sampled_token_probs_repr}, "
|
||||||
f"sampled_token_ids={sampled_token_ids_repr}, "
|
f"sampled_token_ids={sampled_token_ids_repr}, "
|
||||||
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecuteModelRequest:
|
||||||
|
"""The model execution request."""
|
||||||
|
# The sequence group metadata list.
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
|
# Blocks to swap in. Dict of CPU -> GPU block number.
|
||||||
|
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
|
||||||
|
# Blocks to swap out. Dict of GPU -> CPU block number.
|
||||||
|
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
|
||||||
|
# Blocks to copy. Source to a list of dest blocks.
|
||||||
|
blocks_to_copy: Dict[int, List[int]] = field(default_factory=dict)
|
||||||
|
# The number of slots for lookahead decoding.
|
||||||
|
num_lookahead_slots: int = 0
|
||||||
|
# The number of requests in the running queue.
|
||||||
|
running_queue_size: int = 0
|
||||||
|
|
||||||
|
def clone(
|
||||||
|
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
|
) -> "ExecuteModelRequest":
|
||||||
|
"""Clone the request with a new sequence group metadata list."""
|
||||||
|
return ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list=seq_group_metadata_list,
|
||||||
|
blocks_to_swap_in=self.blocks_to_swap_in.copy(),
|
||||||
|
blocks_to_swap_out=self.blocks_to_swap_out.copy(),
|
||||||
|
blocks_to_copy=self.blocks_to_copy.copy(),
|
||||||
|
num_lookahead_slots=self.num_lookahead_slots,
|
||||||
|
running_queue_size=self.running_queue_size,
|
||||||
|
)
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
from itertools import chain, count
|
from itertools import chain, count
|
||||||
from typing import Dict, Iterator, List, Optional, Tuple
|
from typing import Iterator, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||||
|
SequenceGroupMetadata)
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
from vllm.spec_decode.util import (get_all_seq_ids, nvtx_range,
|
||||||
@ -40,11 +41,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||||
def score_proposals(
|
def score_proposals(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
|
||||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
|
||||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
|
||||||
k: int,
|
|
||||||
proposals: SpeculativeProposals,
|
proposals: SpeculativeProposals,
|
||||||
) -> SpeculativeScores:
|
) -> SpeculativeScores:
|
||||||
"""Score the proposed tokens via the scorer model.
|
"""Score the proposed tokens via the scorer model.
|
||||||
@ -57,11 +54,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
no speculation is produced for that sequence.
|
no speculation is produced for that sequence.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
seq_group_metadata_list: The input sequence group metadata.
|
execute_model_req: The execution request.
|
||||||
blocks_to_swap_in: This is passed to the worker during scoring.
|
|
||||||
blocks_to_swap_out: This is passed to the worker during scoring.
|
|
||||||
blocks_to_copy: This is passed to the worker during scoring.
|
|
||||||
k: The fixed proposal length.
|
|
||||||
proposals: The speculative proposals to score.
|
proposals: The speculative proposals to score.
|
||||||
Returns:
|
Returns:
|
||||||
SpeculativeScores: The scores of each speculative token, along with
|
SpeculativeScores: The scores of each speculative token, along with
|
||||||
@ -80,28 +73,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
|
|
||||||
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
(spec_indices, non_spec_indices, target_seq_group_metadata_list,
|
||||||
num_scoring_tokens) = self._expand_batch(
|
num_scoring_tokens) = self._expand_batch(
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
seq_group_metadata_list=execute_model_req.seq_group_metadata_list,
|
||||||
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
proposal_token_ids_list=proposal_token_ids_list_without_skips,
|
||||||
proposal_lens_list=proposal_lens_list,
|
proposal_lens_list=proposal_lens_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
target_sampler_output = self._scorer_worker.execute_model(
|
target_sampler_output = self._scorer_worker.execute_model(
|
||||||
seq_group_metadata_list=target_seq_group_metadata_list,
|
execute_model_req=execute_model_req.clone(
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
seq_group_metadata_list=target_seq_group_metadata_list, ))
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
)
|
|
||||||
assert len(target_sampler_output) == 1, "expected single-step output"
|
assert len(target_sampler_output) == 1, "expected single-step output"
|
||||||
target_sampler_output = target_sampler_output[0]
|
target_sampler_output = target_sampler_output[0]
|
||||||
|
|
||||||
all_tokens, all_probs, spec_logprobs = self._contract_batch(
|
all_tokens, all_probs, spec_logprobs = self._contract_batch(
|
||||||
contracted_bs=len(seq_group_metadata_list),
|
contracted_bs=len(execute_model_req.seq_group_metadata_list),
|
||||||
target_sampler_output=target_sampler_output,
|
target_sampler_output=target_sampler_output,
|
||||||
proposals=proposals,
|
proposals=proposals,
|
||||||
num_scoring_tokens=num_scoring_tokens,
|
num_scoring_tokens=num_scoring_tokens,
|
||||||
non_spec_indices=non_spec_indices,
|
non_spec_indices=non_spec_indices,
|
||||||
spec_indices=spec_indices,
|
spec_indices=spec_indices,
|
||||||
k=k,
|
k=execute_model_req.num_lookahead_slots,
|
||||||
)
|
)
|
||||||
|
|
||||||
return SpeculativeScores(
|
return SpeculativeScores(
|
||||||
|
|||||||
@ -1,10 +1,9 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sequence import SequenceGroupMetadata
|
from vllm.sequence import ExecuteModelRequest
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -58,11 +57,7 @@ class SpeculativeProposer(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_proposals(
|
def get_proposals(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
max_proposal_len: int,
|
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -72,11 +67,7 @@ class SpeculativeScorer(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def score_proposals(
|
def score_proposals(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
|
||||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
|
||||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
|
||||||
k: int,
|
|
||||||
proposals: SpeculativeProposals,
|
proposals: SpeculativeProposals,
|
||||||
) -> SpeculativeScores:
|
) -> SpeculativeScores:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import Dict, List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||||
|
SequenceGroupMetadata)
|
||||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
from vllm.worker.worker import Worker
|
from vllm.worker.worker import Worker
|
||||||
@ -44,10 +45,7 @@ class MultiStepWorker(Worker):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def sampler_output(
|
def sampler_output(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
sample_len: int,
|
sample_len: int,
|
||||||
) -> Tuple[List[SamplerOutput], bool]:
|
) -> Tuple[List[SamplerOutput], bool]:
|
||||||
"""Run the model forward pass sample_len times. Returns the list of
|
"""Run the model forward pass sample_len times. Returns the list of
|
||||||
@ -57,26 +55,24 @@ class MultiStepWorker(Worker):
|
|||||||
|
|
||||||
For multi step worker, this indicator shall be True.
|
For multi step worker, this indicator shall be True.
|
||||||
"""
|
"""
|
||||||
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
|
self._raise_if_unsupported(execute_model_req)
|
||||||
blocks_to_swap_out, blocks_to_copy)
|
|
||||||
|
|
||||||
# Shallow copy input data so modifications (such as appending tokens)
|
# Shallow copy input data so modifications (such as appending tokens)
|
||||||
# do not cause side-effects.
|
# do not cause side-effects.
|
||||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||||
seq_group_metadata_list)
|
execute_model_req.seq_group_metadata_list)
|
||||||
|
copied_execute_model_req = execute_model_req.clone(
|
||||||
|
copied_seq_group_metadata_list)
|
||||||
|
|
||||||
# Assert enough KV space for sample_len tokens per sequence.
|
# Assert enough KV space for sample_len tokens per sequence.
|
||||||
self._assert_enough_kv_space(seq_group_metadata_list, sample_len)
|
self._assert_enough_kv_space(execute_model_req.seq_group_metadata_list,
|
||||||
|
sample_len)
|
||||||
|
|
||||||
# Run model sample_len times.
|
# Run model sample_len times.
|
||||||
model_outputs = []
|
model_outputs = []
|
||||||
for _ in range(sample_len):
|
for _ in range(sample_len):
|
||||||
model_output = super().execute_model(
|
model_output = super().execute_model(
|
||||||
seq_group_metadata_list=copied_seq_group_metadata_list,
|
execute_model_req=copied_execute_model_req)
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
)
|
|
||||||
assert (len(model_output) == 1
|
assert (len(model_output) == 1
|
||||||
), "composing multistep workers not supported"
|
), "composing multistep workers not supported"
|
||||||
model_output = model_output[0]
|
model_output = model_output[0]
|
||||||
@ -89,23 +85,13 @@ class MultiStepWorker(Worker):
|
|||||||
|
|
||||||
def get_spec_proposals(
|
def get_spec_proposals(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
max_proposal_len: int,
|
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
"""Produce speculations given an input batch of sequences. The number of
|
"""Produce speculations given an input batch of sequences. The number of
|
||||||
speculative tokens per sequence is determined by max_proposal_len.
|
speculative tokens per sequence is determined by max_proposal_len.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._proposer.get_proposals(
|
return self._proposer.get_proposals(execute_model_req)
|
||||||
seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out,
|
|
||||||
blocks_to_copy,
|
|
||||||
max_proposal_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _append_new_tokens(
|
def _append_new_tokens(
|
||||||
self, model_output: SamplerOutput,
|
self, model_output: SamplerOutput,
|
||||||
@ -196,20 +182,22 @@ class MultiStepWorker(Worker):
|
|||||||
|
|
||||||
def _raise_if_unsupported(
|
def _raise_if_unsupported(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""MultiStepWorker does not yet implement support for cache swap
|
"""MultiStepWorker does not yet implement support for cache swap
|
||||||
operations or beam search.
|
operations or beam search.
|
||||||
"""
|
"""
|
||||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
if any([
|
||||||
|
execute_model_req.blocks_to_swap_in,
|
||||||
|
execute_model_req.blocks_to_swap_out,
|
||||||
|
execute_model_req.blocks_to_copy
|
||||||
|
]):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"MultiStepWorker does not support cache operations")
|
"MultiStepWorker does not support cache operations")
|
||||||
|
|
||||||
if any(
|
if any(
|
||||||
len(seq_group_metadata.seq_data.keys()) != 1
|
len(seq_group_metadata.seq_data.keys()) != 1
|
||||||
for seq_group_metadata in seq_group_metadata_list):
|
for seq_group_metadata in
|
||||||
|
execute_model_req.seq_group_metadata_list):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"MultiStepWorker does not support beam search.")
|
"MultiStepWorker does not support beam search.")
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.spec_decode.interfaces import SpeculativeProposals
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||||
@ -46,13 +46,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||||||
# NGram don't need gpu sampler
|
# NGram don't need gpu sampler
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(self, execute_model_req: ExecuteModelRequest) -> None:
|
||||||
self,
|
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
||||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
|
||||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
|
||||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
|
||||||
) -> None:
|
|
||||||
"""NGram doesn't depend on model execution, just pass this function"""
|
"""NGram doesn't depend on model execution, just pass this function"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -71,10 +65,7 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
def sampler_output(
|
def sampler_output(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
sample_len: int,
|
sample_len: int,
|
||||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||||
"""NGram match algo to pick proposal candidate. Returns the list of
|
"""NGram match algo to pick proposal candidate. Returns the list of
|
||||||
@ -83,16 +74,11 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||||||
For ngram worker, we already done needed transposed internal, so the
|
For ngram worker, we already done needed transposed internal, so the
|
||||||
indicator pass to sampler_output_to_torch shall be False.
|
indicator pass to sampler_output_to_torch shall be False.
|
||||||
"""
|
"""
|
||||||
self._raise_if_unsupported(
|
self._raise_if_unsupported(execute_model_req)
|
||||||
seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out,
|
|
||||||
blocks_to_copy,
|
|
||||||
)
|
|
||||||
|
|
||||||
arr = []
|
arr = []
|
||||||
has_spec_out = False
|
has_spec_out = False
|
||||||
for seq_group_metadata in seq_group_metadata_list:
|
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
|
||||||
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
seq_data = next(iter(seq_group_metadata.seq_data.values()))
|
||||||
|
|
||||||
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
input_ids = torch.as_tensor(seq_data.get_token_ids(),
|
||||||
@ -135,17 +121,19 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||||||
indices = token_ids.unsqueeze(2)
|
indices = token_ids.unsqueeze(2)
|
||||||
|
|
||||||
token_probs = torch.zeros(
|
token_probs = torch.zeros(
|
||||||
(len(seq_group_metadata_list), sample_len, self.vocab_size),
|
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
||||||
|
self.vocab_size),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
token_probs.scatter_(2, indices, 1)
|
token_probs.scatter_(2, indices, 1)
|
||||||
token_logprobs = torch.zeros(
|
token_logprobs = torch.zeros(
|
||||||
(len(seq_group_metadata_list), sample_len, self.vocab_size),
|
(len(execute_model_req.seq_group_metadata_list), sample_len,
|
||||||
|
self.vocab_size),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
for i in range(len(seq_group_metadata_list)):
|
for i in range(len(execute_model_req.seq_group_metadata_list)):
|
||||||
outputs.append(
|
outputs.append(
|
||||||
SamplerOutput(
|
SamplerOutput(
|
||||||
outputs=None,
|
outputs=None,
|
||||||
@ -157,40 +145,32 @@ class NGramWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
def get_spec_proposals(
|
def get_spec_proposals(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
max_proposal_len: int,
|
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
"""Produce speculations given an input batch of sequences. The number of
|
"""Produce speculations given an input batch of sequences. The number of
|
||||||
speculative tokens per sequence is determined by max_proposal_len.
|
speculative tokens per sequence is determined by max_proposal_len.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self._proposer.get_proposals(
|
return self._proposer.get_proposals(execute_model_req)
|
||||||
seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out,
|
|
||||||
blocks_to_copy,
|
|
||||||
max_proposal_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _raise_if_unsupported(
|
def _raise_if_unsupported(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""NGramWorker does not yet implement support for cache swap
|
"""NGramWorker does not yet implement support for cache swap
|
||||||
operations or beam search.
|
operations or beam search.
|
||||||
"""
|
"""
|
||||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
if any([
|
||||||
|
execute_model_req.blocks_to_swap_in,
|
||||||
|
execute_model_req.blocks_to_swap_out,
|
||||||
|
execute_model_req.blocks_to_copy
|
||||||
|
]):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"NGramWorker does not support cache operations")
|
"NGramWorker does not support cache operations")
|
||||||
|
|
||||||
if any(
|
if any(
|
||||||
len(seq_group_metadata.seq_data.keys()) != 1
|
len(seq_group_metadata.seq_data.keys()) != 1
|
||||||
for seq_group_metadata in seq_group_metadata_list):
|
for seq_group_metadata in
|
||||||
|
execute_model_req.seq_group_metadata_list):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"NGramWorker does not support beam search.")
|
"NGramWorker does not support beam search.")
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||||
|
SequenceGroupMetadata)
|
||||||
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
@ -190,68 +191,36 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
|
||||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
|
||||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
|
||||||
num_lookahead_slots: int,
|
|
||||||
) -> List[SamplerOutput]:
|
|
||||||
"""Perform speculative decoding on the input batch.
|
"""Perform speculative decoding on the input batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert seq_group_metadata_list is not None, (
|
assert execute_model_req.seq_group_metadata_list is not None, (
|
||||||
"speculative decoding "
|
"speculative decoding "
|
||||||
"requires non-None seq_group_metadata_list")
|
"requires non-None seq_group_metadata_list")
|
||||||
|
|
||||||
#logger.info("spec_decode_worker.execute_model num_lookahead_slots=%d",
|
|
||||||
# num_lookahead_slots)
|
|
||||||
|
|
||||||
# If no spec tokens, call the proposer and scorer workers normally.
|
# If no spec tokens, call the proposer and scorer workers normally.
|
||||||
# Used for prefill.
|
# Used for prefill.
|
||||||
if num_lookahead_slots == 0 or len(seq_group_metadata_list) == 0:
|
if execute_model_req.num_lookahead_slots == 0 or len(
|
||||||
return self._run_no_spec(
|
execute_model_req.seq_group_metadata_list) == 0:
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
return self._run_no_spec(execute_model_req)
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._run_speculative_decoding_step(
|
return self._run_speculative_decoding_step(execute_model_req)
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
k=num_lookahead_slots,
|
|
||||||
)
|
|
||||||
|
|
||||||
@nvtx_range("spec_decode_worker._run_no_spec")
|
@nvtx_range("spec_decode_worker._run_no_spec")
|
||||||
def _run_no_spec(
|
def _run_no_spec(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
|
||||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
|
||||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
|
||||||
) -> List[SamplerOutput]:
|
|
||||||
"""Run a prefill step, without any speculation. The input is sent to the
|
"""Run a prefill step, without any speculation. The input is sent to the
|
||||||
proposer and scorer model so that the KV cache is consistent between the
|
proposer and scorer model so that the KV cache is consistent between the
|
||||||
two.
|
two.
|
||||||
"""
|
"""
|
||||||
#logger.info("run proposer worker no spec")
|
#logger.info("run proposer worker no spec")
|
||||||
|
|
||||||
self.proposer_worker.execute_model(
|
self.proposer_worker.execute_model(execute_model_req)
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
)
|
|
||||||
|
|
||||||
#logger.info("run target worker no spec")
|
#logger.info("run target worker no spec")
|
||||||
sampler_output = self.scorer_worker.execute_model(
|
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||||
seq_group_metadata_list=seq_group_metadata_list,
|
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
|
||||||
blocks_to_copy=blocks_to_copy,
|
|
||||||
)
|
|
||||||
assert len(sampler_output) == 1
|
assert len(sampler_output) == 1
|
||||||
sampler_output = sampler_output[0]
|
sampler_output = sampler_output[0]
|
||||||
|
|
||||||
@ -265,12 +234,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||||
def _run_speculative_decoding_step(
|
def _run_speculative_decoding_step(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
blocks_to_swap_in: Optional[Dict[int, int]],
|
|
||||||
blocks_to_swap_out: Optional[Dict[int, int]],
|
|
||||||
blocks_to_copy: Optional[Dict[int, List[int]]],
|
|
||||||
k: int,
|
|
||||||
) -> List[SamplerOutput]:
|
|
||||||
"""Execute a single step of speculative decoding.
|
"""Execute a single step of speculative decoding.
|
||||||
|
|
||||||
This invokes the proposer worker to get k speculative tokens for each
|
This invokes the proposer worker to get k speculative tokens for each
|
||||||
@ -282,33 +246,25 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
#logger.info("get spec proposals")
|
#logger.info("get spec proposals")
|
||||||
# Generate proposals using draft worker.
|
# Generate proposals using draft worker.
|
||||||
assert blocks_to_swap_in is not None
|
proposals = self.proposer_worker.get_spec_proposals(execute_model_req)
|
||||||
assert blocks_to_swap_out is not None
|
|
||||||
assert blocks_to_copy is not None
|
|
||||||
proposals = self.proposer_worker.get_spec_proposals(
|
|
||||||
seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
|
|
||||||
blocks_to_copy, k)
|
|
||||||
|
|
||||||
#logger.info("score proposals")
|
#logger.info("score proposals")
|
||||||
proposal_scores = self.scorer.score_proposals(
|
proposal_scores = self.scorer.score_proposals(
|
||||||
seq_group_metadata_list,
|
execute_model_req,
|
||||||
blocks_to_swap_in,
|
|
||||||
blocks_to_swap_out,
|
|
||||||
blocks_to_copy,
|
|
||||||
k,
|
|
||||||
proposals,
|
proposals,
|
||||||
)
|
)
|
||||||
|
|
||||||
#logger.info("verify proposals")
|
#logger.info("verify proposals")
|
||||||
accepted_token_ids, target_logprobs = self._verify_tokens(
|
accepted_token_ids, target_logprobs = self._verify_tokens(
|
||||||
seq_group_metadata_list, proposal_scores, proposals, k)
|
execute_model_req.seq_group_metadata_list, proposal_scores,
|
||||||
|
proposals, execute_model_req.num_lookahead_slots)
|
||||||
|
|
||||||
#logger.info("create output list")
|
#logger.info("create output list")
|
||||||
return self._create_output_sampler_list(
|
return self._create_output_sampler_list(
|
||||||
seq_group_metadata_list,
|
execute_model_req.seq_group_metadata_list,
|
||||||
accepted_token_ids,
|
accepted_token_ids,
|
||||||
target_logprobs=target_logprobs,
|
target_logprobs=target_logprobs,
|
||||||
k=k)
|
k=execute_model_req.num_lookahead_slots)
|
||||||
|
|
||||||
@nvtx_range("spec_decode_worker._verify_tokens")
|
@nvtx_range("spec_decode_worker._verify_tokens")
|
||||||
def _verify_tokens(
|
def _verify_tokens(
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||||
|
SequenceGroupMetadata)
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeProposer)
|
SpeculativeProposer)
|
||||||
from vllm.spec_decode.util import sampler_output_to_torch
|
from vllm.spec_decode.util import sampler_output_to_torch
|
||||||
@ -40,17 +41,15 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
|
|
||||||
def get_proposals(
|
def get_proposals(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
execute_model_req: ExecuteModelRequest,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
|
||||||
blocks_to_swap_out: Dict[int, int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]],
|
|
||||||
proposal_len: int,
|
|
||||||
) -> SpeculativeProposals:
|
) -> SpeculativeProposals:
|
||||||
"""Get speculative proposals given the input batch.
|
"""Get speculative proposals given the input batch.
|
||||||
|
|
||||||
Sequences which would exceed the max model length are skipped during
|
Sequences which would exceed the max model length are skipped during
|
||||||
speculation.
|
speculation.
|
||||||
"""
|
"""
|
||||||
|
proposal_len = execute_model_req.num_lookahead_slots
|
||||||
|
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||||
|
|
||||||
# Split speculative- and non-speculative- sequences.
|
# Split speculative- and non-speculative- sequences.
|
||||||
(
|
(
|
||||||
@ -66,11 +65,12 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
# token_ids is like [batch] format in proposal_len size list,
|
# token_ids is like [batch] format in proposal_len size list,
|
||||||
# while if it is false, the format would be [proposal_len]
|
# while if it is false, the format would be [proposal_len]
|
||||||
# in batch size list
|
# in batch size list
|
||||||
maybe_sampler_output, transposed = self._worker.sampler_output(
|
nonzero_execute_model_req = ExecuteModelRequest(
|
||||||
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
seq_group_metadata_list=nonzero_proposal_len_seqs,
|
||||||
blocks_to_swap_in=blocks_to_swap_in,
|
num_lookahead_slots=proposal_len,
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
)
|
||||||
blocks_to_copy=blocks_to_copy,
|
maybe_sampler_output, transposed = self._worker.sampler_output(
|
||||||
|
execute_model_req=nonzero_execute_model_req,
|
||||||
sample_len=proposal_len,
|
sample_len=proposal_len,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.distributed import (broadcast_tensor_dict,
|
|||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
from vllm.worker.cpu_model_runner import CPUModelRunner
|
from vllm.worker.cpu_model_runner import CPUModelRunner
|
||||||
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
||||||
@ -256,22 +256,24 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
|
||||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
|
||||||
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
|
||||||
) -> List[SamplerOutput]:
|
) -> List[SamplerOutput]:
|
||||||
|
|
||||||
|
if execute_model_req is None:
|
||||||
|
seq_group_metadata_list = None
|
||||||
|
else:
|
||||||
|
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||||
|
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
assert seq_group_metadata_list is not None
|
assert seq_group_metadata_list is not None
|
||||||
num_seq_groups: int = len(seq_group_metadata_list)
|
num_seq_groups: int = len(seq_group_metadata_list)
|
||||||
assert blocks_to_swap_in is not None
|
assert execute_model_req is not None
|
||||||
assert blocks_to_swap_out is not None
|
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||||
assert blocks_to_copy is not None
|
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||||
assert len(blocks_to_swap_in) == 0
|
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||||
assert len(blocks_to_swap_out) == 0
|
|
||||||
data: Dict[str, Any] = {
|
data: Dict[str, Any] = {
|
||||||
"num_seq_groups": num_seq_groups,
|
"num_seq_groups": num_seq_groups,
|
||||||
"blocks_to_copy": blocks_to_copy,
|
"blocks_to_copy": execute_model_req.blocks_to_copy,
|
||||||
}
|
}
|
||||||
broadcast_tensor_dict(data, src=0)
|
broadcast_tensor_dict(data, src=0)
|
||||||
else:
|
else:
|
||||||
@ -279,7 +281,6 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
|||||||
num_seq_groups = data["num_seq_groups"]
|
num_seq_groups = data["num_seq_groups"]
|
||||||
blocks_to_copy = data["blocks_to_copy"]
|
blocks_to_copy = data["blocks_to_copy"]
|
||||||
|
|
||||||
assert blocks_to_copy is not None
|
|
||||||
self.cache_copy(blocks_to_copy)
|
self.cache_copy(blocks_to_copy)
|
||||||
|
|
||||||
# If there is no input, we don't need to execute the model.
|
# If there is no input, we don't need to execute the model.
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.custom_all_reduce import (
|
|||||||
init_custom_ar)
|
init_custom_ar)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.worker.cache_engine import CacheEngine
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
from vllm.worker.model_runner import ModelRunner
|
from vllm.worker.model_runner import ModelRunner
|
||||||
from vllm.worker.worker_base import WorkerBase
|
from vllm.worker.worker_base import WorkerBase
|
||||||
@ -211,19 +211,21 @@ class Worker(WorkerBase):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
|
||||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
|
||||||
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
|
||||||
num_lookahead_slots: int = 0,
|
|
||||||
) -> List[SamplerOutput]:
|
) -> List[SamplerOutput]:
|
||||||
|
|
||||||
|
if execute_model_req is None:
|
||||||
|
seq_group_metadata_list = None
|
||||||
|
else:
|
||||||
|
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||||
|
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
assert seq_group_metadata_list is not None
|
assert seq_group_metadata_list is not None
|
||||||
|
assert execute_model_req is not None
|
||||||
num_seq_groups = len(seq_group_metadata_list)
|
num_seq_groups = len(seq_group_metadata_list)
|
||||||
assert blocks_to_swap_in is not None
|
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
|
||||||
assert blocks_to_swap_out is not None
|
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
|
||||||
assert blocks_to_copy is not None
|
blocks_to_copy = execute_model_req.blocks_to_copy
|
||||||
data: Dict[str, Any] = {
|
data: Dict[str, Any] = {
|
||||||
"num_seq_groups": num_seq_groups,
|
"num_seq_groups": num_seq_groups,
|
||||||
"blocks_to_swap_in": blocks_to_swap_in,
|
"blocks_to_swap_in": blocks_to_swap_in,
|
||||||
@ -238,9 +240,6 @@ class Worker(WorkerBase):
|
|||||||
blocks_to_swap_out = data["blocks_to_swap_out"]
|
blocks_to_swap_out = data["blocks_to_swap_out"]
|
||||||
blocks_to_copy = data["blocks_to_copy"]
|
blocks_to_copy = data["blocks_to_copy"]
|
||||||
|
|
||||||
assert blocks_to_swap_in is not None
|
|
||||||
assert blocks_to_swap_out is not None
|
|
||||||
assert blocks_to_copy is not None
|
|
||||||
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||||
|
|
||||||
# If there is no input, we don't need to execute the model.
|
# If there is no input, we don't need to execute the model.
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from typing import Dict, List, Set, Tuple
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||||
update_environment_variables)
|
update_environment_variables)
|
||||||
|
|
||||||
@ -48,10 +48,8 @@ class WorkerBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
self,
|
||||||
blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int,
|
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||||
int],
|
|
||||||
blocks_to_copy: Dict[int, List[int]]) -> List[SamplerOutput]:
|
|
||||||
"""Executes at least one model step on the given sequences, unless no
|
"""Executes at least one model step on the given sequences, unless no
|
||||||
sequences are provided."""
|
sequences are provided."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user