mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 05:15:42 +08:00
[Spec Decode] (1/2) Remove batch expansion (#8839)
This commit is contained in:
parent
22f5851b80
commit
1570203864
@ -208,7 +208,7 @@ steps:
|
|||||||
- tests/spec_decode
|
- tests/spec_decode
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||||
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
||||||
|
|
||||||
- label: LoRA Test %N # 15min each
|
- label: LoRA Test %N # 15min each
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
|
|||||||
@ -434,7 +434,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
|||||||
sampling_metadata = SamplingMetadata.prepare(
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
seq_group_metadata_list,
|
seq_group_metadata_list,
|
||||||
seq_lens=seq_lens if seq_lens else None,
|
seq_lens=seq_lens if seq_lens else None,
|
||||||
query_lens=seq_lens if seq_lens else None,
|
query_lens=seq_lens if seq_lens else [1] * batch_size,
|
||||||
device=device,
|
device=device,
|
||||||
pin_memory=is_pin_memory_available())
|
pin_memory=is_pin_memory_available())
|
||||||
# the logits tensor is modified in-place by the sampler
|
# the logits tensor is modified in-place by the sampler
|
||||||
|
|||||||
@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
|||||||
max_output_len=32,
|
max_output_len=32,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 3,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"speculative_disable_mqa_scorer": True,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||||
|
output_len: int, seed: int):
|
||||||
|
"""Verify that ngram speculative decoding generates the same output
|
||||||
|
with batch expansion scorer and mqa scorer.
|
||||||
|
"""
|
||||||
|
run_equality_correctness_test(vllm_runner,
|
||||||
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|||||||
@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
"speculative_disable_by_batch_size": 4
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"speculative_disable_mqa_scorer": True,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||||
|
output_len: int, seed: int):
|
||||||
|
"""Verify that speculative decoding generates the same output
|
||||||
|
with batch expansion scorer and mqa scorer.
|
||||||
|
"""
|
||||||
|
run_equality_correctness_test(vllm_runner,
|
||||||
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
max_output_len=output_len,
|
max_output_len=output_len,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model_name": MAIN_MODEL,
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"speculative_disable_mqa_scorer": True,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||||
|
output_len: int, seed: int):
|
||||||
|
"""Verify that speculative decoding generates the same output
|
||||||
|
with batch expansion scorer and mqa scorer.
|
||||||
|
"""
|
||||||
|
run_equality_correctness_test(vllm_runner,
|
||||||
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|||||||
@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
|||||||
max_output_len=output_len,
|
max_output_len=output_len,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model_name": "JackFram/llama-68m",
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
"speculative_model": "[ngram]",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"ngram_prompt_lookup_max": 3,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"speculative_disable_mqa_scorer": True,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_ngram_scorer(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
|
"""Verify that ngram speculative decoding generates the same output
|
||||||
|
with batch expansion scorer and mqa scorer.
|
||||||
|
"""
|
||||||
|
run_equality_correctness_test(vllm_runner,
|
||||||
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|||||||
@ -173,7 +173,6 @@ def test_same_output_for_multi_step():
|
|||||||
block_size,
|
block_size,
|
||||||
num_gpu_blocks,
|
num_gpu_blocks,
|
||||||
seed,
|
seed,
|
||||||
model_runner_cls=TP1DraftModelRunner,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
worker = create_worker(
|
worker = create_worker(
|
||||||
|
|||||||
65
tests/spec_decode/test_scorer.py
Normal file
65
tests/spec_decode/test_scorer.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.sequence import ExecuteModelRequest
|
||||||
|
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
||||||
|
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
|
||||||
|
from vllm.spec_decode.mqa_scorer import MQAScorer
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
from .utils import create_batch, create_worker
|
||||||
|
|
||||||
|
|
||||||
|
def create_proposal(batch_size: int, propose_len: int, vocab_size: int,
|
||||||
|
device: str) -> SpeculativeProposals:
|
||||||
|
proposal_probs = torch.rand((batch_size, propose_len, vocab_size),
|
||||||
|
device=device)
|
||||||
|
proposal_token_ids = torch.argmax(proposal_probs, dim=-1)
|
||||||
|
proposal_lens = torch.tensor([propose_len] * batch_size, device=device)
|
||||||
|
return SpeculativeProposals(proposal_token_ids, proposal_probs,
|
||||||
|
proposal_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_score_equal(score1: SpeculativeScores,
|
||||||
|
score2: SpeculativeScores) -> None:
|
||||||
|
assert torch.allclose(score1.probs, score2.probs)
|
||||||
|
assert torch.allclose(score1.logprobs, score2.logprobs)
|
||||||
|
assert torch.equal(score1.token_ids, score2.token_ids)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('model_name', ['facebook/opt-125m'])
|
||||||
|
@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16])
|
||||||
|
@pytest.mark.parametrize('propose_len', [1, 3, 5])
|
||||||
|
@pytest.mark.parametrize('device', ['cuda'])
|
||||||
|
def test_scoroer(model_name: str, batch_size: int, propose_len: int,
|
||||||
|
device: str) -> None:
|
||||||
|
"""
|
||||||
|
Compare the batch expansion scorer and mqa scorer return the same score
|
||||||
|
"""
|
||||||
|
seed = 0
|
||||||
|
block_size = 32
|
||||||
|
num_gpu_blocks = 2048 // block_size
|
||||||
|
scorer_worker = create_worker(Worker, model_name, block_size,
|
||||||
|
num_gpu_blocks, seed)
|
||||||
|
scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||||
|
scorer_worker.model_runner.model.sampler.\
|
||||||
|
should_modify_greedy_probs_inplace = True
|
||||||
|
|
||||||
|
vocab_size = scorer_worker.vocab_size
|
||||||
|
proposals = create_proposal(batch_size, propose_len, vocab_size, device)
|
||||||
|
seq_group_metadatalist, _, _ = create_batch(batch_size,
|
||||||
|
propose_len,
|
||||||
|
block_size=block_size,
|
||||||
|
num_gpu_blocks=num_gpu_blocks)
|
||||||
|
requests = ExecuteModelRequest(seq_group_metadatalist,
|
||||||
|
num_lookahead_slots=propose_len)
|
||||||
|
|
||||||
|
batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device,
|
||||||
|
vocab_size)
|
||||||
|
batch_expansion_score = batch_expansion_scorer.score_proposals(
|
||||||
|
requests, proposals)
|
||||||
|
|
||||||
|
mqa_scorer = MQAScorer(scorer_worker, device, vocab_size)
|
||||||
|
mqa_score = mqa_scorer.score_proposals(requests, proposals)
|
||||||
|
|
||||||
|
assert_score_equal(batch_expansion_score, mqa_score)
|
||||||
@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int,
|
|||||||
@pytest.mark.parametrize("acceptance_sampler_method",
|
@pytest.mark.parametrize("acceptance_sampler_method",
|
||||||
["rejection_sampler", "typical_acceptance_sampler"])
|
["rejection_sampler", "typical_acceptance_sampler"])
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_correctly_calls_target_model(k: int, batch_size: int,
|
def test_batch_expansion_correctly_calls_target_model(
|
||||||
acceptance_sampler_method: str):
|
k: int, batch_size: int, acceptance_sampler_method: str):
|
||||||
"""Verify SpecDecodeWorker calls the target model with correct
|
"""Verify SpecDecodeWorker calls the target model with correct
|
||||||
inputs. Everything else is mocked out.
|
inputs with batch expansion. Everything else is mocked out.
|
||||||
"""
|
"""
|
||||||
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False)
|
||||||
target_worker = mock_worker(use_spec=False)
|
target_worker = mock_worker(use_spec=False)
|
||||||
@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int,
|
|||||||
target_worker,
|
target_worker,
|
||||||
mock_spec_decode_sampler(acceptance_sampler_method),
|
mock_spec_decode_sampler(acceptance_sampler_method),
|
||||||
disable_logprobs=False,
|
disable_logprobs=False,
|
||||||
metrics_collector=metrics_collector)
|
metrics_collector=metrics_collector,
|
||||||
|
disable_mqa_scorer=True)
|
||||||
worker.init_device()
|
worker.init_device()
|
||||||
|
|
||||||
vocab_size = 32_000
|
vocab_size = 32_000
|
||||||
|
|||||||
@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts(
|
|||||||
for i, final_len in enumerate(final_prompt_lens)
|
for i, final_len in enumerate(final_prompt_lens)
|
||||||
}
|
}
|
||||||
|
|
||||||
return [
|
seq_grou_metadata_list = []
|
||||||
SequenceGroupMetadata(
|
for i, (prompt_token_ids,
|
||||||
request_id=str(i),
|
cont_token_ids) in enumerate(zip(prompts, continuations)):
|
||||||
is_prompt=len(cont_token_ids) == 0,
|
data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids)
|
||||||
seq_data={
|
data.update_num_computed_tokens(
|
||||||
i: SequenceData.from_seqs(prompt_token_ids[:],
|
len(prompt_token_ids) + len(cont_token_ids) - 1)
|
||||||
cont_token_ids[:]),
|
seq_data = {i: data}
|
||||||
},
|
seq_grou_metadata_list.append(
|
||||||
sampling_params=SamplingParams(temperature=0.0, ),
|
SequenceGroupMetadata(
|
||||||
block_tables={i: block_allocations[i][:]},
|
request_id=str(i),
|
||||||
) for i, (prompt_token_ids,
|
is_prompt=len(cont_token_ids) == 0,
|
||||||
cont_token_ids) in enumerate(zip(prompts, continuations))
|
seq_data=seq_data,
|
||||||
]
|
sampling_params=SamplingParams(temperature=0.0),
|
||||||
|
block_tables={i: block_allocations[i][:]},
|
||||||
|
))
|
||||||
|
return seq_grou_metadata_list
|
||||||
|
|
||||||
|
|
||||||
def assert_logprobs_dict_allclose(
|
def assert_logprobs_dict_allclose(
|
||||||
|
|||||||
@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
|
|||||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||||
use_cuda_graph: bool
|
use_cuda_graph: bool
|
||||||
|
|
||||||
|
# Number of query tokens for each request in the batch.
|
||||||
|
# Currently, we require that all requests have the same number of query
|
||||||
|
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||||
|
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||||
|
decode_query_len: Optional[int] = None
|
||||||
|
|
||||||
_cached_prefill_metadata: Optional[
|
_cached_prefill_metadata: Optional[
|
||||||
"BlocksparseFlashAttentionMetadata"] = None
|
"BlocksparseFlashAttentionMetadata"] = None
|
||||||
_cached_decode_metadata: Optional[
|
_cached_decode_metadata: Optional[
|
||||||
|
|||||||
@ -245,8 +245,15 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
# |-------------------- seq_len ---------------------|
|
# |-------------------- seq_len ---------------------|
|
||||||
# |-- query_len ---|
|
# |-- query_len ---|
|
||||||
|
|
||||||
# Maximum query length in the batch. None for decoding.
|
# Maximum query length in the batch.
|
||||||
max_query_len: Optional[int]
|
max_query_len: Optional[int]
|
||||||
|
|
||||||
|
# Number of query tokens for each request in the batch.
|
||||||
|
# Currently, we require that all requests have the same number of query
|
||||||
|
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||||
|
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||||
|
decode_query_len: Optional[int]
|
||||||
|
|
||||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||||
# requests only.
|
# requests only.
|
||||||
max_prefill_seq_len: int
|
max_prefill_seq_len: int
|
||||||
@ -303,6 +310,7 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
|
||||||
seq_lens=self.seq_lens[:self.num_prefills],
|
seq_lens=self.seq_lens[:self.num_prefills],
|
||||||
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
|
||||||
|
decode_query_len=0,
|
||||||
max_query_len=self.max_query_len,
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||||
max_decode_seq_len=0,
|
max_decode_seq_len=0,
|
||||||
@ -331,7 +339,8 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
|
||||||
max_query_len=None,
|
decode_query_len=self.decode_query_len,
|
||||||
|
max_query_len=self.max_query_len,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=self.max_decode_seq_len,
|
max_decode_seq_len=self.max_decode_seq_len,
|
||||||
query_start_loc=None,
|
query_start_loc=None,
|
||||||
@ -461,9 +470,6 @@ class FlashAttentionMetadataBuilder(
|
|||||||
self.num_prefill_tokens += token_len
|
self.num_prefill_tokens += token_len
|
||||||
self.prefill_seq_lens.append(seq_len)
|
self.prefill_seq_lens.append(seq_len)
|
||||||
else:
|
else:
|
||||||
assert query_len == 1, (
|
|
||||||
"seq_len: {}, context_len: {}, query_len: {}".format(
|
|
||||||
seq_len, context_len, query_len))
|
|
||||||
self.num_decode_tokens += query_len
|
self.num_decode_tokens += query_len
|
||||||
self.curr_seq_lens.append(curr_seq_len)
|
self.curr_seq_lens.append(curr_seq_len)
|
||||||
|
|
||||||
@ -518,6 +524,11 @@ class FlashAttentionMetadataBuilder(
|
|||||||
use_captured_graph = cuda_graph_pad_size != -1
|
use_captured_graph = cuda_graph_pad_size != -1
|
||||||
|
|
||||||
max_query_len = max(query_lens)
|
max_query_len = max(query_lens)
|
||||||
|
decode_query_lens = query_lens[self.num_prefills:]
|
||||||
|
if len(decode_query_lens) > 0:
|
||||||
|
decode_query_len = max(decode_query_lens)
|
||||||
|
else:
|
||||||
|
decode_query_len = 1
|
||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||||
num_decode_tokens = self.num_decode_tokens
|
num_decode_tokens = self.num_decode_tokens
|
||||||
@ -586,6 +597,7 @@ class FlashAttentionMetadataBuilder(
|
|||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
seq_lens_tensor=seq_lens_tensor,
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
|
decode_query_len=decode_query_len,
|
||||||
max_prefill_seq_len=max_prefill_seq_len,
|
max_prefill_seq_len=max_prefill_seq_len,
|
||||||
max_decode_seq_len=max_decode_seq_len,
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
@ -786,8 +798,12 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
# Decoding run.
|
# Decoding run.
|
||||||
|
_, num_head, head_dim = decode_query.shape
|
||||||
|
decode_query = decode_query.reshape(-1,
|
||||||
|
decode_meta.decode_query_len,
|
||||||
|
num_head, head_dim)
|
||||||
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
|
decode_output = torch.ops.vllm.flash_attn_with_kvcache(
|
||||||
decode_query.unsqueeze(1),
|
decode_query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
block_table=decode_meta.block_tables,
|
block_table=decode_meta.block_tables,
|
||||||
@ -796,7 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
causal=True,
|
causal=True,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
softcap=self.logits_soft_cap,
|
softcap=self.logits_soft_cap,
|
||||||
).squeeze(1)
|
)
|
||||||
|
|
||||||
if prefill_output is None:
|
if prefill_output is None:
|
||||||
assert decode_output is not None
|
assert decode_output is not None
|
||||||
@ -804,5 +820,11 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
if decode_output is None:
|
if decode_output is None:
|
||||||
assert prefill_output is not None
|
assert prefill_output is not None
|
||||||
return prefill_output.view(num_prefill_tokens, hidden_size)
|
return prefill_output.view(num_prefill_tokens, hidden_size)
|
||||||
|
|
||||||
|
# Chunked prefill does not work with speculative decoding.
|
||||||
|
# Therefore, the query length for decode should be 1 in chunked prefill.
|
||||||
|
assert decode_meta is not None
|
||||||
|
assert decode_meta.decode_query_len == 1
|
||||||
|
decode_output = decode_output.squeeze(1)
|
||||||
output = torch.cat([prefill_output, decode_output], dim=0)
|
output = torch.cat([prefill_output, decode_output], dim=0)
|
||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|||||||
@ -595,7 +595,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
use_captured_graph = cuda_graph_pad_size != -1
|
use_captured_graph = cuda_graph_pad_size != -1
|
||||||
|
|
||||||
max_query_len = max(query_lens)
|
|
||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||||
num_decode_tokens = self.num_decode_tokens
|
num_decode_tokens = self.num_decode_tokens
|
||||||
|
|
||||||
@ -634,7 +633,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
|
||||||
|
|
||||||
assert device is not None
|
assert device is not None
|
||||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||||
|
|||||||
@ -116,9 +116,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# Cuda-graph is currently enabled for decoding only.
|
# Cuda-graph is currently enabled for decoding only.
|
||||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||||
use_cuda_graph: bool
|
use_cuda_graph: bool
|
||||||
|
|
||||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||||
# so far).
|
# so far).
|
||||||
context_lens_tensor: Optional[torch.Tensor]
|
context_lens_tensor: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# Number of query tokens for each request in the batch.
|
||||||
|
# Currently, we require that all requests have the same number of query
|
||||||
|
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||||
|
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||||
|
decode_query_len: Optional[int] = None
|
||||||
|
|
||||||
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
_cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||||
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
_cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None
|
||||||
|
|
||||||
|
|||||||
@ -312,7 +312,8 @@ class CommonAttentionState(AttentionState):
|
|||||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||||
seq_lens=None,
|
seq_lens=None,
|
||||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||||
max_query_len=None,
|
max_query_len=1,
|
||||||
|
decode_query_len=1,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||||
query_start_loc=None,
|
query_start_loc=None,
|
||||||
|
|||||||
@ -118,6 +118,12 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
# Maximum query length in the batch. None for decoding.
|
# Maximum query length in the batch. None for decoding.
|
||||||
max_query_len: Optional[int] = None
|
max_query_len: Optional[int] = None
|
||||||
|
|
||||||
|
# Number of query tokens for each request in the batch.
|
||||||
|
# Currently, we require that all requests have the same number of query
|
||||||
|
# tokens during the decoding phase. When speculavie decoding is enabled,
|
||||||
|
# decode_query_len might be greater than 1. In all other cases, it is 1.
|
||||||
|
decode_query_len: Optional[int] = None
|
||||||
|
|
||||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||||
# the batch, used to index into subquery. E.g., if the subquery length
|
# the batch, used to index into subquery. E.g., if the subquery length
|
||||||
# is [4, 6], it is [0, 4, 10].
|
# is [4, 6], it is [0, 4, 10].
|
||||||
|
|||||||
@ -1116,6 +1116,7 @@ class SpeculativeConfig:
|
|||||||
speculative_model_quantization: Optional[str],
|
speculative_model_quantization: Optional[str],
|
||||||
speculative_draft_tensor_parallel_size: Optional[int],
|
speculative_draft_tensor_parallel_size: Optional[int],
|
||||||
num_speculative_tokens: Optional[int],
|
num_speculative_tokens: Optional[int],
|
||||||
|
speculative_disable_mqa_scorer: Optional[bool],
|
||||||
speculative_max_model_len: Optional[int],
|
speculative_max_model_len: Optional[int],
|
||||||
enable_chunked_prefill: bool,
|
enable_chunked_prefill: bool,
|
||||||
use_v2_block_manager: bool,
|
use_v2_block_manager: bool,
|
||||||
@ -1150,6 +1151,9 @@ class SpeculativeConfig:
|
|||||||
num_speculative_tokens (Optional[int]): The number of speculative
|
num_speculative_tokens (Optional[int]): The number of speculative
|
||||||
tokens, if provided. Will default to the number in the draft
|
tokens, if provided. Will default to the number in the draft
|
||||||
model config if present, otherwise is required.
|
model config if present, otherwise is required.
|
||||||
|
speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
|
||||||
|
scorer for the speculative model and fall back to batch
|
||||||
|
expansion for scoring.
|
||||||
speculative_max_model_len (Optional[int]): The maximum model len of
|
speculative_max_model_len (Optional[int]): The maximum model len of
|
||||||
the speculative model. Used when testing the ability to skip
|
the speculative model. Used when testing the ability to skip
|
||||||
speculation for some sequences.
|
speculation for some sequences.
|
||||||
@ -1304,6 +1308,7 @@ class SpeculativeConfig:
|
|||||||
draft_model_config,
|
draft_model_config,
|
||||||
draft_parallel_config,
|
draft_parallel_config,
|
||||||
num_speculative_tokens,
|
num_speculative_tokens,
|
||||||
|
speculative_disable_mqa_scorer,
|
||||||
speculative_disable_by_batch_size,
|
speculative_disable_by_batch_size,
|
||||||
ngram_prompt_lookup_max,
|
ngram_prompt_lookup_max,
|
||||||
ngram_prompt_lookup_min,
|
ngram_prompt_lookup_min,
|
||||||
@ -1400,6 +1405,7 @@ class SpeculativeConfig:
|
|||||||
draft_model_config: ModelConfig,
|
draft_model_config: ModelConfig,
|
||||||
draft_parallel_config: ParallelConfig,
|
draft_parallel_config: ParallelConfig,
|
||||||
num_speculative_tokens: int,
|
num_speculative_tokens: int,
|
||||||
|
speculative_disable_mqa_scorer: Optional[bool],
|
||||||
speculative_disable_by_batch_size: Optional[int],
|
speculative_disable_by_batch_size: Optional[int],
|
||||||
ngram_prompt_lookup_max: Optional[int],
|
ngram_prompt_lookup_max: Optional[int],
|
||||||
ngram_prompt_lookup_min: Optional[int],
|
ngram_prompt_lookup_min: Optional[int],
|
||||||
@ -1446,6 +1452,7 @@ class SpeculativeConfig:
|
|||||||
self.draft_model_config = draft_model_config
|
self.draft_model_config = draft_model_config
|
||||||
self.draft_parallel_config = draft_parallel_config
|
self.draft_parallel_config = draft_parallel_config
|
||||||
self.num_speculative_tokens = num_speculative_tokens
|
self.num_speculative_tokens = num_speculative_tokens
|
||||||
|
self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
|
||||||
self.speculative_disable_by_batch_size = \
|
self.speculative_disable_by_batch_size = \
|
||||||
speculative_disable_by_batch_size
|
speculative_disable_by_batch_size
|
||||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
||||||
|
|||||||
@ -162,6 +162,7 @@ class EngineArgs:
|
|||||||
speculative_model_quantization: Optional[str] = None
|
speculative_model_quantization: Optional[str] = None
|
||||||
speculative_draft_tensor_parallel_size: Optional[int] = None
|
speculative_draft_tensor_parallel_size: Optional[int] = None
|
||||||
num_speculative_tokens: Optional[int] = None
|
num_speculative_tokens: Optional[int] = None
|
||||||
|
speculative_disable_mqa_scorer: Optional[bool] = False
|
||||||
speculative_max_model_len: Optional[int] = None
|
speculative_max_model_len: Optional[int] = None
|
||||||
speculative_disable_by_batch_size: Optional[int] = None
|
speculative_disable_by_batch_size: Optional[int] = None
|
||||||
ngram_prompt_lookup_max: Optional[int] = None
|
ngram_prompt_lookup_max: Optional[int] = None
|
||||||
@ -640,6 +641,12 @@ class EngineArgs:
|
|||||||
default=EngineArgs.num_speculative_tokens,
|
default=EngineArgs.num_speculative_tokens,
|
||||||
help='The number of speculative tokens to sample from '
|
help='The number of speculative tokens to sample from '
|
||||||
'the draft model in speculative decoding.')
|
'the draft model in speculative decoding.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--speculative-disable-mqa-scorer',
|
||||||
|
action='store_true',
|
||||||
|
help=
|
||||||
|
'If set to True, the MQA scorer will be disabled in speculative '
|
||||||
|
' and fall back to batch expansion')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--speculative-draft-tensor-parallel-size',
|
'--speculative-draft-tensor-parallel-size',
|
||||||
'-spec-draft-tp',
|
'-spec-draft-tp',
|
||||||
@ -970,6 +977,7 @@ class EngineArgs:
|
|||||||
speculative_draft_tensor_parallel_size = \
|
speculative_draft_tensor_parallel_size = \
|
||||||
self.speculative_draft_tensor_parallel_size,
|
self.speculative_draft_tensor_parallel_size,
|
||||||
num_speculative_tokens=self.num_speculative_tokens,
|
num_speculative_tokens=self.num_speculative_tokens,
|
||||||
|
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
|
||||||
speculative_disable_by_batch_size=self.
|
speculative_disable_by_batch_size=self.
|
||||||
speculative_disable_by_batch_size,
|
speculative_disable_by_batch_size,
|
||||||
speculative_max_model_len=self.speculative_max_model_len,
|
speculative_max_model_len=self.speculative_max_model_len,
|
||||||
|
|||||||
@ -1110,6 +1110,8 @@ class LLMEngine:
|
|||||||
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
|
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
|
||||||
len(output),
|
len(output),
|
||||||
is_first_step_output)
|
is_first_step_output)
|
||||||
|
elif not is_async:
|
||||||
|
seq_group.update_num_computed_tokens(1)
|
||||||
|
|
||||||
if outputs:
|
if outputs:
|
||||||
for o in outputs:
|
for o in outputs:
|
||||||
@ -1133,8 +1135,16 @@ class LLMEngine:
|
|||||||
else:
|
else:
|
||||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||||
if seq_group_meta.do_sample:
|
if seq_group_meta.do_sample:
|
||||||
self.output_processor.process_outputs(
|
output_token_num = self.output_processor.process_outputs(
|
||||||
seq_group, output, is_async)
|
seq_group, output, is_async)
|
||||||
|
if self.speculative_config:
|
||||||
|
# We -1 here because we always
|
||||||
|
# (w/o speculative decoding) add the number of
|
||||||
|
# computed tokens by one in the decoding phase.
|
||||||
|
# Therefore, we remove that one token that
|
||||||
|
# is already added.
|
||||||
|
seq_group.update_num_computed_tokens(output_token_num -
|
||||||
|
1)
|
||||||
|
|
||||||
if seq_group.is_finished():
|
if seq_group.is_finished():
|
||||||
finished_now.append(i)
|
finished_now.append(i)
|
||||||
@ -1251,11 +1261,12 @@ class LLMEngine:
|
|||||||
# decodes after the very first step. Therefore,
|
# decodes after the very first step. Therefore,
|
||||||
# we skip the update to the num_computed_tokens
|
# we skip the update to the num_computed_tokens
|
||||||
# here.
|
# here.
|
||||||
pass
|
seq_group.update_num_computed_tokens(1)
|
||||||
else:
|
else:
|
||||||
seq_group.update_num_computed_tokens(
|
seq_group.update_num_computed_tokens(
|
||||||
seq_group_metadata.token_chunk_size)
|
seq_group_metadata.token_chunk_size)
|
||||||
|
else:
|
||||||
|
seq_group.update_num_computed_tokens(1)
|
||||||
if seq_group_metadata.do_sample:
|
if seq_group_metadata.do_sample:
|
||||||
assert len(sequence_group_outputs.samples) == 1, (
|
assert len(sequence_group_outputs.samples) == 1, (
|
||||||
"Async output processor expects a single sample"
|
"Async output processor expects a single sample"
|
||||||
@ -1266,7 +1277,6 @@ class LLMEngine:
|
|||||||
assert len(seq_group.seqs) == 1
|
assert len(seq_group.seqs) == 1
|
||||||
seq = seq_group.seqs[0]
|
seq = seq_group.seqs[0]
|
||||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||||
seq_group.update_num_computed_tokens(1)
|
|
||||||
|
|
||||||
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||||
"""Performs one decoding iteration and returns newly generated results.
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
from vllm.config import SchedulerConfig
|
from vllm.config import SchedulerConfig
|
||||||
from vllm.core.scheduler import Scheduler
|
from vllm.core.scheduler import Scheduler
|
||||||
@ -58,10 +58,14 @@ class SequenceGroupOutputProcessor(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process_outputs(self, sequence_group: SequenceGroup,
|
def process_outputs(self, sequence_group: SequenceGroup,
|
||||||
outputs: List[SequenceGroupOutput],
|
outputs: List[SequenceGroupOutput],
|
||||||
is_async: bool) -> None:
|
is_async: bool) -> Optional[int]:
|
||||||
"""Process new token ids for the sequence group. Handles logic such as
|
"""Process new token ids for the sequence group. Handles logic such as
|
||||||
detokenization, stop checking, and freeing/forking sequences in the
|
detokenization, stop checking, and freeing/forking sequences in the
|
||||||
scheduler.
|
scheduler.
|
||||||
|
|
||||||
|
Return the number of new tokens generated in the sequence group.
|
||||||
|
The returned value is optional because it is only used for
|
||||||
|
speculative decoding mqa scorer.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import functools
|
import functools
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
from vllm.core.scheduler import Scheduler
|
from vllm.core.scheduler import Scheduler
|
||||||
from vllm.engine.output_processor.interfaces import (
|
from vllm.engine.output_processor.interfaces import (
|
||||||
@ -69,7 +69,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
def process_outputs(self,
|
def process_outputs(self,
|
||||||
sequence_group: SequenceGroup,
|
sequence_group: SequenceGroup,
|
||||||
outputs: List[SequenceGroupOutput],
|
outputs: List[SequenceGroupOutput],
|
||||||
is_async: bool = False) -> None:
|
is_async: bool = False) -> Optional[int]:
|
||||||
"""Append new tokens in the outputs to sequences in the sequence group.
|
"""Append new tokens in the outputs to sequences in the sequence group.
|
||||||
|
|
||||||
This only supports sequence groups of size 1. It supports greater than
|
This only supports sequence groups of size 1. It supports greater than
|
||||||
@ -84,6 +84,10 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
tokens from the previous step. If this is true, then
|
tokens from the previous step. If this is true, then
|
||||||
no tokens need to be appended since it is already done
|
no tokens need to be appended since it is already done
|
||||||
externally (before the next schedule() call)
|
externally (before the next schedule() call)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The number of tokens appended to the sequence. This is optional
|
||||||
|
because only speculative decode uses this return value.
|
||||||
"""
|
"""
|
||||||
# Sequences can be in RUNNING or FINISHED_ABORTED state
|
# Sequences can be in RUNNING or FINISHED_ABORTED state
|
||||||
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
|
# once scheduled, as a sequence is moved to FINSIHED_ABORTED
|
||||||
@ -106,6 +110,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
# was already appended, so we only need to do the rest of the
|
# was already appended, so we only need to do the rest of the
|
||||||
# postprocessor: Detokenization + stopping logic
|
# postprocessor: Detokenization + stopping logic
|
||||||
self._process_decode_and_stop(seq, sequence_group.sampling_params)
|
self._process_decode_and_stop(seq, sequence_group.sampling_params)
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
# Standard multi-step case
|
# Standard multi-step case
|
||||||
|
|
||||||
@ -121,8 +126,8 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
]
|
]
|
||||||
assert valid_samples
|
assert valid_samples
|
||||||
|
|
||||||
self._process_seq_outputs(seq, valid_samples,
|
return self._process_seq_outputs(seq, valid_samples,
|
||||||
sequence_group.sampling_params)
|
sequence_group.sampling_params)
|
||||||
|
|
||||||
def _process_decode_and_stop(self, seq: Sequence,
|
def _process_decode_and_stop(self, seq: Sequence,
|
||||||
sampling_params: SamplingParams) -> None:
|
sampling_params: SamplingParams) -> None:
|
||||||
@ -140,7 +145,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
|
|
||||||
def _process_seq_outputs(self, seq: Sequence,
|
def _process_seq_outputs(self, seq: Sequence,
|
||||||
valid_samples: List[SequenceOutput],
|
valid_samples: List[SequenceOutput],
|
||||||
sampling_params: SamplingParams) -> None:
|
sampling_params: SamplingParams) -> int:
|
||||||
output_token_ids = [sample.output_token for sample in valid_samples]
|
output_token_ids = [sample.output_token for sample in valid_samples]
|
||||||
output_logprobs = [sample.logprobs for sample in valid_samples]
|
output_logprobs = [sample.logprobs for sample in valid_samples]
|
||||||
|
|
||||||
@ -148,7 +153,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
|
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
|
||||||
len(output_token_ids))
|
len(output_token_ids))
|
||||||
if remaining_tokens < 0:
|
if remaining_tokens < 0:
|
||||||
valid_samples = valid_samples[:remaining_tokens]
|
|
||||||
output_token_ids = output_token_ids[:remaining_tokens]
|
output_token_ids = output_token_ids[:remaining_tokens]
|
||||||
|
|
||||||
# Truncate any tokens after EOS. This is required as spec decode
|
# Truncate any tokens after EOS. This is required as spec decode
|
||||||
@ -162,7 +166,6 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
for i in range(len(output_token_ids)):
|
for i in range(len(output_token_ids)):
|
||||||
if output_token_ids[i] == eos_token_id:
|
if output_token_ids[i] == eos_token_id:
|
||||||
output_token_ids = output_token_ids[:i + 1]
|
output_token_ids = output_token_ids[:i + 1]
|
||||||
valid_samples = valid_samples[:i + 1]
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# Incrementally append tokens to the sequence, as if we had only one new
|
# Incrementally append tokens to the sequence, as if we had only one new
|
||||||
@ -173,9 +176,9 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
token_id=output_token_id,
|
token_id=output_token_id,
|
||||||
logprobs=output_logprob,
|
logprobs=output_logprob,
|
||||||
)
|
)
|
||||||
seq.data.update_num_computed_tokens(1)
|
|
||||||
|
|
||||||
self._process_decode_and_stop(seq, sampling_params)
|
self._process_decode_and_stop(seq, sampling_params)
|
||||||
|
|
||||||
if seq.is_finished():
|
if seq.is_finished():
|
||||||
break
|
break
|
||||||
|
return len(output_token_ids)
|
||||||
|
|||||||
@ -912,7 +912,7 @@ def get_logprobs(
|
|||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
sample_results: SampleResultType,
|
sample_results: SampleResultType,
|
||||||
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
|
||||||
"""Return sample lobprobs and prompt logprobs.
|
"""Return sample logprobs and prompt logprobs.
|
||||||
|
|
||||||
The logic consists of 3 parts.
|
The logic consists of 3 parts.
|
||||||
- Select indices to compute logprob from, ranks of token ids, and
|
- Select indices to compute logprob from, ranks of token ids, and
|
||||||
|
|||||||
@ -146,7 +146,7 @@ class SamplingMetadata:
|
|||||||
def prepare(
|
def prepare(
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
seq_lens: List[int],
|
seq_lens: List[int],
|
||||||
query_lens: Optional[List[int]],
|
query_lens: List[int],
|
||||||
device: str,
|
device: str,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
generators: Optional[Dict[str, torch.Generator]] = None,
|
generators: Optional[Dict[str, torch.Generator]] = None,
|
||||||
@ -194,7 +194,7 @@ class SamplingMetadata:
|
|||||||
def _prepare_seq_groups(
|
def _prepare_seq_groups(
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
seq_lens: List[int],
|
seq_lens: List[int],
|
||||||
query_lens: Optional[List[int]],
|
query_lens: List[int],
|
||||||
device: str,
|
device: str,
|
||||||
generators: Optional[Dict[str, torch.Generator]] = None,
|
generators: Optional[Dict[str, torch.Generator]] = None,
|
||||||
cache: Optional[SamplingMetadataCache] = None,
|
cache: Optional[SamplingMetadataCache] = None,
|
||||||
@ -284,7 +284,8 @@ def _prepare_seq_groups(
|
|||||||
else:
|
else:
|
||||||
# Decode
|
# Decode
|
||||||
prompt_logprob_len = 0
|
prompt_logprob_len = 0
|
||||||
sample_len = len(seq_ids) if do_sample else 0
|
query_len = query_lens[i] if query_lens is not None else 1
|
||||||
|
sample_len = len(seq_ids) * query_len if do_sample else 0
|
||||||
|
|
||||||
if sampling_params.seed is not None and generators is not None:
|
if sampling_params.seed is not None and generators is not None:
|
||||||
generator = generators.get(seq_group_metadata.request_id)
|
generator = generators.get(seq_group_metadata.request_id)
|
||||||
@ -440,14 +441,14 @@ class SamplingTensors:
|
|||||||
|
|
||||||
if seq_group.do_sample:
|
if seq_group.do_sample:
|
||||||
sample_lens = len(seq_group.sample_indices)
|
sample_lens = len(seq_group.sample_indices)
|
||||||
assert sample_lens == len(seq_ids)
|
assert sample_lens >= len(seq_ids)
|
||||||
temperatures += [temperature] * len(seq_ids)
|
temperatures += [temperature] * sample_lens
|
||||||
top_ps += [top_p] * len(seq_ids)
|
top_ps += [top_p] * sample_lens
|
||||||
top_ks += [top_k] * len(seq_ids)
|
top_ks += [top_k] * sample_lens
|
||||||
min_ps += [min_p] * len(seq_ids)
|
min_ps += [min_p] * sample_lens
|
||||||
presence_penalties += [p] * len(seq_ids)
|
presence_penalties += [p] * sample_lens
|
||||||
frequency_penalties += [f] * len(seq_ids)
|
frequency_penalties += [f] * sample_lens
|
||||||
repetition_penalties += [r] * len(seq_ids)
|
repetition_penalties += [r] * sample_lens
|
||||||
|
|
||||||
if do_penalties:
|
if do_penalties:
|
||||||
for seq_group in sampling_metadata.seq_groups:
|
for seq_group in sampling_metadata.seq_groups:
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
|
|||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
|
||||||
from vllm.worker.worker_base import WorkerBase
|
|
||||||
|
|
||||||
SeqId = int
|
SeqId = int
|
||||||
TargetSeqId = int
|
TargetSeqId = int
|
||||||
@ -36,12 +35,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
|||||||
of topk/tree.
|
of topk/tree.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, scorer_worker: WorkerBase, device: str,
|
|
||||||
vocab_size: int):
|
|
||||||
self._scorer_worker = scorer_worker
|
|
||||||
self._device = device
|
|
||||||
self._vocab_size = vocab_size
|
|
||||||
|
|
||||||
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
@nvtx_range("BatchExpansionTop1Scorer.score_proposals")
|
||||||
def score_proposals(
|
def score_proposals(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -94,8 +94,6 @@ class TP1DraftModelRunner(ModelRunner):
|
|||||||
assert seq_group.is_prompt is False # No prompt
|
assert seq_group.is_prompt is False # No prompt
|
||||||
assert seq_group.prompt_logprob_indices == [] # No prompt
|
assert seq_group.prompt_logprob_indices == [] # No prompt
|
||||||
assert seq_group.sample_indices == [i] # Simple
|
assert seq_group.sample_indices == [i] # Simple
|
||||||
assert seq_group.seq_len is None # Decode
|
|
||||||
assert seq_group.query_len is None # Decode
|
|
||||||
|
|
||||||
def _gpu_advance_step(
|
def _gpu_advance_step(
|
||||||
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
self, model_input: ModelInputForGPUWithSamplingMetadata,
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from typing import Optional, Set
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -74,6 +75,12 @@ class SpeculativeProposer(ABC):
|
|||||||
|
|
||||||
class SpeculativeScorer(ABC):
|
class SpeculativeScorer(ABC):
|
||||||
|
|
||||||
|
def __init__(self, scorer_worker: WorkerBase, device: str,
|
||||||
|
vocab_size: int):
|
||||||
|
self._scorer_worker = scorer_worker
|
||||||
|
self._device = device
|
||||||
|
self._vocab_size = vocab_size
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def score_proposals(
|
def score_proposals(
|
||||||
self,
|
self,
|
||||||
|
|||||||
80
vllm/spec_decode/mqa_scorer.py
Normal file
80
vllm/spec_decode/mqa_scorer.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from vllm.sequence import (ExecuteModelRequest, SequenceData,
|
||||||
|
SequenceGroupMetadata, get_all_seq_ids)
|
||||||
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
|
SpeculativeScorer, SpeculativeScores)
|
||||||
|
|
||||||
|
SeqId = int
|
||||||
|
TargetSeqId = int
|
||||||
|
|
||||||
|
|
||||||
|
class MQAScorer(SpeculativeScorer):
|
||||||
|
|
||||||
|
def score_proposals(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
proposals: SpeculativeProposals,
|
||||||
|
) -> SpeculativeScores:
|
||||||
|
target_seq_group_metadata_list = []
|
||||||
|
target_seq_id_start = max(
|
||||||
|
get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1
|
||||||
|
all_proposal_tokens = proposals.proposal_token_ids.tolist()
|
||||||
|
for i, seq_group_metadata in enumerate(
|
||||||
|
execute_model_req.seq_group_metadata_list):
|
||||||
|
seq_data_dict = seq_group_metadata.seq_data
|
||||||
|
assert len(seq_data_dict) == 1
|
||||||
|
seq_id = next(iter(seq_data_dict.keys()))
|
||||||
|
|
||||||
|
seq_data: SequenceData = seq_data_dict[seq_id]
|
||||||
|
prompt_token_ids = seq_data.get_prompt_token_ids()
|
||||||
|
output_token_ids = seq_data.get_output_token_ids()
|
||||||
|
proposal_token_ids = all_proposal_tokens[i]
|
||||||
|
new_output_token_ids = [*output_token_ids, *proposal_token_ids]
|
||||||
|
|
||||||
|
target_seq_id = target_seq_id_start + i
|
||||||
|
new_seq_data = SequenceData.from_seqs(
|
||||||
|
prompt_token_ids=prompt_token_ids,
|
||||||
|
output_token_ids=new_output_token_ids,
|
||||||
|
)
|
||||||
|
new_seq_data.update_num_computed_tokens(
|
||||||
|
len(prompt_token_ids) + len(output_token_ids) - 1)
|
||||||
|
|
||||||
|
# Ensure that the new sequence has at least one token
|
||||||
|
# because we only use mqa scorer in the decoding stage.
|
||||||
|
assert len(output_token_ids) >= 1
|
||||||
|
new_seq_data_dict = {target_seq_id: new_seq_data}
|
||||||
|
|
||||||
|
new_seq_group_metadata = SequenceGroupMetadata(
|
||||||
|
request_id=seq_group_metadata.request_id,
|
||||||
|
is_prompt=seq_group_metadata.is_prompt,
|
||||||
|
seq_data=new_seq_data_dict,
|
||||||
|
sampling_params=seq_group_metadata.sampling_params,
|
||||||
|
block_tables={
|
||||||
|
target_seq_id: seq_group_metadata.block_tables[seq_id],
|
||||||
|
},
|
||||||
|
lora_request=None,
|
||||||
|
token_chunk_size=1,
|
||||||
|
)
|
||||||
|
target_seq_group_metadata_list.append(new_seq_group_metadata)
|
||||||
|
|
||||||
|
target_sampler_output = self._scorer_worker.execute_model(
|
||||||
|
execute_model_req=execute_model_req.clone(
|
||||||
|
seq_group_metadata_list=target_seq_group_metadata_list))
|
||||||
|
|
||||||
|
target_sampler_output = target_sampler_output[0]
|
||||||
|
|
||||||
|
bs, k = proposals.proposal_token_ids.shape
|
||||||
|
all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1)
|
||||||
|
|
||||||
|
all_probs = target_sampler_output.sampled_token_probs.reshape(
|
||||||
|
bs, k + 1, self._vocab_size)
|
||||||
|
all_logprobs = target_sampler_output.logprobs.reshape(
|
||||||
|
bs, k + 1, self._vocab_size)
|
||||||
|
|
||||||
|
hidden_states = None
|
||||||
|
if target_sampler_output.hidden_states is not None:
|
||||||
|
hidden_states = target_sampler_output.hidden_states.reshape(
|
||||||
|
bs, (k + 1), -1)
|
||||||
|
return SpeculativeScores(probs=all_probs,
|
||||||
|
token_ids=all_tokens,
|
||||||
|
logprobs=all_logprobs,
|
||||||
|
hidden_states=hidden_states)
|
||||||
@ -1,6 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -24,6 +24,7 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
|||||||
from vllm.spec_decode.medusa_worker import MedusaWorker
|
from vllm.spec_decode.medusa_worker import MedusaWorker
|
||||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||||
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
||||||
|
from vllm.spec_decode.mqa_scorer import MQAScorer
|
||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||||
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
||||||
@ -70,6 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
|
|||||||
spec_decode_worker = SpecDecodeWorker.create_worker(
|
spec_decode_worker = SpecDecodeWorker.create_worker(
|
||||||
scorer_worker=target_worker,
|
scorer_worker=target_worker,
|
||||||
draft_worker_kwargs=draft_worker_kwargs,
|
draft_worker_kwargs=draft_worker_kwargs,
|
||||||
|
disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer,
|
||||||
disable_by_batch_size=speculative_config.
|
disable_by_batch_size=speculative_config.
|
||||||
speculative_disable_by_batch_size,
|
speculative_disable_by_batch_size,
|
||||||
draft_token_acceptance_method=speculative_config.
|
draft_token_acceptance_method=speculative_config.
|
||||||
@ -116,6 +118,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
cls,
|
cls,
|
||||||
scorer_worker: Worker,
|
scorer_worker: Worker,
|
||||||
draft_worker_kwargs: Dict[str, Any],
|
draft_worker_kwargs: Dict[str, Any],
|
||||||
|
disable_mqa_scorer: bool,
|
||||||
disable_by_batch_size: Optional[int],
|
disable_by_batch_size: Optional[int],
|
||||||
draft_token_acceptance_method: str,
|
draft_token_acceptance_method: str,
|
||||||
typical_acceptance_sampler_posterior_threshold: float,
|
typical_acceptance_sampler_posterior_threshold: float,
|
||||||
@ -173,12 +176,43 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
typical_acceptance_sampler_posterior_threshold,
|
typical_acceptance_sampler_posterior_threshold,
|
||||||
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
posterior_alpha=typical_acceptance_sampler_posterior_alpha,
|
||||||
)
|
)
|
||||||
logger.info("Configuring SpecDecodeWorker with sampler=%s",
|
logger.info(
|
||||||
type(spec_decode_sampler))
|
"[Speculative Decoding] Configuring"
|
||||||
|
" SpecDecodeWorker with sampler=%s", type(spec_decode_sampler))
|
||||||
|
|
||||||
|
if not disable_mqa_scorer:
|
||||||
|
if scorer_worker.model_runner.attn_backend.get_name(
|
||||||
|
) != "flash-attn":
|
||||||
|
disable_mqa_scorer = True
|
||||||
|
logger.info(
|
||||||
|
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||||
|
"MQA is only available with flash attn backend.")
|
||||||
|
|
||||||
|
if ngram_prompt_lookup_max > 0:
|
||||||
|
disable_mqa_scorer = True
|
||||||
|
logger.info(
|
||||||
|
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||||
|
"NGramWorker does not support MQA scorer.")
|
||||||
|
|
||||||
|
if "model_config" in draft_worker_kwargs and \
|
||||||
|
draft_worker_kwargs["model_config"].max_model_len < \
|
||||||
|
scorer_worker.model_config.max_model_len:
|
||||||
|
disable_mqa_scorer = True
|
||||||
|
logger.info(
|
||||||
|
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||||
|
"draft model max_model_len is smaller than the target "
|
||||||
|
"model max_model_len.")
|
||||||
|
|
||||||
|
if not scorer_worker.model_runner.model_config.enforce_eager:
|
||||||
|
disable_mqa_scorer = True
|
||||||
|
logger.info(
|
||||||
|
"[Speculative Decoding] Disabling MQA scorer as the "
|
||||||
|
"target model is not running in eager mode.")
|
||||||
|
|
||||||
return SpecDecodeWorker(
|
return SpecDecodeWorker(
|
||||||
proposer_worker,
|
proposer_worker,
|
||||||
scorer_worker,
|
scorer_worker,
|
||||||
|
disable_mqa_scorer=disable_mqa_scorer,
|
||||||
disable_logprobs=disable_logprobs,
|
disable_logprobs=disable_logprobs,
|
||||||
disable_log_stats=disable_log_stats,
|
disable_log_stats=disable_log_stats,
|
||||||
disable_by_batch_size=disable_by_batch_size,
|
disable_by_batch_size=disable_by_batch_size,
|
||||||
@ -190,6 +224,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
proposer_worker: ProposerWorkerBase,
|
proposer_worker: ProposerWorkerBase,
|
||||||
scorer_worker: WorkerBase,
|
scorer_worker: WorkerBase,
|
||||||
spec_decode_sampler: SpecDecodeBaseSampler,
|
spec_decode_sampler: SpecDecodeBaseSampler,
|
||||||
|
disable_mqa_scorer: bool = False,
|
||||||
disable_logprobs: bool = False,
|
disable_logprobs: bool = False,
|
||||||
disable_log_stats: bool = False,
|
disable_log_stats: bool = False,
|
||||||
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
metrics_collector: Optional[AsyncMetricsCollector] = None,
|
||||||
@ -211,6 +246,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
types of sampler namely RejectionSampler and
|
types of sampler namely RejectionSampler and
|
||||||
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
|
TypicalAcceptanceSampler. 'spec_decode_sampler' is either an
|
||||||
instance of RejectionSampler or TypicalAcceptanceSampler.
|
instance of RejectionSampler or TypicalAcceptanceSampler.
|
||||||
|
disable_mqa_scorer: If set to True, disable the MQA scorer and use
|
||||||
|
the BatchExpansionTop1Scorer instead.
|
||||||
disable_logprobs: If set to True, token log probabilities will
|
disable_logprobs: If set to True, token log probabilities will
|
||||||
not be output in both the draft worker and the target worker.
|
not be output in both the draft worker and the target worker.
|
||||||
If set to False, log probabilities will be output by both.
|
If set to False, log probabilities will be output by both.
|
||||||
@ -248,6 +285,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
||||||
# Lazy initialization.
|
# Lazy initialization.
|
||||||
self.scorer: SpeculativeScorer
|
self.scorer: SpeculativeScorer
|
||||||
|
self.disable_mqa_scorer = disable_mqa_scorer
|
||||||
|
|
||||||
# Hidden states from target model to pass to proposer
|
# Hidden states from target model to pass to proposer
|
||||||
# in the subsequent step.
|
# in the subsequent step.
|
||||||
@ -270,10 +308,19 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self._metrics.init_gpu_tensors(self.rank)
|
self._metrics.init_gpu_tensors(self.rank)
|
||||||
self.spec_decode_sampler.init_gpu_tensors(self.rank)
|
self.spec_decode_sampler.init_gpu_tensors(self.rank)
|
||||||
|
|
||||||
self.scorer = BatchExpansionTop1Scorer(
|
scorer_cls: Type[SpeculativeScorer]
|
||||||
scorer_worker=self.scorer_worker,
|
if self.disable_mqa_scorer:
|
||||||
device=self.device,
|
scorer_cls = BatchExpansionTop1Scorer
|
||||||
vocab_size=self._vocab_size)
|
logger.info("[Speculative Decoding] Use batch "
|
||||||
|
"expansion for scoring proposals.")
|
||||||
|
else:
|
||||||
|
scorer_cls = MQAScorer
|
||||||
|
logger.info(
|
||||||
|
"[Speculative Decoding] Use MQA scorer for scoring proposals.")
|
||||||
|
|
||||||
|
self.scorer = scorer_cls(scorer_worker=self.scorer_worker,
|
||||||
|
device=self.device,
|
||||||
|
vocab_size=self._vocab_size)
|
||||||
|
|
||||||
self._configure_model_sampler_for_spec_decode()
|
self._configure_model_sampler_for_spec_decode()
|
||||||
|
|
||||||
|
|||||||
@ -468,43 +468,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
|
|
||||||
# Compute context length (the number of tokens that are
|
# Compute context length (the number of tokens that are
|
||||||
# already computed) and sequence length (total number of tokens).
|
# already computed) and sequence length (total number of tokens).
|
||||||
|
|
||||||
seq_len = seq_data.get_len()
|
seq_len = seq_data.get_len()
|
||||||
if inter_data.is_prompt:
|
if inter_data.is_prompt:
|
||||||
context_len = seq_data.get_num_computed_tokens()
|
context_len = seq_data.get_num_computed_tokens()
|
||||||
else:
|
seq_len = min(seq_len, context_len + token_chunk_size)
|
||||||
# get_num_computed_tokens is incorrect for spec decoding.
|
elif self.runner.scheduler_config.is_multi_step or \
|
||||||
# So, we should have a special logic here.
|
self.runner.model_config.is_encoder_decoder_model:
|
||||||
# TODO(sang): Fix it.
|
|
||||||
context_len = seq_len - 1
|
context_len = seq_len - 1
|
||||||
seq_len = min(seq_len, context_len + token_chunk_size)
|
else:
|
||||||
|
context_len = seq_data.get_num_computed_tokens()
|
||||||
|
|
||||||
# Compute tokens.
|
# Compute tokens.
|
||||||
if inter_data.is_prompt:
|
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||||
tokens = seq_data.get_token_ids()
|
|
||||||
if context_len != 0 or seq_len < len(tokens):
|
|
||||||
tokens = tokens[context_len:seq_len]
|
|
||||||
else:
|
|
||||||
# Optimization. get_token_ids requires the entire copy of
|
|
||||||
# tokens.
|
|
||||||
tokens = seq_data.get_last_token_id()
|
|
||||||
|
|
||||||
inter_data.seq_lens[seq_idx] = seq_len
|
inter_data.seq_lens[seq_idx] = seq_len
|
||||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||||
inter_data.context_lens[seq_idx] = context_len
|
inter_data.context_lens[seq_idx] = context_len
|
||||||
|
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||||
if isinstance(tokens, list):
|
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
||||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
inter_data.query_lens[seq_idx] = seq_len - context_len
|
||||||
else:
|
|
||||||
inter_data.input_tokens[seq_idx].append(tokens)
|
|
||||||
|
|
||||||
if (seq_len - context_len) == 1:
|
|
||||||
inter_data.input_positions[seq_idx].append(seq_len - 1)
|
|
||||||
else:
|
|
||||||
inter_data.input_positions[seq_idx].extend(
|
|
||||||
range(context_len, seq_len))
|
|
||||||
|
|
||||||
inter_data.query_lens[
|
|
||||||
seq_idx] = seq_len - context_len if inter_data.is_prompt else 1
|
|
||||||
|
|
||||||
if seq_data.mrope_position_delta is not None:
|
if seq_data.mrope_position_delta is not None:
|
||||||
if inter_data.mrope_input_positions is None:
|
if inter_data.mrope_input_positions is None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user