mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[V1][Spec Decode] Ngram Spec Decode (#12193)
Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
parent
367cb8ce8c
commit
80f63a3966
@ -4,10 +4,12 @@ from typing import List, Optional
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
model: str = "facebook/opt-125m",
|
||||
@ -38,6 +40,7 @@ def create_scheduler(
|
||||
return Scheduler(scheduler_config,
|
||||
model_config,
|
||||
cache_config,
|
||||
speculative_config=None,
|
||||
lora_config=None,
|
||||
log_stats=True)
|
||||
|
||||
@ -46,8 +49,12 @@ def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[List[PlaceholderRange]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
):
|
||||
sampling_params = SamplingParams()
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
@ -64,7 +71,7 @@ def create_requests(
|
||||
multi_modal_inputs=mm_inputs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
arrival_time=0,
|
||||
)
|
||||
requests.append(request)
|
||||
@ -195,7 +202,7 @@ def test_schedule_partial_requests():
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[0] * len(requests),
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
@ -215,6 +222,189 @@ def test_schedule_partial_requests():
|
||||
assert requests[2].request_id not in output.num_scheduled_tokens
|
||||
|
||||
|
||||
def test_stop_via_update_from_output():
|
||||
"""Test stopping behavior through update_from_output"""
|
||||
scheduler = create_scheduler()
|
||||
|
||||
# Test case 1: Stop on EOS token
|
||||
requests = create_requests(num_requests=2, max_tokens=10)
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
scheduler.scheduled_req_ids.add(req.request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID],
|
||||
[10,
|
||||
11]], # First request hits EOS, second continues
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped, second continues
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == requests[1].request_id
|
||||
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
||||
assert requests[0].request_id in scheduler.finished_req_ids
|
||||
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
|
||||
assert list(requests[1].output_token_ids) == [10, 11]
|
||||
|
||||
# Test case 2: Stop on custom stop token
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=2,
|
||||
max_tokens=10,
|
||||
stop_token_ids=[42, 43])
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
scheduler.scheduled_req_ids.add(req.request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped on custom token
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == requests[1].request_id
|
||||
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
||||
assert requests[0].stop_reason == 42
|
||||
assert requests[0].request_id in scheduler.finished_req_ids
|
||||
assert list(requests[0].output_token_ids) == [10, 42]
|
||||
assert list(requests[1].output_token_ids) == [13, 14]
|
||||
|
||||
# Test case 3: Stop on max tokens
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=2, max_tokens=2)
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
scheduler.scheduled_req_ids.add(req.request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped due to length
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == requests[1].request_id
|
||||
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
assert requests[0].request_id in scheduler.finished_req_ids
|
||||
assert list(requests[0].output_token_ids) == [10, 11
|
||||
] # Truncated to max_tokens
|
||||
assert list(requests[1].output_token_ids) == [13]
|
||||
|
||||
# Test case 4: Ignore EOS flag
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=1, max_tokens=10)
|
||||
requests[0].sampling_params.ignore_eos = True
|
||||
requests[0].num_computed_tokens = requests[0].num_tokens
|
||||
scheduler.requests[requests[0].request_id] = requests[0]
|
||||
scheduler.running.append(requests[0])
|
||||
scheduler.scheduled_req_ids.add(requests[0].request_id)
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={})
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify request continues past EOS
|
||||
assert len(scheduler.running) == 1
|
||||
assert not requests[0].is_finished()
|
||||
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
|
||||
|
||||
|
||||
def test_schedule_concurrent_batches():
|
||||
scheduler = create_scheduler(
|
||||
max_num_batched_tokens=1024,
|
||||
@ -243,7 +433,7 @@ def test_schedule_concurrent_batches():
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[0],
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
@ -259,7 +449,7 @@ def test_schedule_concurrent_batches():
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[0],
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
|
||||
49
tests/v1/e2e/test_ngram_spec_decode.py
Normal file
49
tests/v1/e2e/test_ngram_spec_decode.py
Normal file
@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_prompts():
|
||||
return [
|
||||
"Can you repeat the sentence ten times, this is a sentence.",
|
||||
"Can you repeat the sentence ten times, this is a test.",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampling_config():
|
||||
# Only support greedy for now
|
||||
return SamplingParams(temperature=0, max_tokens=30, ignore_eos=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return "meta-llama/Meta-Llama-3-8B-Instruct"
|
||||
|
||||
|
||||
def test_ngram_correctness(monkeypatch, test_prompts, sampling_config,
|
||||
model_name):
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using ngram speculative decoding.
|
||||
'''
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
ref_llm = LLM(model=model_name)
|
||||
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
|
||||
spec_llm = LLM(model=model_name,
|
||||
speculative_model='[ngram]',
|
||||
ngram_prompt_lookup_max=5,
|
||||
ngram_prompt_lookup_min=3,
|
||||
num_speculative_tokens=3)
|
||||
spec_outputs = spec_llm.generate(test_prompts, sampling_config)
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
assert ref_output.outputs[0].text == spec_output.outputs[0].text, \
|
||||
(f"ref_output: {ref_output.outputs[0].text},"
|
||||
f"spec_output: {spec_output.outputs[0].text}")
|
||||
del spec_llm
|
||||
173
tests/v1/sample/test_rejection_sampler.py
Normal file
173
tests/v1/sample/test_rejection_sampler.py
Normal file
@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampler():
|
||||
return RejectionSampler()
|
||||
|
||||
|
||||
def create_logits_tensor(token_ids: List[int],
|
||||
vocab_size: int = 100) -> torch.Tensor:
|
||||
"""Helper function to create logits tensor that
|
||||
will produce desired token ids on argmax"""
|
||||
logits = torch.full((len(token_ids), vocab_size), -100.0).cuda()
|
||||
for i, token_id in enumerate(token_ids):
|
||||
logits[i, token_id] = 100.0
|
||||
return logits
|
||||
|
||||
|
||||
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
|
||||
batch_size = len(spec_tokens)
|
||||
return SamplingMetadata(
|
||||
temperature=0.0,
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
rejection_sampling=True,
|
||||
spec_token_ids=spec_tokens,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
no_top_p=False,
|
||||
no_top_k=False,
|
||||
min_p=torch.empty(batch_size, ),
|
||||
no_min_p=True,
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
no_penalties=False,
|
||||
prompt_token_ids=None,
|
||||
frequency_penalties=torch.tensor([]),
|
||||
presence_penalties=torch.tensor([]),
|
||||
repetition_penalties=torch.tensor([]),
|
||||
output_token_ids=[],
|
||||
min_tokens=[],
|
||||
stop_token_ids=[],
|
||||
logit_bias=[None] * batch_size,
|
||||
)
|
||||
|
||||
|
||||
def test_perfect_match(sampler):
|
||||
"""Test when output tokens perfectly match speculated tokens"""
|
||||
spec_tokens = [[1, 2, 3]]
|
||||
output_tokens = [1, 2, 3, 4] # 4 is the bonus token
|
||||
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
expected = torch.tensor([[1, 2, 3, 4]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_early_mismatch(sampler):
|
||||
"""Test when there's an early mismatch in tokens"""
|
||||
spec_tokens = [[1, 2, 3]]
|
||||
output_tokens = [1, 5, 3, 4] # Mismatch at position 1
|
||||
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_multiple_sequences(sampler):
|
||||
"""Test handling multiple sequences of speculated tokens"""
|
||||
spec_tokens = [[1, 2], [3]]
|
||||
output_tokens = [1, 2, 5, 3, 4] # Two sequences with bonus tokens 5 and 4
|
||||
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_single_token_sequence(sampler):
|
||||
"""Test handling sequences with single token"""
|
||||
spec_tokens = [[1]]
|
||||
output_tokens = [1, 2] # Single token with bonus token 2
|
||||
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_empty_sequence(sampler):
|
||||
"""Test handling empty sequence of speculated tokens"""
|
||||
spec_tokens: List[List[int]] = [[]]
|
||||
output_tokens = [5] # Just the bonus token
|
||||
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
def test_multiple_mismatches(sampler):
|
||||
"""Test handling multiple sequences with mismatches"""
|
||||
spec_tokens = [[1, 2, 3], [4, 5, 6]]
|
||||
output_tokens = [1, 2, 7, 6, 4, 8, 6, 9] # Mismatches in both sequences
|
||||
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID],
|
||||
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spec_tokens,output_tokens,expected",
|
||||
[
|
||||
([[1, 2]], [1, 2, 3], [[1, 2, 3]]), # Perfect match with bonus
|
||||
([[1]], [2, 3], [[2, INVALID_TOKEN_ID]]), # First mismatch
|
||||
([[1, 2], [3, 4]], [1, 5, 6, 3, 4, 7], [[1, 5, INVALID_TOKEN_ID],
|
||||
[3, 4, 7]]), # Mixed matches
|
||||
])
|
||||
def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
|
||||
"""Parametrized test for various matching scenarios"""
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
expected_tensor = torch.tensor(expected,
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
assert torch.equal(output.sampled_token_ids, expected_tensor)
|
||||
|
||||
|
||||
def test_logits_shape_handling(sampler):
|
||||
"""Test handling of different logits tensor shapes"""
|
||||
spec_tokens = [[1, 2]]
|
||||
output_tokens = [1, 2, 3]
|
||||
vocab_size = 1000
|
||||
|
||||
metadata = create_sampling_metadata(spec_tokens)
|
||||
logits = create_logits_tensor(output_tokens, vocab_size)
|
||||
|
||||
output = sampler(logits, metadata)
|
||||
expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device)
|
||||
assert torch.equal(output.sampled_token_ids, expected)
|
||||
assert logits.shape[-1] == vocab_size
|
||||
@ -77,6 +77,7 @@ def _create_default_sampling_metadata(
|
||||
temperature=torch.full((batch_size, ), 0.0),
|
||||
all_greedy=True,
|
||||
all_random=False,
|
||||
rejection_sampling=False,
|
||||
top_p=torch.empty(batch_size, ),
|
||||
top_k=torch.empty(batch_size, ),
|
||||
no_top_p=True,
|
||||
@ -88,6 +89,7 @@ def _create_default_sampling_metadata(
|
||||
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
|
||||
vocab_size, device),
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=[],
|
||||
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
|
||||
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
|
||||
|
||||
32
tests/v1/spec_decode/test_ngram.py
Normal file
32
tests/v1/spec_decode/test_ngram.py
Normal file
@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def proposer():
|
||||
return NgramProposer()
|
||||
|
||||
|
||||
def test_kmp_lps_array(proposer):
|
||||
assert proposer._kmp_lps_array([]) == []
|
||||
assert proposer._kmp_lps_array([1]) == [0]
|
||||
assert proposer._kmp_lps_array([1, 1, 1]) == [0, 1, 2]
|
||||
assert proposer._kmp_lps_array([1, 2, 3, 4]) == [0, 0, 0, 0]
|
||||
assert proposer._kmp_lps_array([1, 2, 1, 2, 3]) == [0, 0, 1, 2, 0]
|
||||
|
||||
|
||||
def test_find_subarray_kmp(proposer):
|
||||
X = ConstantList([1, 2, 3, 4, 1, 2, 3, 5, 6])
|
||||
assert proposer._find_subarray_kmp(X, 2, 2) is None
|
||||
X = ConstantList([1, 2, 3, 4, 1, 2, 3])
|
||||
assert proposer._find_subarray_kmp(X, 2, 3) == [4, 1, 2]
|
||||
assert proposer._find_subarray_kmp(X, 2, 2) == [4, 1]
|
||||
assert proposer._find_subarray_kmp(X, 1, 3) == [4, 1, 2]
|
||||
assert proposer._find_subarray_kmp(X, 1, 2) == [4, 1]
|
||||
X = ConstantList([1, 3, 6, 2, 3, 4, 1, 2, 3])
|
||||
assert proposer._find_subarray_kmp(X, 2, 3) == [4, 1, 2]
|
||||
# Return on the first match
|
||||
assert proposer._find_subarray_kmp(X, 1, 3) == [6, 2, 3]
|
||||
@ -92,6 +92,7 @@ def _construct_expected_sampling_metadata(
|
||||
device=device),
|
||||
all_greedy=False,
|
||||
all_random=True,
|
||||
rejection_sampling=False,
|
||||
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
|
||||
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
|
||||
no_top_p=all(x == 1.0 for x in top_p),
|
||||
@ -116,6 +117,7 @@ def _construct_expected_sampling_metadata(
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=[],
|
||||
min_tokens=min_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
no_penalties=(all(x == 0 for x in presence_penalties)
|
||||
@ -205,7 +207,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
|
||||
# Generate the sampling metadata
|
||||
sampling_metadata = input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids, skip_copy=False)
|
||||
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
|
||||
|
||||
# Create expected output.
|
||||
expected_sampling_metadata = _construct_expected_sampling_metadata(
|
||||
|
||||
@ -66,6 +66,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
@ -109,6 +110,7 @@ def test_update_states_request_finished(model_runner):
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids={req_id},
|
||||
@ -137,6 +139,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids={},
|
||||
@ -160,6 +163,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
scheduled_cached_reqs=[cached_req_data],
|
||||
num_scheduled_tokens={req_id: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
@ -188,6 +192,7 @@ def test_update_states_no_changes(model_runner):
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={req_id: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
@ -220,6 +225,7 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={req_ids[0]: 1},
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
|
||||
@ -124,9 +124,8 @@ class CudaPlatformBase(Platform):
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
elif vllm_config.speculative_config:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding is not yet supported on VLLM V1."
|
||||
)
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
|
||||
@ -82,6 +82,11 @@ class KVCacheManager:
|
||||
self.req_to_block_hashes: DefaultDict[
|
||||
str, List[BlockHashType]] = defaultdict(list)
|
||||
|
||||
# {req_id: The number of cached blocks for this given request}
|
||||
# This is used to track the number of cached blocks for each request.
|
||||
# This is only used to track the RUNNING requests, we do not track the
|
||||
# data for reempted ones.
|
||||
self.num_cached_block: Dict[str, int] = defaultdict(int)
|
||||
self.prefix_cache_stats = PrefixCacheStats()
|
||||
|
||||
@property
|
||||
@ -241,23 +246,25 @@ class KVCacheManager:
|
||||
if not self.enable_caching:
|
||||
return new_blocks
|
||||
|
||||
# NOTE(rickyx): We are assuming the `num_tokens` are actual
|
||||
# tokens rather than lookahead slots (e.g. for speculative decoding).
|
||||
# TODO(rickyx): When supporting speculative decoding, we will need to
|
||||
# differentiate between them so that we can know how many blocks are
|
||||
# full after appending the actual tokens.
|
||||
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
|
||||
num_computed_full_blocks = num_computed_tokens // self.block_size
|
||||
new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks]
|
||||
num_cached_blocks = self.num_cached_block[request.request_id]
|
||||
# Speculated tokens might be rejected in the future, so we does
|
||||
# not cache any speculated tokens. We only cache blocks with
|
||||
# generated (accepted) tokens.
|
||||
num_full_blocks_after_append = (num_computed_tokens + num_tokens - len(
|
||||
request.spec_token_ids)) // self.block_size
|
||||
new_full_blocks = req_blocks[
|
||||
num_cached_blocks:num_full_blocks_after_append]
|
||||
|
||||
if new_full_blocks:
|
||||
self._cache_full_blocks(
|
||||
request=request,
|
||||
blk_start_idx=num_computed_full_blocks,
|
||||
blk_start_idx=num_cached_blocks,
|
||||
# The new full blocks are the full blocks that are not computed.
|
||||
full_blocks=new_full_blocks,
|
||||
prev_block=(req_blocks[num_computed_full_blocks - 1]
|
||||
if num_computed_full_blocks > 0 else None))
|
||||
|
||||
prev_block=(req_blocks[num_cached_blocks -
|
||||
1] if num_cached_blocks > 0 else None))
|
||||
self.num_cached_block[
|
||||
request.request_id] = num_full_blocks_after_append
|
||||
return new_blocks
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
@ -281,6 +288,8 @@ class KVCacheManager:
|
||||
if block.ref_cnt == 0:
|
||||
self.free_block_queue.append(block)
|
||||
|
||||
self.num_cached_block.pop(request.request_id, None)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
|
||||
@ -4,7 +4,8 @@ import time
|
||||
from collections import deque
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
compute_encoder_budget)
|
||||
@ -28,11 +29,13 @@ class Scheduler:
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.speculative_config = speculative_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Scheduling constraints.
|
||||
@ -96,12 +99,14 @@ class Scheduler:
|
||||
def schedule(self) -> "SchedulerOutput":
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
# Each request just has the num_computed_tokens and num_tokens,
|
||||
# which is equal to len(prompt_token_ids) + len(output_token_ids).
|
||||
# Each request just has the num_computed_tokens and
|
||||
# num_tokens_with_spec. num_tokens_with_spec =
|
||||
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
|
||||
# At each step, the scheduler tries to assign tokens to the requests
|
||||
# so that each request's num_computed_tokens can catch up its
|
||||
# num_tokens. This is general enough to cover chunked prefills,
|
||||
# prefix caching, and the "jump decoding" optimization in the future.
|
||||
# num_tokens_with_spec. This is general enough to cover
|
||||
# chunked prefills, prefix caching, speculative decoding,
|
||||
# and the "jump decoding" optimization in the future.
|
||||
|
||||
scheduled_new_reqs: List[Request] = []
|
||||
scheduled_resumed_reqs: List[Request] = []
|
||||
@ -114,7 +119,8 @@ class Scheduler:
|
||||
# Encoder-related.
|
||||
scheduled_encoder_inputs: Dict[str, List[int]] = {}
|
||||
encoder_budget = self.max_num_encoder_input_tokens
|
||||
|
||||
# Spec decode-related.
|
||||
scheduled_spec_decode_tokens: Dict[str, List[int]] = {}
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
@ -126,7 +132,8 @@ class Scheduler:
|
||||
req_index += 1
|
||||
continue
|
||||
|
||||
num_new_tokens = request.num_tokens - request.num_computed_tokens
|
||||
num_new_tokens = (request.num_tokens_with_spec -
|
||||
request.num_computed_tokens)
|
||||
num_new_tokens = min(num_new_tokens, token_budget)
|
||||
assert num_new_tokens > 0
|
||||
|
||||
@ -189,6 +196,11 @@ class Scheduler:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Speculative decode related.
|
||||
if request.spec_token_ids:
|
||||
scheduled_spec_decode_tokens[
|
||||
request.request_id] = request.spec_token_ids
|
||||
|
||||
# Record the LoRAs in scheduled_running_reqs
|
||||
requested_loras: Set[int] = set()
|
||||
if self.lora_config:
|
||||
@ -338,6 +350,7 @@ class Scheduler:
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_encoder_inputs=scheduled_encoder_inputs,
|
||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||
num_common_prefix_blocks=num_common_prefix_blocks,
|
||||
# finished_req_ids is an existing state in the scheduler,
|
||||
# instead of being newly scheduled in this step.
|
||||
@ -447,11 +460,11 @@ class Scheduler:
|
||||
scheduler_output: "SchedulerOutput",
|
||||
model_runner_output: "ModelRunnerOutput",
|
||||
) -> EngineCoreOutputs:
|
||||
# NOTE(woosuk): This method doesn't consider speculative decoding.
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
|
||||
new_running: List[Request] = []
|
||||
outputs: List[EngineCoreOutput] = []
|
||||
|
||||
@ -466,11 +479,30 @@ class Scheduler:
|
||||
new_running.append(request)
|
||||
continue
|
||||
|
||||
request.num_computed_tokens += num_tokens_scheduled
|
||||
# When the request's num_computed_tokens catches up its num_tokens,
|
||||
# the request generates output tokens. Otherwise, we ignore the
|
||||
# sampler output for the request.
|
||||
assert request.num_computed_tokens <= request.num_tokens
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
generated_token_ids = sampled_token_ids[req_index]
|
||||
if req_id not in scheduler_output.scheduled_spec_decode_tokens:
|
||||
# When the request's num_computed_tokens catches up
|
||||
# its num_tokens, the request generates output tokens.
|
||||
# Otherwise, we ignore the sampler output for the request.
|
||||
request.num_computed_tokens += num_tokens_scheduled
|
||||
assert request.num_computed_tokens <= request.num_tokens
|
||||
else:
|
||||
# num_computed_tokens_step represents the number of tokens
|
||||
# processed in the current step, considering scheduled
|
||||
# tokens and rejections.
|
||||
# It is calculated as:
|
||||
# num_computed_tokens_step = num_scheduled_tokens -
|
||||
# num_tokens_rejected,
|
||||
# where num_tokens_rejected is given by:
|
||||
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
|
||||
scheduled_spec_token_ids = (
|
||||
scheduler_output.scheduled_spec_decode_tokens[req_id])
|
||||
|
||||
num_computed_tokens_step = num_scheduled_tokens[req_id] - (
|
||||
len(scheduled_spec_token_ids) + 1 -
|
||||
len(generated_token_ids))
|
||||
request.num_computed_tokens += num_computed_tokens_step
|
||||
|
||||
cached_encoder_input_ids = (
|
||||
self.encoder_cache_manager.get_cached_input_ids(request))
|
||||
@ -485,27 +517,32 @@ class Scheduler:
|
||||
self.encoder_cache_manager.free_encoder_input(
|
||||
request, input_id)
|
||||
|
||||
if request.num_computed_tokens >= request.num_tokens:
|
||||
# Clear the spec tokens as the request has generated
|
||||
# a new token. Here, We assume all spec tokens are verified
|
||||
# if we perform speculative decoding for this request.
|
||||
# Therefore, we can clear all spec tokens after
|
||||
# the generation step.
|
||||
request.clear_spec_tokens()
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
|
||||
stopped = False
|
||||
new_logprobs = None
|
||||
new_token_ids = None
|
||||
new_token_ids: List[int] = []
|
||||
|
||||
if request.num_computed_tokens == request.num_tokens:
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
# NOTE(woosuk): Currently, we assume that each request
|
||||
# generates at most one token at each step.
|
||||
token_id = sampled_token_ids[req_index]
|
||||
request.append_output_token_ids(token_id)
|
||||
num_new_tokens = 1
|
||||
# TODO: Update the KV cache manager for prefix caching.
|
||||
if request.num_computed_tokens >= request.num_tokens:
|
||||
for output_token_id in generated_token_ids:
|
||||
request.append_output_token_ids(output_token_id)
|
||||
new_token_ids.append(output_token_id)
|
||||
|
||||
# Check for stop and update request state.
|
||||
# This must be called before we make the EngineCoreOutput.
|
||||
stopped = self._check_stop(request)
|
||||
if stopped:
|
||||
self._free_request(request)
|
||||
# Check for stop and update request state.
|
||||
# This must be called before we make the EngineCoreOutput.
|
||||
stopped = self._check_stop(request)
|
||||
if stopped:
|
||||
self._free_request(request)
|
||||
break
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if request.sampling_params.logprobs is not None:
|
||||
@ -514,8 +551,6 @@ class Scheduler:
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
|
||||
new_token_ids = request.output_token_ids[-num_new_tokens:]
|
||||
|
||||
# Transmit partial if chunked prefill & prompt logprobs is enabled
|
||||
if new_token_ids or prompt_logprobs_tensors is not None:
|
||||
# Add EngineCoreOutput for this Request.
|
||||
|
||||
@ -91,6 +91,10 @@ class SchedulerOutput:
|
||||
# Total number of tokens scheduled for all requests.
|
||||
# Equal to sum(num_scheduled_tokens.values())
|
||||
total_num_scheduled_tokens: int
|
||||
# req_id -> spec_decode_tokens
|
||||
# If a request does not have any spec decode tokens, it will
|
||||
# not be included in the dictionary.
|
||||
scheduled_spec_decode_tokens: Dict[str, List[int]]
|
||||
# req_id -> encoder input indices that need processing.
|
||||
# E.g., if a request has [0, 1], it could mean the vision encoder needs
|
||||
# to process that the request's 0-th and 1-th images in the current step.
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -65,6 +66,7 @@ class EngineCore:
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
speculative_config=vllm_config.speculative_config,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
@ -84,6 +86,15 @@ class EngineCore:
|
||||
self.batch_queue_size)
|
||||
self.batch_queue = queue.Queue(self.batch_queue_size)
|
||||
|
||||
# Setup speculative decode.
|
||||
# TODO: find a better way to check if we are using ngram.
|
||||
self.use_spec_decode = False
|
||||
if self.scheduler.speculative_config:
|
||||
assert self.scheduler.speculative_config.ngram_prompt_lookup_min \
|
||||
, "Only ngram spec decode is supported in V1."
|
||||
self.proposer = NgramProposer()
|
||||
self.use_spec_decode = True
|
||||
|
||||
def _initialize_kv_caches(self,
|
||||
vllm_config: VllmConfig) -> Tuple[int, int]:
|
||||
start = time.time()
|
||||
@ -147,6 +158,9 @@ class EngineCore:
|
||||
return EngineCoreOutputs(
|
||||
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
||||
|
||||
if self.use_spec_decode:
|
||||
self.propose_tokens()
|
||||
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
output = self.model_executor.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
@ -207,6 +221,23 @@ class EngineCore:
|
||||
def profile(self, is_start: bool = True):
|
||||
self.model_executor.profile(is_start)
|
||||
|
||||
def propose_tokens(self):
|
||||
assert self.scheduler.speculative_config is not None
|
||||
for req in self.scheduler.running:
|
||||
# Ignore requests that are doing chunked prefill.
|
||||
if req.num_computed_tokens < req.num_tokens - 1:
|
||||
continue
|
||||
# Ignore requests that already have spec tokens.
|
||||
if req.spec_token_ids:
|
||||
continue
|
||||
spec_tokens = self.proposer.propose(
|
||||
req.all_token_ids,
|
||||
self.scheduler.speculative_config.ngram_prompt_lookup_min,
|
||||
self.scheduler.speculative_config.num_speculative_tokens,
|
||||
)
|
||||
if spec_tokens:
|
||||
req.append_spec_token_ids(spec_tokens)
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
self.scheduler.reset_prefix_cache()
|
||||
|
||||
|
||||
@ -43,7 +43,10 @@ class LogprobsTensors(NamedTuple):
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
|
||||
# [num_reqs]
|
||||
# [num_reqs, max_num_generated_tokens]
|
||||
# Different requests can have different number of generated tokens.
|
||||
# All requests are padded to max_num_generated_tokens.
|
||||
# INVALID_TOKEN_ID (-1 by default) is used for padding.
|
||||
sampled_token_ids: torch.Tensor
|
||||
logprobs_tensors: Optional[LogprobsTensors]
|
||||
|
||||
@ -58,8 +61,11 @@ class ModelRunnerOutput:
|
||||
# req_id -> index
|
||||
req_id_to_index: Dict[str, int]
|
||||
|
||||
# [num_reqs]
|
||||
sampled_token_ids: List[int]
|
||||
# num_reqs x num_generated_tokens
|
||||
# num_generated_tokens is the number of tokens
|
||||
# generated in the current step. It can be different for
|
||||
# each request due to speculative/jump decoding.
|
||||
sampled_token_ids: List[List[int]]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
|
||||
@ -46,6 +46,7 @@ class Request:
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
self._output_token_ids: List[int] = []
|
||||
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
|
||||
self.spec_token_ids: List[int] = []
|
||||
self.num_computed_tokens = 0
|
||||
|
||||
# Multi-modal related
|
||||
@ -103,10 +104,26 @@ class Request:
|
||||
self._output_token_ids.extend(token_ids)
|
||||
self._all_token_ids.extend(token_ids)
|
||||
|
||||
def append_spec_token_ids(
|
||||
self,
|
||||
token_ids: Union[int, List[int]],
|
||||
) -> None:
|
||||
if isinstance(token_ids, int):
|
||||
self.spec_token_ids.append(token_ids)
|
||||
else:
|
||||
self.spec_token_ids.extend(token_ids)
|
||||
|
||||
def clear_spec_tokens(self) -> None:
|
||||
self.spec_token_ids.clear()
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
return len(self._all_token_ids)
|
||||
|
||||
@property
|
||||
def num_tokens_with_spec(self) -> int:
|
||||
return len(self._all_token_ids) + len(self.spec_token_ids)
|
||||
|
||||
@property
|
||||
def num_output_tokens(self) -> int:
|
||||
return len(self._output_token_ids)
|
||||
|
||||
@ -12,6 +12,8 @@ class SamplingMetadata:
|
||||
temperature: torch.Tensor
|
||||
all_greedy: bool
|
||||
all_random: bool
|
||||
rejection_sampling: bool
|
||||
spec_token_ids: List[List[int]]
|
||||
|
||||
top_p: torch.Tensor
|
||||
top_k: torch.Tensor
|
||||
|
||||
160
vllm/v1/sample/rejection_sampler.py
Normal file
160
vllm/v1/sample/rejection_sampler.py
Normal file
@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
try:
|
||||
import flashinfer.sampling as fs
|
||||
is_flashinfer_available = True
|
||||
except ImportError:
|
||||
is_flashinfer_available = False
|
||||
|
||||
logger = init_logger(__name__)
|
||||
INVALID_TOKEN_ID = -1
|
||||
|
||||
|
||||
class RejectionSampler(nn.Module):
|
||||
|
||||
def forward(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
if not sampling_metadata.all_greedy:
|
||||
raise NotImplementedError(
|
||||
"Only greedy sampling is supported by rejection sampler.")
|
||||
|
||||
if is_flashinfer_available:
|
||||
logger.info("User FlashInfer for rejection sampling.")
|
||||
return RejectionSampler.flashinfer_sample(logits,
|
||||
sampling_metadata)
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is not available. Falling back to the PyTorch-"
|
||||
"native implementation of rejection sampling.")
|
||||
return RejectionSampler.greedy_sample_native(
|
||||
logits, sampling_metadata)
|
||||
|
||||
@staticmethod
|
||||
def flashinfer_sample(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
# NOTE: The following input preparationg can be moved
|
||||
# to the model runner with a persistent manner for better
|
||||
# performance.
|
||||
spec_token_ids = sampling_metadata.spec_token_ids
|
||||
max_spec_len = max(len(s) for s in spec_token_ids)
|
||||
batch_size = len(spec_token_ids)
|
||||
draft_token_ids = torch.full((batch_size, max_spec_len),
|
||||
INVALID_TOKEN_ID,
|
||||
device="cpu",
|
||||
dtype=torch.long)
|
||||
|
||||
target_token_ids = torch.full((batch_size, max_spec_len + 1),
|
||||
fill_value=INVALID_TOKEN_ID,
|
||||
device=logits.device,
|
||||
dtype=torch.long)
|
||||
|
||||
# TODO: Vectorize the following loop for better performance.
|
||||
start_loc = 0
|
||||
for i in range(batch_size):
|
||||
num_spec_tokens = len(spec_token_ids[i])
|
||||
draft_token_ids[i, :num_spec_tokens] = torch.tensor(
|
||||
spec_token_ids[i], device="cpu", dtype=torch.long)
|
||||
end_loc = start_loc + num_spec_tokens + 1
|
||||
# Assume greedy sampling.
|
||||
target_token_ids[i, :num_spec_tokens + 1] = torch.argmax(
|
||||
logits[start_loc:end_loc], dim=-1)
|
||||
start_loc = end_loc
|
||||
|
||||
vocab_size = logits.size(-1)
|
||||
# NOTE: CPU <-> GPU synchronization happens here.
|
||||
draft_token_ids = draft_token_ids.to(logits.device)
|
||||
draft_probs = RejectionSampler._create_greedy_token_probs(
|
||||
draft_token_ids, vocab_size, logits.device)
|
||||
target_probs = RejectionSampler._create_greedy_token_probs(
|
||||
target_token_ids, vocab_size, logits.device)
|
||||
uniform_samples = torch.zeros(batch_size,
|
||||
max_spec_len + 1,
|
||||
device=logits.device)
|
||||
|
||||
sampled_token_ids, _, _ = fs.chain_speculative_sampling(
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
uniform_samples,
|
||||
target_probs,
|
||||
)
|
||||
return SamplerOutput(sampled_token_ids=sampled_token_ids,
|
||||
logprobs_tensors=None)
|
||||
|
||||
# TODO: The following method can be optimized for better performance.
|
||||
@staticmethod
|
||||
def greedy_sample_native(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> SamplerOutput:
|
||||
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
|
||||
# Add 1 to include the 'bonus' token.
|
||||
sample_lens = [x + 1 for x in spec_lens]
|
||||
|
||||
output_token_ids = logits.argmax(dim=-1).view(-1)
|
||||
output_token_ids = output_token_ids.split(sample_lens)
|
||||
output_token_ids = pad_sequence(output_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
|
||||
spec_token_ids = [
|
||||
torch.tensor(x,
|
||||
dtype=output_token_ids.dtype,
|
||||
device=output_token_ids.device)
|
||||
for x in sampling_metadata.spec_token_ids
|
||||
]
|
||||
spec_token_ids = pad_sequence(spec_token_ids,
|
||||
batch_first=True,
|
||||
padding_value=INVALID_TOKEN_ID)
|
||||
|
||||
# Produce a mask that remains 1 (True) until the first
|
||||
# mismatch (cumprod turns 0 after a mismatch).
|
||||
accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod(
|
||||
dim=1)
|
||||
# Identify valid positions (non-padding).
|
||||
valid_mask = output_token_ids != INVALID_TOKEN_ID
|
||||
# Generate mask with bonus token.
|
||||
generate_mask = torch.cat([
|
||||
accept_mask,
|
||||
torch.zeros(accept_mask.size(0), 1, device=accept_mask.device)
|
||||
],
|
||||
dim=1).to(torch.bool) & valid_mask
|
||||
zeros_mask = (generate_mask == 0)
|
||||
first_zero_idx = zeros_mask.float().argmax(dim=1)
|
||||
# Figure out which rows actually contain at least one zero.
|
||||
rows_with_zero = zeros_mask.any(dim=1)
|
||||
# Use indexing to set the first zero in each of those rows to 1.
|
||||
generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1
|
||||
|
||||
output_token_ids[~generate_mask] = INVALID_TOKEN_ID
|
||||
return SamplerOutput(sampled_token_ids=output_token_ids,
|
||||
logprobs_tensors=None)
|
||||
|
||||
@staticmethod
|
||||
def _create_greedy_token_probs(token_ids: torch.Tensor, vocab_size: int,
|
||||
out_device: torch.device) -> torch.Tensor:
|
||||
batch_size, num_tokens = token_ids.shape
|
||||
|
||||
token_probs = torch.zeros(batch_size,
|
||||
num_tokens,
|
||||
vocab_size,
|
||||
dtype=torch.float,
|
||||
device=out_device)
|
||||
|
||||
# Ignore INVALID_TOKEN_ID.
|
||||
valid_mask = (token_ids != INVALID_TOKEN_ID)
|
||||
valid_indices = token_ids.clone()
|
||||
valid_indices[~valid_mask] = 0
|
||||
|
||||
token_probs.scatter_(dim=2,
|
||||
index=valid_indices.unsqueeze(-1),
|
||||
src=valid_mask.unsqueeze(-1).float())
|
||||
|
||||
return token_probs
|
||||
@ -9,6 +9,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
|
||||
apply_min_token_penalties)
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -18,12 +19,21 @@ class Sampler(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
self.rejection_sampler = RejectionSampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
if sampling_metadata.rejection_sampling:
|
||||
if sampling_metadata.max_num_logprobs:
|
||||
raise NotImplementedError(
|
||||
"Rejection sampling does not support logprobs.")
|
||||
return self.rejection_sampler(
|
||||
logits,
|
||||
sampling_metadata,
|
||||
)
|
||||
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
@ -54,7 +64,10 @@ class Sampler(nn.Module):
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=sampled,
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
99
vllm/v1/spec_decode/ngram_proposer.py
Normal file
99
vllm/v1/spec_decode/ngram_proposer.py
Normal file
@ -0,0 +1,99 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
|
||||
class NgramProposer:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def propose(self, context_token_ids: ConstantList[int], n: int,
|
||||
k: int) -> Optional[List[int]]:
|
||||
"""Proposes the next sequence of tokens based on n-gram pattern
|
||||
matching in the context. The function finds matches of the last n
|
||||
tokens in the previous context, and returns k tokens that followed
|
||||
that match.
|
||||
|
||||
Args:
|
||||
context_token_ids: List of token IDs representing the
|
||||
context sequence.
|
||||
n: Length of the n-gram to match.
|
||||
k: Number of tokens follow the match. If there are less
|
||||
than k tokens follow the match, we will return
|
||||
the maximum amount of tokens until the end.
|
||||
|
||||
Returns:
|
||||
List[int]: The sequence of tokens that followed
|
||||
the matched n-gram in the context.
|
||||
None: If no matching n-gram pattern is found.
|
||||
|
||||
Example:
|
||||
If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4:
|
||||
- The last 2 tokens [2,3] will be matched against the previous
|
||||
4 tokens [1,2,3,4].
|
||||
- Finding a match of [2,3] would return the tokens that
|
||||
followed that pattern. Here we will return [4,2,3] because
|
||||
we only have three tokens after the match.
|
||||
"""
|
||||
# TODO: Use c++ to implement the _find_subarray_kmp to
|
||||
# improve the efficiency
|
||||
return self._find_subarray_kmp(context_token_ids, n, k)
|
||||
|
||||
@staticmethod
|
||||
def _kmp_lps_array(pattern: List[int]) -> List[int]:
|
||||
"""
|
||||
Build the lps (longest proper prefix which is also suffix)
|
||||
array for the pattern.
|
||||
"""
|
||||
lps = [0] * len(pattern)
|
||||
prev_lps = 0 # length of the previous longest prefix suffix
|
||||
i = 1
|
||||
|
||||
while i < len(pattern):
|
||||
if pattern[i] == pattern[prev_lps]:
|
||||
prev_lps += 1
|
||||
lps[i] = prev_lps
|
||||
i += 1
|
||||
else:
|
||||
if prev_lps != 0:
|
||||
prev_lps = lps[prev_lps - 1]
|
||||
else:
|
||||
lps[i] = 0
|
||||
i += 1
|
||||
|
||||
return lps
|
||||
|
||||
@staticmethod
|
||||
def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int,
|
||||
k: int) -> Optional[List[int]]:
|
||||
context_len = len(context_token_ids)
|
||||
assert n > 0
|
||||
|
||||
pattern = context_token_ids[-n:]
|
||||
# Precompute lps array for Y
|
||||
lps = NgramProposer._kmp_lps_array(pattern)
|
||||
|
||||
i = 0
|
||||
j = 0
|
||||
# -n because the last n tokens are used as pattern
|
||||
while i < context_len - n:
|
||||
if context_token_ids[i] == pattern[j]:
|
||||
i += 1
|
||||
j += 1
|
||||
|
||||
# If we have matched the entire Y
|
||||
if j == n:
|
||||
# Found pattern in context, gather the next K elements
|
||||
return context_token_ids[i:i + k]
|
||||
else:
|
||||
# Mismatch
|
||||
if j != 0:
|
||||
# Use the lps array to avoid re-checking elements
|
||||
j = lps[j - 1]
|
||||
else:
|
||||
i += 1
|
||||
|
||||
# Y not found
|
||||
return None
|
||||
@ -390,6 +390,7 @@ class InputBatch:
|
||||
def make_sampling_metadata(
|
||||
self,
|
||||
req_id_output_token_ids: Dict[str, List[int]],
|
||||
req_id_to_spec_token_ids: Dict[str, List[int]],
|
||||
skip_copy: bool = False,
|
||||
) -> SamplingMetadata:
|
||||
if not skip_copy:
|
||||
@ -423,7 +424,8 @@ class InputBatch:
|
||||
self.prompt_token_ids = self._make_prompt_token_ids_tensor()
|
||||
|
||||
output_token_ids: List[List[int]] = []
|
||||
|
||||
spec_token_ids: List[List[int]] = []
|
||||
rejection_sampling = False
|
||||
for req_id in self.req_ids[:self.num_reqs]:
|
||||
assert req_id is not None
|
||||
# Currently we create a tensor for output_token_ids from scratch
|
||||
@ -434,11 +436,18 @@ class InputBatch:
|
||||
# TODO - Replace this with incremental update to output token
|
||||
# statistics.
|
||||
output_token_ids.append(req_id_output_token_ids[req_id])
|
||||
req_spec_token_ids = req_id_to_spec_token_ids.get(req_id, [])
|
||||
spec_token_ids.append(req_spec_token_ids)
|
||||
if req_spec_token_ids:
|
||||
# If any of the requests require speculative decoding, set the
|
||||
# flag to True.
|
||||
rejection_sampling = True
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=self.temperature[:self.num_reqs],
|
||||
all_greedy=self.all_greedy,
|
||||
all_random=self.all_random,
|
||||
rejection_sampling=rejection_sampling,
|
||||
top_p=self.top_p[:self.num_reqs],
|
||||
top_k=self.top_k[:self.num_reqs],
|
||||
min_p=self.min_p[:self.num_reqs],
|
||||
@ -452,6 +461,7 @@ class InputBatch:
|
||||
presence_penalties=self.presence_penalties[:self.num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:self.num_reqs],
|
||||
output_token_ids=output_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
min_tokens=self.min_tokens[:self.num_reqs],
|
||||
stop_token_ids=self.stop_token_ids[:self.num_reqs],
|
||||
no_penalties=self.no_penalties,
|
||||
|
||||
@ -32,6 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@ -180,6 +181,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.max_model_len,
|
||||
self.max_num_tokens),
|
||||
dtype=np.int32)
|
||||
self.arange_cpu = torch.from_numpy(self.arange_np)
|
||||
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
||||
# a faster version of creating a new tensor every time. Thus, we should
|
||||
# not make any assumptions about the values in these tensors.
|
||||
@ -368,7 +370,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return batch_changed
|
||||
|
||||
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
|
||||
def _prepare_inputs(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> Tuple[FlashAttentionMetadata, torch.Tensor]:
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@ -382,12 +386,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# TODO: The Python loop can be slow. Optimize.
|
||||
num_scheduled_tokens_list: List[int] = []
|
||||
max_num_scheduled_tokens = 0
|
||||
for req_id in self.input_batch.req_ids[:num_reqs]:
|
||||
all_spec_token_ids: List[int] = []
|
||||
num_spec_tokens_list: List[int] = []
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_scheduled_tokens_list.append(num_tokens)
|
||||
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
|
||||
num_tokens)
|
||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||
req_id, [])
|
||||
all_spec_token_ids.extend(spec_token_ids)
|
||||
num_spec_tokens_list.append(len(spec_token_ids))
|
||||
|
||||
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
|
||||
dtype=np.int32)
|
||||
assert max_num_scheduled_tokens > 0
|
||||
@ -426,6 +437,79 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# where M is the max_model_len.
|
||||
token_indices = (positions_np +
|
||||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||
|
||||
use_spec_decode = len(all_spec_token_ids) > 0
|
||||
if use_spec_decode:
|
||||
|
||||
# 1. Write spec_token_ids to input batch.
|
||||
# Step 1. Get req indices that perform spec decode and repeat
|
||||
# the req indices by the number of spec tokens. Note
|
||||
# for requests that don't perform spec decode, the
|
||||
# number of spec tokens is 0 and the req index is
|
||||
# repeated 0 times.
|
||||
# E.g., num_spec_tokens_list: [3, 0, 2, 0, 1]
|
||||
# spec_req_indices: [0, 0, 0, 2, 2, 4]
|
||||
spec_req_indices = np.repeat(self.arange_np[:num_reqs],
|
||||
num_spec_tokens_list)
|
||||
# spec_offsets: offsets within each spec token list.
|
||||
# E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here
|
||||
spec_offsets = np.concatenate(
|
||||
[self.arange_np[1:val + 1] for val in num_spec_tokens_list])
|
||||
# spec_seq_offsets: offsets within each sequence.
|
||||
# E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2]
|
||||
# after repeating: [1, 1, 1, 3, 3, 2]
|
||||
# spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1]
|
||||
# = [2, 3, 4, 4, 5, 3]
|
||||
spec_seq_offsets = np.repeat(
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs],
|
||||
num_spec_tokens_list) + spec_offsets
|
||||
# cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3]
|
||||
cumsums_spec_offsets = (
|
||||
spec_seq_offsets +
|
||||
spec_req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||
cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to(
|
||||
torch.int64)
|
||||
all_spec_token_ids = torch.tensor(all_spec_token_ids,
|
||||
device="cpu",
|
||||
dtype=self.input_ids_cpu.dtype)
|
||||
|
||||
# Step 2. Write spec token ids to input_ids_cpu.
|
||||
self.input_batch.token_ids_cpu_tensor.flatten().scatter_(
|
||||
0, cumsums_spec_offsets, all_spec_token_ids)
|
||||
|
||||
# 2. Get spec decode logits indices.
|
||||
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
|
||||
# cu_num_tokens: [4, 104, 107, 207, 209]
|
||||
# num_spec_tokens_list: [3, 0, 2, 0, 1]
|
||||
# num_sampled_tokens: [4, 1, 3, 1, 2]
|
||||
# spec_decode_logits_indices:
|
||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32)
|
||||
num_sampled_tokens = num_spec_tokens_np + 1
|
||||
# logits_start_loc: [0, 103, 104, 206, 207]
|
||||
logits_start_loc = cu_num_tokens - num_sampled_tokens
|
||||
# [0, 103, 104, 206, 207] ->
|
||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
|
||||
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
|
||||
# The following three lines:
|
||||
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
|
||||
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
|
||||
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
|
||||
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
cumsums_sampled_offsets = np.repeat(
|
||||
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
|
||||
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
|
||||
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
|
||||
total_num_sampled_tokens = num_sampled_tokens.sum()
|
||||
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
|
||||
cumsums_sampled_offsets)
|
||||
|
||||
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
|
||||
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
|
||||
spec_decode_logits_indices = logits_start_loc + sampled_arange
|
||||
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
@ -519,16 +603,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
)
|
||||
|
||||
if use_spec_decode:
|
||||
logits_indices = torch.from_numpy(spec_decode_logits_indices).to(
|
||||
self.device, non_blocking=True)
|
||||
else:
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain
|
||||
# partial requests. While we should not sample any token
|
||||
# from these partial requests, we do so for simplicity.
|
||||
# We will ignore the sampled tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
self.set_active_loras(self.input_batch, num_scheduled_tokens)
|
||||
|
||||
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
|
||||
# requests. While we should not sample any token from these partial
|
||||
# requests, we do so for simplicity. We will ignore the sampled
|
||||
# tokens from the partial requests.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return attn_metadata, logits_indices
|
||||
|
||||
def _compute_cascade_attn_prefix_len(
|
||||
@ -673,6 +762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _prepare_sampling(
|
||||
self,
|
||||
batch_changed: bool,
|
||||
req_to_spec_token_ids: Dict[str, List[int]],
|
||||
) -> SamplingMetadata:
|
||||
# Create the sampling metadata.
|
||||
req_id_output_token_ids: Dict[str, List[int]] = \
|
||||
@ -680,7 +770,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for req_id, req in self.requests.items()}
|
||||
|
||||
sampling_metadata = self.input_batch.make_sampling_metadata(
|
||||
req_id_output_token_ids, skip_copy=not batch_changed)
|
||||
req_id_output_token_ids, req_to_spec_token_ids, not batch_changed)
|
||||
return sampling_metadata
|
||||
|
||||
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
@ -847,7 +937,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self._prepare_sampling(batch_changed)
|
||||
sampling_metadata = self._prepare_sampling(
|
||||
batch_changed, scheduler_output.scheduled_spec_decode_tokens)
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
@ -857,18 +948,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# the requests one by one. Optimize.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
|
||||
for i, req_id in enumerate( # type: ignore[assignment]
|
||||
self.input_batch.req_ids[:num_reqs]):
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
scheduler_output.num_scheduled_tokens[req_id])
|
||||
assert seq_len <= req_state.num_tokens
|
||||
if seq_len == req_state.num_tokens:
|
||||
# Append the sampled token to the output token ids.
|
||||
self.input_batch.num_tokens[i] += 1
|
||||
# OPTIMIZATION: Priming the state updates for later updates.
|
||||
req_state.output_token_ids.append(0)
|
||||
if seq_len >= req_state.num_tokens:
|
||||
request_seq_lens.append((i, req_state, seq_len))
|
||||
else:
|
||||
# Ignore the sampled token from the partial request.
|
||||
@ -886,7 +971,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
logprobs_lists = logprobs_tensors.tolists() \
|
||||
if logprobs_tensors is not None else None
|
||||
@ -897,16 +981,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output,
|
||||
)
|
||||
|
||||
# Update with the actual token ids
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
token_id = sampled_token_ids[i]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
req_state.output_token_ids[-1] = token_id
|
||||
# Update batch with the valid generated tokens.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
valid_sampled_token_ids = sampled_token_ids.tolist()
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
token_id = valid_sampled_token_ids[i][0]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
req_state.output_token_ids.append(token_id)
|
||||
self.input_batch.num_tokens[i] += 1
|
||||
else:
|
||||
valid_mask = sampled_token_ids != INVALID_TOKEN_ID
|
||||
gen_lens = valid_mask.sum(dim=1).tolist()
|
||||
valid_sampled_token_ids = [
|
||||
seq.tolist()
|
||||
for seq in sampled_token_ids[valid_mask].split(gen_lens)
|
||||
]
|
||||
self.input_batch.num_tokens[:num_reqs] += gen_lens
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
|
||||
self.input_batch.token_ids_cpu[
|
||||
i, target_slice] = valid_sampled_token_ids[i]
|
||||
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
|
||||
@ -695,7 +695,7 @@ class TPUModelRunner:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=all_req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
sampled_token_ids=[[token_id] for token_id in sampled_token_ids],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user