[V1][Spec Decode] Ngram Spec Decode (#12193)

Signed-off-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Lily Liu 2025-02-15 18:05:11 -08:00 committed by GitHub
parent 367cb8ce8c
commit 80f63a3966
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 1023 additions and 82 deletions

View File

@ -4,10 +4,12 @@ from typing import List, Optional
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
EOS_TOKEN_ID = 50256
def create_scheduler(
model: str = "facebook/opt-125m",
@ -38,6 +40,7 @@ def create_scheduler(
return Scheduler(scheduler_config,
model_config,
cache_config,
speculative_config=None,
lora_config=None,
log_stats=True)
@ -46,8 +49,12 @@ def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[List[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[List[int]] = None,
):
sampling_params = SamplingParams()
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids)
requests = []
for i in range(num_requests):
if mm_positions is not None:
@ -64,7 +71,7 @@ def create_requests(
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
)
requests.append(request)
@ -195,7 +202,7 @@ def test_schedule_partial_requests():
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[0] * len(requests),
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
)
@ -215,6 +222,189 @@ def test_schedule_partial_requests():
assert requests[2].request_id not in output.num_scheduled_tokens
def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output"""
scheduler = create_scheduler()
# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID],
[10,
11]], # First request hits EOS, second continues
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped, second continues
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
assert list(requests[1].output_token_ids) == [10, 11]
# Test case 2: Stop on custom stop token
scheduler = create_scheduler()
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped on custom token
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].stop_reason == 42
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 42]
assert list(requests[1].output_token_ids) == [13, 14]
# Test case 3: Stop on max tokens
scheduler = create_scheduler()
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 11
] # Truncated to max_tokens
assert list(requests[1].output_token_ids) == [13]
# Test case 4: Ignore EOS flag
scheduler = create_scheduler()
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
scheduler.scheduled_req_ids.add(requests[0].request_id)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify request continues past EOS
assert len(scheduler.running) == 1
assert not requests[0].is_finished()
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
def test_schedule_concurrent_batches():
scheduler = create_scheduler(
max_num_batched_tokens=1024,
@ -243,7 +433,7 @@ def test_schedule_concurrent_batches():
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[0],
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
)
@ -259,7 +449,7 @@ def test_schedule_concurrent_batches():
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[0],
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
)

View File

@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm import LLM, SamplingParams
@pytest.fixture
def test_prompts():
return [
"Can you repeat the sentence ten times, this is a sentence.",
"Can you repeat the sentence ten times, this is a test.",
]
@pytest.fixture
def sampling_config():
# Only support greedy for now
return SamplingParams(temperature=0, max_tokens=30, ignore_eos=False)
@pytest.fixture
def model_name():
return "meta-llama/Meta-Llama-3-8B-Instruct"
def test_ngram_correctness(monkeypatch, test_prompts, sampling_config,
model_name):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
ref_llm = LLM(model=model_name)
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
del ref_llm
spec_llm = LLM(model=model_name,
speculative_model='[ngram]',
ngram_prompt_lookup_max=5,
ngram_prompt_lookup_min=3,
num_speculative_tokens=3)
spec_outputs = spec_llm.generate(test_prompts, sampling_config)
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
assert ref_output.outputs[0].text == spec_output.outputs[0].text, \
(f"ref_output: {ref_output.outputs[0].text},"
f"spec_output: {spec_output.outputs[0].text}")
del spec_llm

View File

@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List
import pytest
import torch
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
@pytest.fixture
def sampler():
return RejectionSampler()
def create_logits_tensor(token_ids: List[int],
vocab_size: int = 100) -> torch.Tensor:
"""Helper function to create logits tensor that
will produce desired token ids on argmax"""
logits = torch.full((len(token_ids), vocab_size), -100.0).cuda()
for i, token_id in enumerate(token_ids):
logits[i, token_id] = 100.0
return logits
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
batch_size = len(spec_tokens)
return SamplingMetadata(
temperature=0.0,
all_greedy=True,
all_random=False,
rejection_sampling=True,
spec_token_ids=spec_tokens,
top_p=None,
top_k=None,
no_top_p=False,
no_top_k=False,
min_p=torch.empty(batch_size, ),
no_min_p=True,
generators={},
max_num_logprobs=0,
no_penalties=False,
prompt_token_ids=None,
frequency_penalties=torch.tensor([]),
presence_penalties=torch.tensor([]),
repetition_penalties=torch.tensor([]),
output_token_ids=[],
min_tokens=[],
stop_token_ids=[],
logit_bias=[None] * batch_size,
)
def test_perfect_match(sampler):
"""Test when output tokens perfectly match speculated tokens"""
spec_tokens = [[1, 2, 3]]
output_tokens = [1, 2, 3, 4] # 4 is the bonus token
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
expected = torch.tensor([[1, 2, 3, 4]],
dtype=torch.int,
device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
def test_early_mismatch(sampler):
"""Test when there's an early mismatch in tokens"""
spec_tokens = [[1, 2, 3]]
output_tokens = [1, 5, 3, 4] # Mismatch at position 1
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
def test_multiple_sequences(sampler):
"""Test handling multiple sequences of speculated tokens"""
spec_tokens = [[1, 2], [3]]
output_tokens = [1, 2, 5, 3, 4] # Two sequences with bonus tokens 5 and 4
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
def test_single_token_sequence(sampler):
"""Test handling sequences with single token"""
spec_tokens = [[1]]
output_tokens = [1, 2] # Single token with bonus token 2
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
def test_empty_sequence(sampler):
"""Test handling empty sequence of speculated tokens"""
spec_tokens: List[List[int]] = [[]]
output_tokens = [5] # Just the bonus token
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
def test_multiple_mismatches(sampler):
"""Test handling multiple sequences with mismatches"""
spec_tokens = [[1, 2, 3], [4, 5, 6]]
output_tokens = [1, 2, 7, 6, 4, 8, 6, 9] # Mismatches in both sequences
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID],
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int,
device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2]], [1, 2, 3], [[1, 2, 3]]), # Perfect match with bonus
([[1]], [2, 3], [[2, INVALID_TOKEN_ID]]), # First mismatch
([[1, 2], [3, 4]], [1, 5, 6, 3, 4, 7], [[1, 5, INVALID_TOKEN_ID],
[3, 4, 7]]), # Mixed matches
])
def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
"""Parametrized test for various matching scenarios"""
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens)
output = sampler(logits, metadata)
expected_tensor = torch.tensor(expected,
dtype=torch.int,
device=logits.device)
assert torch.equal(output.sampled_token_ids, expected_tensor)
def test_logits_shape_handling(sampler):
"""Test handling of different logits tensor shapes"""
spec_tokens = [[1, 2]]
output_tokens = [1, 2, 3]
vocab_size = 1000
metadata = create_sampling_metadata(spec_tokens)
logits = create_logits_tensor(output_tokens, vocab_size)
output = sampler(logits, metadata)
expected = torch.tensor([[1, 2, 3]], dtype=torch.int, device=logits.device)
assert torch.equal(output.sampled_token_ids, expected)
assert logits.shape[-1] == vocab_size

View File

@ -77,6 +77,7 @@ def _create_default_sampling_metadata(
temperature=torch.full((batch_size, ), 0.0),
all_greedy=True,
all_random=False,
rejection_sampling=False,
top_p=torch.empty(batch_size, ),
top_k=torch.empty(batch_size, ),
no_top_p=True,
@ -88,6 +89,7 @@ def _create_default_sampling_metadata(
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
vocab_size, device),
output_token_ids=output_token_ids,
spec_token_ids=[],
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),

View File

@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import ConstantList
@pytest.fixture
def proposer():
return NgramProposer()
def test_kmp_lps_array(proposer):
assert proposer._kmp_lps_array([]) == []
assert proposer._kmp_lps_array([1]) == [0]
assert proposer._kmp_lps_array([1, 1, 1]) == [0, 1, 2]
assert proposer._kmp_lps_array([1, 2, 3, 4]) == [0, 0, 0, 0]
assert proposer._kmp_lps_array([1, 2, 1, 2, 3]) == [0, 0, 1, 2, 0]
def test_find_subarray_kmp(proposer):
X = ConstantList([1, 2, 3, 4, 1, 2, 3, 5, 6])
assert proposer._find_subarray_kmp(X, 2, 2) is None
X = ConstantList([1, 2, 3, 4, 1, 2, 3])
assert proposer._find_subarray_kmp(X, 2, 3) == [4, 1, 2]
assert proposer._find_subarray_kmp(X, 2, 2) == [4, 1]
assert proposer._find_subarray_kmp(X, 1, 3) == [4, 1, 2]
assert proposer._find_subarray_kmp(X, 1, 2) == [4, 1]
X = ConstantList([1, 3, 6, 2, 3, 4, 1, 2, 3])
assert proposer._find_subarray_kmp(X, 2, 3) == [4, 1, 2]
# Return on the first match
assert proposer._find_subarray_kmp(X, 1, 3) == [6, 2, 3]

View File

@ -92,6 +92,7 @@ def _construct_expected_sampling_metadata(
device=device),
all_greedy=False,
all_random=True,
rejection_sampling=False,
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
no_top_p=all(x == 1.0 for x in top_p),
@ -116,6 +117,7 @@ def _construct_expected_sampling_metadata(
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
spec_token_ids=[],
min_tokens=min_tokens,
stop_token_ids=stop_token_ids,
no_penalties=(all(x == 0 for x in presence_penalties)
@ -205,7 +207,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
# Generate the sampling metadata
sampling_metadata = input_batch.make_sampling_metadata(
req_id_output_token_ids, skip_copy=False)
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
# Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata(

View File

@ -66,6 +66,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
scheduled_cached_reqs=[],
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
@ -109,6 +110,7 @@ def test_update_states_request_finished(model_runner):
scheduled_cached_reqs=[],
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={req_id},
@ -137,6 +139,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_cached_reqs=[],
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids={},
@ -160,6 +163,7 @@ def test_update_states_request_resumed(model_runner):
scheduled_cached_reqs=[cached_req_data],
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
@ -188,6 +192,7 @@ def test_update_states_no_changes(model_runner):
scheduled_cached_reqs=[],
num_scheduled_tokens={req_id: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),
@ -220,6 +225,7 @@ def test_update_states_request_unscheduled(model_runner):
scheduled_cached_reqs=[],
num_scheduled_tokens={req_ids[0]: 1},
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
finished_req_ids=set(),

View File

@ -124,9 +124,8 @@ class CudaPlatformBase(Platform):
"vllm.worker.multi_step_worker.MultiStepWorker"
elif vllm_config.speculative_config:
if envs.VLLM_USE_V1:
raise NotImplementedError(
"Speculative decoding is not yet supported on VLLM V1."
)
parallel_config.worker_cls = \
"vllm.v1.worker.gpu_worker.Worker"
else:
parallel_config.worker_cls = \
"vllm.spec_decode.spec_decode_worker.create_spec_worker"

View File

@ -82,6 +82,11 @@ class KVCacheManager:
self.req_to_block_hashes: DefaultDict[
str, List[BlockHashType]] = defaultdict(list)
# {req_id: The number of cached blocks for this given request}
# This is used to track the number of cached blocks for each request.
# This is only used to track the RUNNING requests, we do not track the
# data for reempted ones.
self.num_cached_block: Dict[str, int] = defaultdict(int)
self.prefix_cache_stats = PrefixCacheStats()
@property
@ -241,23 +246,25 @@ class KVCacheManager:
if not self.enable_caching:
return new_blocks
# NOTE(rickyx): We are assuming the `num_tokens` are actual
# tokens rather than lookahead slots (e.g. for speculative decoding).
# TODO(rickyx): When supporting speculative decoding, we will need to
# differentiate between them so that we can know how many blocks are
# full after appending the actual tokens.
num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size
num_computed_full_blocks = num_computed_tokens // self.block_size
new_full_blocks = req_blocks[num_computed_full_blocks:num_full_blocks]
num_cached_blocks = self.num_cached_block[request.request_id]
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
num_full_blocks_after_append = (num_computed_tokens + num_tokens - len(
request.spec_token_ids)) // self.block_size
new_full_blocks = req_blocks[
num_cached_blocks:num_full_blocks_after_append]
if new_full_blocks:
self._cache_full_blocks(
request=request,
blk_start_idx=num_computed_full_blocks,
blk_start_idx=num_cached_blocks,
# The new full blocks are the full blocks that are not computed.
full_blocks=new_full_blocks,
prev_block=(req_blocks[num_computed_full_blocks - 1]
if num_computed_full_blocks > 0 else None))
prev_block=(req_blocks[num_cached_blocks -
1] if num_cached_blocks > 0 else None))
self.num_cached_block[
request.request_id] = num_full_blocks_after_append
return new_blocks
def free(self, request: Request) -> None:
@ -281,6 +288,8 @@ class KVCacheManager:
if block.ref_cnt == 0:
self.free_block_queue.append(block)
self.num_cached_block.pop(request.request_id, None)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,

View File

@ -4,7 +4,8 @@ import time
from collections import deque
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
from vllm.config import (CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
compute_encoder_budget)
@ -28,11 +29,13 @@ class Scheduler:
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
log_stats: bool,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
self.speculative_config = speculative_config
self.log_stats = log_stats
# Scheduling constraints.
@ -96,12 +99,14 @@ class Scheduler:
def schedule(self) -> "SchedulerOutput":
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and num_tokens,
# which is equal to len(prompt_token_ids) + len(output_token_ids).
# Each request just has the num_computed_tokens and
# num_tokens_with_spec. num_tokens_with_spec =
# len(prompt_token_ids) + len(output_token_ids) + len(spec_token_ids).
# At each step, the scheduler tries to assign tokens to the requests
# so that each request's num_computed_tokens can catch up its
# num_tokens. This is general enough to cover chunked prefills,
# prefix caching, and the "jump decoding" optimization in the future.
# num_tokens_with_spec. This is general enough to cover
# chunked prefills, prefix caching, speculative decoding,
# and the "jump decoding" optimization in the future.
scheduled_new_reqs: List[Request] = []
scheduled_resumed_reqs: List[Request] = []
@ -114,7 +119,8 @@ class Scheduler:
# Encoder-related.
scheduled_encoder_inputs: Dict[str, List[int]] = {}
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: Dict[str, List[int]] = {}
scheduled_timestamp = time.monotonic()
# First, schedule the RUNNING requests.
@ -126,7 +132,8 @@ class Scheduler:
req_index += 1
continue
num_new_tokens = request.num_tokens - request.num_computed_tokens
num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens)
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
@ -189,6 +196,11 @@ class Scheduler:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Speculative decode related.
if request.spec_token_ids:
scheduled_spec_decode_tokens[
request.request_id] = request.spec_token_ids
# Record the LoRAs in scheduled_running_reqs
requested_loras: Set[int] = set()
if self.lora_config:
@ -338,6 +350,7 @@ class Scheduler:
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
@ -447,11 +460,11 @@ class Scheduler:
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> EngineCoreOutputs:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
outputs: List[EngineCoreOutput] = []
@ -466,11 +479,30 @@ class Scheduler:
new_running.append(request)
continue
request.num_computed_tokens += num_tokens_scheduled
# When the request's num_computed_tokens catches up its num_tokens,
# the request generates output tokens. Otherwise, we ignore the
# sampler output for the request.
assert request.num_computed_tokens <= request.num_tokens
req_index = model_runner_output.req_id_to_index[req_id]
generated_token_ids = sampled_token_ids[req_index]
if req_id not in scheduler_output.scheduled_spec_decode_tokens:
# When the request's num_computed_tokens catches up
# its num_tokens, the request generates output tokens.
# Otherwise, we ignore the sampler output for the request.
request.num_computed_tokens += num_tokens_scheduled
assert request.num_computed_tokens <= request.num_tokens
else:
# num_computed_tokens_step represents the number of tokens
# processed in the current step, considering scheduled
# tokens and rejections.
# It is calculated as:
# num_computed_tokens_step = num_scheduled_tokens -
# num_tokens_rejected,
# where num_tokens_rejected is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens[req_id])
num_computed_tokens_step = num_scheduled_tokens[req_id] - (
len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens += num_computed_tokens_step
cached_encoder_input_ids = (
self.encoder_cache_manager.get_cached_input_ids(request))
@ -485,27 +517,32 @@ class Scheduler:
self.encoder_cache_manager.free_encoder_input(
request, input_id)
if request.num_computed_tokens >= request.num_tokens:
# Clear the spec tokens as the request has generated
# a new token. Here, We assume all spec tokens are verified
# if we perform speculative decoding for this request.
# Therefore, we can clear all spec tokens after
# the generation step.
request.clear_spec_tokens()
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
stopped = False
new_logprobs = None
new_token_ids = None
new_token_ids: List[int] = []
if request.num_computed_tokens == request.num_tokens:
req_index = model_runner_output.req_id_to_index[req_id]
# NOTE(woosuk): Currently, we assume that each request
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
request.append_output_token_ids(token_id)
num_new_tokens = 1
# TODO: Update the KV cache manager for prefix caching.
if request.num_computed_tokens >= request.num_tokens:
for output_token_id in generated_token_ids:
request.append_output_token_ids(output_token_id)
new_token_ids.append(output_token_id)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = self._check_stop(request)
if stopped:
self._free_request(request)
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
stopped = self._check_stop(request)
if stopped:
self._free_request(request)
break
# Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None:
@ -514,8 +551,6 @@ class Scheduler:
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
new_token_ids = request.output_token_ids[-num_new_tokens:]
# Transmit partial if chunked prefill & prompt logprobs is enabled
if new_token_ids or prompt_logprobs_tensors is not None:
# Add EngineCoreOutput for this Request.

View File

@ -91,6 +91,10 @@ class SchedulerOutput:
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> spec_decode_tokens
# If a request does not have any spec decode tokens, it will
# not be included in the dictionary.
scheduled_spec_decode_tokens: Dict[str, List[int]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.

View File

@ -27,6 +27,7 @@ from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -65,6 +66,7 @@ class EngineCore:
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config,
log_stats=self.log_stats,
)
@ -84,6 +86,15 @@ class EngineCore:
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
# Setup speculative decode.
# TODO: find a better way to check if we are using ngram.
self.use_spec_decode = False
if self.scheduler.speculative_config:
assert self.scheduler.speculative_config.ngram_prompt_lookup_min \
, "Only ngram spec decode is supported in V1."
self.proposer = NgramProposer()
self.use_spec_decode = True
def _initialize_kv_caches(self,
vllm_config: VllmConfig) -> Tuple[int, int]:
start = time.time()
@ -147,6 +158,9 @@ class EngineCore:
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
if self.use_spec_decode:
self.propose_tokens()
scheduler_output = self.scheduler.schedule()
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
@ -207,6 +221,23 @@ class EngineCore:
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
def propose_tokens(self):
assert self.scheduler.speculative_config is not None
for req in self.scheduler.running:
# Ignore requests that are doing chunked prefill.
if req.num_computed_tokens < req.num_tokens - 1:
continue
# Ignore requests that already have spec tokens.
if req.spec_token_ids:
continue
spec_tokens = self.proposer.propose(
req.all_token_ids,
self.scheduler.speculative_config.ngram_prompt_lookup_min,
self.scheduler.speculative_config.num_speculative_tokens,
)
if spec_tokens:
req.append_spec_token_ids(spec_tokens)
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()

View File

@ -43,7 +43,10 @@ class LogprobsTensors(NamedTuple):
@dataclass
class SamplerOutput:
# [num_reqs]
# [num_reqs, max_num_generated_tokens]
# Different requests can have different number of generated tokens.
# All requests are padded to max_num_generated_tokens.
# INVALID_TOKEN_ID (-1 by default) is used for padding.
sampled_token_ids: torch.Tensor
logprobs_tensors: Optional[LogprobsTensors]
@ -58,8 +61,11 @@ class ModelRunnerOutput:
# req_id -> index
req_id_to_index: Dict[str, int]
# [num_reqs]
sampled_token_ids: List[int]
# num_reqs x num_generated_tokens
# num_generated_tokens is the number of tokens
# generated in the current step. It can be different for
# each request due to speculative/jump decoding.
sampled_token_ids: List[List[int]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]

View File

@ -46,6 +46,7 @@ class Request:
self.num_prompt_tokens = len(self.prompt_token_ids)
self._output_token_ids: List[int] = []
self._all_token_ids: List[int] = self.prompt_token_ids.copy()
self.spec_token_ids: List[int] = []
self.num_computed_tokens = 0
# Multi-modal related
@ -103,10 +104,26 @@ class Request:
self._output_token_ids.extend(token_ids)
self._all_token_ids.extend(token_ids)
def append_spec_token_ids(
self,
token_ids: Union[int, List[int]],
) -> None:
if isinstance(token_ids, int):
self.spec_token_ids.append(token_ids)
else:
self.spec_token_ids.extend(token_ids)
def clear_spec_tokens(self) -> None:
self.spec_token_ids.clear()
@property
def num_tokens(self) -> int:
return len(self._all_token_ids)
@property
def num_tokens_with_spec(self) -> int:
return len(self._all_token_ids) + len(self.spec_token_ids)
@property
def num_output_tokens(self) -> int:
return len(self._output_token_ids)

View File

@ -12,6 +12,8 @@ class SamplingMetadata:
temperature: torch.Tensor
all_greedy: bool
all_random: bool
rejection_sampling: bool
spec_token_ids: List[List[int]]
top_p: torch.Tensor
top_k: torch.Tensor

View File

@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from vllm.logger import init_logger
from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
try:
import flashinfer.sampling as fs
is_flashinfer_available = True
except ImportError:
is_flashinfer_available = False
logger = init_logger(__name__)
INVALID_TOKEN_ID = -1
class RejectionSampler(nn.Module):
def forward(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
if not sampling_metadata.all_greedy:
raise NotImplementedError(
"Only greedy sampling is supported by rejection sampler.")
if is_flashinfer_available:
logger.info("User FlashInfer for rejection sampling.")
return RejectionSampler.flashinfer_sample(logits,
sampling_metadata)
else:
logger.warning(
"FlashInfer is not available. Falling back to the PyTorch-"
"native implementation of rejection sampling.")
return RejectionSampler.greedy_sample_native(
logits, sampling_metadata)
@staticmethod
def flashinfer_sample(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
# NOTE: The following input preparationg can be moved
# to the model runner with a persistent manner for better
# performance.
spec_token_ids = sampling_metadata.spec_token_ids
max_spec_len = max(len(s) for s in spec_token_ids)
batch_size = len(spec_token_ids)
draft_token_ids = torch.full((batch_size, max_spec_len),
INVALID_TOKEN_ID,
device="cpu",
dtype=torch.long)
target_token_ids = torch.full((batch_size, max_spec_len + 1),
fill_value=INVALID_TOKEN_ID,
device=logits.device,
dtype=torch.long)
# TODO: Vectorize the following loop for better performance.
start_loc = 0
for i in range(batch_size):
num_spec_tokens = len(spec_token_ids[i])
draft_token_ids[i, :num_spec_tokens] = torch.tensor(
spec_token_ids[i], device="cpu", dtype=torch.long)
end_loc = start_loc + num_spec_tokens + 1
# Assume greedy sampling.
target_token_ids[i, :num_spec_tokens + 1] = torch.argmax(
logits[start_loc:end_loc], dim=-1)
start_loc = end_loc
vocab_size = logits.size(-1)
# NOTE: CPU <-> GPU synchronization happens here.
draft_token_ids = draft_token_ids.to(logits.device)
draft_probs = RejectionSampler._create_greedy_token_probs(
draft_token_ids, vocab_size, logits.device)
target_probs = RejectionSampler._create_greedy_token_probs(
target_token_ids, vocab_size, logits.device)
uniform_samples = torch.zeros(batch_size,
max_spec_len + 1,
device=logits.device)
sampled_token_ids, _, _ = fs.chain_speculative_sampling(
draft_probs,
draft_token_ids,
uniform_samples,
target_probs,
)
return SamplerOutput(sampled_token_ids=sampled_token_ids,
logprobs_tensors=None)
# TODO: The following method can be optimized for better performance.
@staticmethod
def greedy_sample_native(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> SamplerOutput:
spec_lens = [len(x) for x in sampling_metadata.spec_token_ids]
# Add 1 to include the 'bonus' token.
sample_lens = [x + 1 for x in spec_lens]
output_token_ids = logits.argmax(dim=-1).view(-1)
output_token_ids = output_token_ids.split(sample_lens)
output_token_ids = pad_sequence(output_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Convert spec token IDs to a tensor, split by sample_lens, then pad.
spec_token_ids = [
torch.tensor(x,
dtype=output_token_ids.dtype,
device=output_token_ids.device)
for x in sampling_metadata.spec_token_ids
]
spec_token_ids = pad_sequence(spec_token_ids,
batch_first=True,
padding_value=INVALID_TOKEN_ID)
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
accept_mask = (output_token_ids[:, :-1] == spec_token_ids).cumprod(
dim=1)
# Identify valid positions (non-padding).
valid_mask = output_token_ids != INVALID_TOKEN_ID
# Generate mask with bonus token.
generate_mask = torch.cat([
accept_mask,
torch.zeros(accept_mask.size(0), 1, device=accept_mask.device)
],
dim=1).to(torch.bool) & valid_mask
zeros_mask = (generate_mask == 0)
first_zero_idx = zeros_mask.float().argmax(dim=1)
# Figure out which rows actually contain at least one zero.
rows_with_zero = zeros_mask.any(dim=1)
# Use indexing to set the first zero in each of those rows to 1.
generate_mask[rows_with_zero, first_zero_idx[rows_with_zero]] = 1
output_token_ids[~generate_mask] = INVALID_TOKEN_ID
return SamplerOutput(sampled_token_ids=output_token_ids,
logprobs_tensors=None)
@staticmethod
def _create_greedy_token_probs(token_ids: torch.Tensor, vocab_size: int,
out_device: torch.device) -> torch.Tensor:
batch_size, num_tokens = token_ids.shape
token_probs = torch.zeros(batch_size,
num_tokens,
vocab_size,
dtype=torch.float,
device=out_device)
# Ignore INVALID_TOKEN_ID.
valid_mask = (token_ids != INVALID_TOKEN_ID)
valid_indices = token_ids.clone()
valid_indices[~valid_mask] = 0
token_probs.scatter_(dim=2,
index=valid_indices.unsqueeze(-1),
src=valid_mask.unsqueeze(-1).float())
return token_probs

View File

@ -9,6 +9,7 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.rejection_sampler import RejectionSampler
_SAMPLING_EPS = 1e-5
@ -18,12 +19,21 @@ class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
self.rejection_sampler = RejectionSampler()
def forward(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
if sampling_metadata.rejection_sampling:
if sampling_metadata.max_num_logprobs:
raise NotImplementedError(
"Rejection sampling does not support logprobs.")
return self.rejection_sampler(
logits,
sampling_metadata,
)
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
@ -54,7 +64,10 @@ class Sampler(nn.Module):
# These are GPU tensors.
sampler_output = SamplerOutput(
sampled_token_ids=sampled,
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=logprobs_tensors,
)
return sampler_output

View File

@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional
from vllm.v1.utils import ConstantList
class NgramProposer:
def __init__(self):
pass
def propose(self, context_token_ids: ConstantList[int], n: int,
k: int) -> Optional[List[int]]:
"""Proposes the next sequence of tokens based on n-gram pattern
matching in the context. The function finds matches of the last n
tokens in the previous context, and returns k tokens that followed
that match.
Args:
context_token_ids: List of token IDs representing the
context sequence.
n: Length of the n-gram to match.
k: Number of tokens follow the match. If there are less
than k tokens follow the match, we will return
the maximum amount of tokens until the end.
Returns:
List[int]: The sequence of tokens that followed
the matched n-gram in the context.
None: If no matching n-gram pattern is found.
Example:
If context_token_ids = [1,2,3,4,2,3], n = 2, and k = 4:
- The last 2 tokens [2,3] will be matched against the previous
4 tokens [1,2,3,4].
- Finding a match of [2,3] would return the tokens that
followed that pattern. Here we will return [4,2,3] because
we only have three tokens after the match.
"""
# TODO: Use c++ to implement the _find_subarray_kmp to
# improve the efficiency
return self._find_subarray_kmp(context_token_ids, n, k)
@staticmethod
def _kmp_lps_array(pattern: List[int]) -> List[int]:
"""
Build the lps (longest proper prefix which is also suffix)
array for the pattern.
"""
lps = [0] * len(pattern)
prev_lps = 0 # length of the previous longest prefix suffix
i = 1
while i < len(pattern):
if pattern[i] == pattern[prev_lps]:
prev_lps += 1
lps[i] = prev_lps
i += 1
else:
if prev_lps != 0:
prev_lps = lps[prev_lps - 1]
else:
lps[i] = 0
i += 1
return lps
@staticmethod
def _find_subarray_kmp(context_token_ids: ConstantList[int], n: int,
k: int) -> Optional[List[int]]:
context_len = len(context_token_ids)
assert n > 0
pattern = context_token_ids[-n:]
# Precompute lps array for Y
lps = NgramProposer._kmp_lps_array(pattern)
i = 0
j = 0
# -n because the last n tokens are used as pattern
while i < context_len - n:
if context_token_ids[i] == pattern[j]:
i += 1
j += 1
# If we have matched the entire Y
if j == n:
# Found pattern in context, gather the next K elements
return context_token_ids[i:i + k]
else:
# Mismatch
if j != 0:
# Use the lps array to avoid re-checking elements
j = lps[j - 1]
else:
i += 1
# Y not found
return None

View File

@ -390,6 +390,7 @@ class InputBatch:
def make_sampling_metadata(
self,
req_id_output_token_ids: Dict[str, List[int]],
req_id_to_spec_token_ids: Dict[str, List[int]],
skip_copy: bool = False,
) -> SamplingMetadata:
if not skip_copy:
@ -423,7 +424,8 @@ class InputBatch:
self.prompt_token_ids = self._make_prompt_token_ids_tensor()
output_token_ids: List[List[int]] = []
spec_token_ids: List[List[int]] = []
rejection_sampling = False
for req_id in self.req_ids[:self.num_reqs]:
assert req_id is not None
# Currently we create a tensor for output_token_ids from scratch
@ -434,11 +436,18 @@ class InputBatch:
# TODO - Replace this with incremental update to output token
# statistics.
output_token_ids.append(req_id_output_token_ids[req_id])
req_spec_token_ids = req_id_to_spec_token_ids.get(req_id, [])
spec_token_ids.append(req_spec_token_ids)
if req_spec_token_ids:
# If any of the requests require speculative decoding, set the
# flag to True.
rejection_sampling = True
return SamplingMetadata(
temperature=self.temperature[:self.num_reqs],
all_greedy=self.all_greedy,
all_random=self.all_random,
rejection_sampling=rejection_sampling,
top_p=self.top_p[:self.num_reqs],
top_k=self.top_k[:self.num_reqs],
min_p=self.min_p[:self.num_reqs],
@ -452,6 +461,7 @@ class InputBatch:
presence_penalties=self.presence_penalties[:self.num_reqs],
repetition_penalties=self.repetition_penalties[:self.num_reqs],
output_token_ids=output_token_ids,
spec_token_ids=spec_token_ids,
min_tokens=self.min_tokens[:self.num_reqs],
stop_token_ids=self.stop_token_ids[:self.num_reqs],
no_penalties=self.no_penalties,

View File

@ -32,6 +32,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@ -180,6 +181,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.max_model_len,
self.max_num_tokens),
dtype=np.int32)
self.arange_cpu = torch.from_numpy(self.arange_np)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
# not make any assumptions about the values in these tensors.
@ -368,7 +370,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return batch_changed
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
def _prepare_inputs(
self, scheduler_output: "SchedulerOutput"
) -> Tuple[FlashAttentionMetadata, torch.Tensor]:
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
@ -382,12 +386,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens_list: List[int] = []
max_num_scheduled_tokens = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
all_spec_token_ids: List[int] = []
num_spec_tokens_list: List[int] = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens_list.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, [])
all_spec_token_ids.extend(spec_token_ids)
num_spec_tokens_list.append(len(spec_token_ids))
num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list,
dtype=np.int32)
assert max_num_scheduled_tokens > 0
@ -426,6 +437,79 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
use_spec_decode = len(all_spec_token_ids) > 0
if use_spec_decode:
# 1. Write spec_token_ids to input batch.
# Step 1. Get req indices that perform spec decode and repeat
# the req indices by the number of spec tokens. Note
# for requests that don't perform spec decode, the
# number of spec tokens is 0 and the req index is
# repeated 0 times.
# E.g., num_spec_tokens_list: [3, 0, 2, 0, 1]
# spec_req_indices: [0, 0, 0, 2, 2, 4]
spec_req_indices = np.repeat(self.arange_np[:num_reqs],
num_spec_tokens_list)
# spec_offsets: offsets within each spec token list.
# E.g., [1, 2, 3, 1, 2, 1], TODO: avoid the for loop here
spec_offsets = np.concatenate(
[self.arange_np[1:val + 1] for val in num_spec_tokens_list])
# spec_seq_offsets: offsets within each sequence.
# E.g., num_computed_tokens_cpu: [1, 4, 3, 6, 2]
# after repeating: [1, 1, 1, 3, 3, 2]
# spec_seq_offsets: [1, 1, 1, 3, 3, 2] + [1, 2, 3, 1, 2, 1]
# = [2, 3, 4, 4, 5, 3]
spec_seq_offsets = np.repeat(
self.input_batch.num_computed_tokens_cpu[:num_reqs],
num_spec_tokens_list) + spec_offsets
# cumsums_spec_offsets: [0, 0, 0, 2M, 2M, 4M] + [2, 3, 4, 4, 5, 3]
cumsums_spec_offsets = (
spec_seq_offsets +
spec_req_indices * self.input_batch.token_ids_cpu.shape[1])
cumsums_spec_offsets = torch.from_numpy(cumsums_spec_offsets).to(
torch.int64)
all_spec_token_ids = torch.tensor(all_spec_token_ids,
device="cpu",
dtype=self.input_ids_cpu.dtype)
# Step 2. Write spec token ids to input_ids_cpu.
self.input_batch.token_ids_cpu_tensor.flatten().scatter_(
0, cumsums_spec_offsets, all_spec_token_ids)
# 2. Get spec decode logits indices.
# E.g., num_scheduled_tokens: [4, 100, 3, 100, 2]
# cu_num_tokens: [4, 104, 107, 207, 209]
# num_spec_tokens_list: [3, 0, 2, 0, 1]
# num_sampled_tokens: [4, 1, 3, 1, 2]
# spec_decode_logits_indices:
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
num_spec_tokens_np = np.array(num_spec_tokens_list, dtype=np.int32)
num_sampled_tokens = num_spec_tokens_np + 1
# logits_start_loc: [0, 103, 104, 206, 207]
logits_start_loc = cu_num_tokens - num_sampled_tokens
# [0, 103, 104, 206, 207] ->
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207]
logits_start_loc = np.repeat(logits_start_loc, num_sampled_tokens)
# The following three lines:
# [4, 1, 3, 1, 2] -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
# Step 1. [4, 1, 3, 1, 2] -> [4, 5, 8, 9, 11]
cu_num_sampled_tokens = np.cumsum(num_sampled_tokens)
# Step 2. [4, 5, 8, 9, 11] -> [0, 4, 5, 8, 9]
# -> [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
cumsums_sampled_offsets = np.repeat(
cu_num_sampled_tokens - num_sampled_tokens, num_sampled_tokens)
# Step 3. [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# - [0, 0, 0, 0, 4, 5, 5, 5, 8, 9, 9]
# -> [0, 1, 2, 3, 0, 0, 1, 2, 0, 0, 1]
total_num_sampled_tokens = num_sampled_tokens.sum()
sampled_arange = (self.arange_np[:total_num_sampled_tokens] -
cumsums_sampled_offsets)
# [0, 0, 0, 0, 103, 104, 104, 104, 206, 207, 207] ->
# [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208]
spec_decode_logits_indices = logits_start_loc + sampled_arange
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
@ -519,16 +603,21 @@ class GPUModelRunner(LoRAModelRunnerMixin):
suffix_kv_lens=suffix_kv_lens,
)
if use_spec_decode:
logits_indices = torch.from_numpy(spec_decode_logits_indices).to(
self.device, non_blocking=True)
else:
# NOTE(woosuk): Due to chunked prefills, the batch may contain
# partial requests. While we should not sample any token
# from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
# Hot-Swap lora model
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# NOTE(woosuk): Due to chunked prefills, the batch may contain partial
# requests. While we should not sample any token from these partial
# requests, we do so for simplicity. We will ignore the sampled
# tokens from the partial requests.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return attn_metadata, logits_indices
def _compute_cascade_attn_prefix_len(
@ -673,6 +762,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _prepare_sampling(
self,
batch_changed: bool,
req_to_spec_token_ids: Dict[str, List[int]],
) -> SamplingMetadata:
# Create the sampling metadata.
req_id_output_token_ids: Dict[str, List[int]] = \
@ -680,7 +770,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for req_id, req in self.requests.items()}
sampling_metadata = self.input_batch.make_sampling_metadata(
req_id_output_token_ids, skip_copy=not batch_changed)
req_id_output_token_ids, req_to_spec_token_ids, not batch_changed)
return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
@ -847,7 +937,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits = self.model.compute_logits(sample_hidden_states, None)
# Sample the next token and get logprobs if needed.
sampling_metadata = self._prepare_sampling(batch_changed)
sampling_metadata = self._prepare_sampling(
batch_changed, scheduler_output.scheduled_spec_decode_tokens)
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
@ -857,18 +948,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
for i, req_id in enumerate( # type: ignore[assignment]
self.input_batch.req_ids[:num_reqs]):
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
assert seq_len <= req_state.num_tokens
if seq_len == req_state.num_tokens:
# Append the sampled token to the output token ids.
self.input_batch.num_tokens[i] += 1
# OPTIMIZATION: Priming the state updates for later updates.
req_state.output_token_ids.append(0)
if seq_len >= req_state.num_tokens:
request_seq_lens.append((i, req_state, seq_len))
else:
# Ignore the sampled token from the partial request.
@ -886,7 +971,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None
@ -897,16 +981,34 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output,
)
# Update with the actual token ids
for i, req_state, seq_len in request_seq_lens:
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids[-1] = token_id
# Update batch with the valid generated tokens.
sampled_token_ids = sampler_output.sampled_token_ids
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = sampled_token_ids.tolist()
for i, req_state, seq_len in request_seq_lens:
token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
self.input_batch.num_tokens[i] += 1
else:
valid_mask = sampled_token_ids != INVALID_TOKEN_ID
gen_lens = valid_mask.sum(dim=1).tolist()
valid_sampled_token_ids = [
seq.tolist()
for seq in sampled_token_ids[valid_mask].split(gen_lens)
]
self.input_batch.num_tokens[:num_reqs] += gen_lens
for i, req_state, seq_len in request_seq_lens:
target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1)
self.input_batch.token_ids_cpu[
i, target_slice] = valid_sampled_token_ids[i]
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids,
sampled_token_ids=valid_sampled_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
)

View File

@ -695,7 +695,7 @@ class TPUModelRunner:
model_runner_output = ModelRunnerOutput(
req_ids=all_req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids,
sampled_token_ids=[[token_id] for token_id in sampled_token_ids],
logprobs=None,
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
)