mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 10:40:44 +08:00
[Bugfix] Make spec. decode respect per-request seed. (#6034)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
b5672a112c
commit
d4201e06d5
@ -150,9 +150,54 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
generators = [None] * batch_size
|
||||
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids)
|
||||
draft_token_ids, generators)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("k", [1, 3, 6])
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("n_rep", [100])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
|
||||
frac_seeded: float, n_rep: int,
|
||||
device: str):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64)
|
||||
|
||||
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
|
||||
|
||||
results = []
|
||||
for _ in range(n_rep):
|
||||
generators = [
|
||||
torch.Generator(
|
||||
device=device).manual_seed(i) if seeded_mask[i] else None
|
||||
for i in range(batch_size)
|
||||
]
|
||||
results.append(
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids, generators))
|
||||
|
||||
for i in range(batch_size):
|
||||
if seeded_mask[i]:
|
||||
for j in range(1, n_rep):
|
||||
assert torch.equal(results[j][i], results[0][i])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@ -197,10 +242,11 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
raise AssertionError()
|
||||
|
||||
oob_token_ids[0][0] = rogue_token_id
|
||||
generators = [None] * batch_size
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids)
|
||||
draft_token_ids, generators)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
|
||||
@ -371,11 +417,15 @@ class _CorrectnessTestHelper:
|
||||
dtype=torch.int64,
|
||||
device="cuda").repeat(num_samples, 1)
|
||||
|
||||
# unseeded
|
||||
generators = [None]
|
||||
|
||||
# Get output tokens via rejection sampling.
|
||||
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
|
||||
bonus_token_ids.to("cuda"),
|
||||
draft_probs.to("cuda"),
|
||||
draft_token_ids.to("cuda"))
|
||||
draft_token_ids.to("cuda"),
|
||||
generators)
|
||||
|
||||
# Remove bonus tokens
|
||||
output_token_ids = output_token_ids[:, :-1].flatten()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from itertools import cycle
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
@ -128,7 +128,9 @@ class AsyncLLM:
|
||||
try:
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
res = asyncio.run(get_output(prompt, sampling_params))
|
||||
params = sampling_params[i] if isinstance(
|
||||
sampling_params, Sequence) else sampling_params
|
||||
res = asyncio.run(get_output(prompt, params))
|
||||
outputs.append(res)
|
||||
finally:
|
||||
ray.shutdown()
|
||||
@ -267,7 +269,31 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
"""
|
||||
temperature = 0.0
|
||||
|
||||
run_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len,
|
||||
temperature=0.0,
|
||||
seeded=False,
|
||||
print_tokens=print_tokens,
|
||||
ensure_all_accepted=ensure_all_accepted)
|
||||
|
||||
|
||||
def run_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
temperature: float,
|
||||
seeded: bool,
|
||||
print_tokens: bool = False,
|
||||
ensure_all_accepted: bool = False):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero (or when temperature is > 0 and seeded).
|
||||
"""
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
@ -286,11 +312,21 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
if seeded:
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
seed=i,
|
||||
) for i in range(len(prompts))
|
||||
]
|
||||
else:
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
(spec_batch_tokens, spec_batch_token_ids,
|
||||
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
|
||||
|
||||
44
tests/spec_decode/e2e/test_seed.py
Normal file
44
tests/spec_decode/e2e/test_seed.py
Normal file
@ -0,0 +1,44 @@
|
||||
import pytest
|
||||
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# speculative model
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
|
||||
# num speculative tokens
|
||||
"num_speculative_tokens": 3,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32])
|
||||
@pytest.mark.parametrize("temperature", [0.1, 1.0])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
10,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_seeded_consistency(baseline_llm_generator, batch_size: int,
|
||||
temperature: float, output_len: int):
|
||||
"""Verify outputs are consistent across multiple runs with same seed
|
||||
"""
|
||||
run_equality_correctness_test(baseline_llm_generator,
|
||||
baseline_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seeded=True,
|
||||
force_output_len=True)
|
||||
@ -1,14 +1,14 @@
|
||||
from functools import cached_property
|
||||
from typing import Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
SpecDecodeStochasticBaseSampler)
|
||||
|
||||
|
||||
class RejectionSampler(SpecDecodeBaseSampler):
|
||||
class RejectionSampler(SpecDecodeStochasticBaseSampler):
|
||||
"""Apply modified rejection sampling as described in "Accelerating Large
|
||||
Language Model Decoding with Speculative Sampling"
|
||||
https://arxiv.org/pdf/2302.01318.pdf.
|
||||
@ -36,6 +36,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
generators: List[Optional[torch.Generator]],
|
||||
) -> torch.Tensor:
|
||||
"""Sample token ids using rejection sampling. This accepts or rejects
|
||||
tokens proposed by the draft model using the probability of each token
|
||||
@ -82,6 +83,7 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
||||
target_probs,
|
||||
draft_probs,
|
||||
draft_token_ids,
|
||||
generators,
|
||||
))
|
||||
|
||||
output_token_ids = self._create_output(
|
||||
@ -94,10 +96,11 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
||||
return output_token_ids
|
||||
|
||||
def _batch_modified_rejection_sampling(
|
||||
self,
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
self,
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
generators: List[Optional[torch.Generator]],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Perform modified rejection sampling on each sequence.
|
||||
|
||||
@ -114,22 +117,33 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
||||
|
||||
# shape [batch_size, k]
|
||||
accepted = self._get_accepted(target_probs, draft_probs,
|
||||
draft_token_ids)
|
||||
draft_token_ids, generators)
|
||||
|
||||
recovered_probs = self._get_recovered_probs(
|
||||
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
|
||||
|
||||
seed_indices, non_seed_indices = self._split_batch_by_seeded(
|
||||
generators, k=k)
|
||||
|
||||
# NOTE: the recovered_probs are overwritten by this method.
|
||||
recovered_token_ids = _multinomial(recovered_probs,
|
||||
num_samples=1).reshape(
|
||||
batch_size, k)
|
||||
recovered_token_ids = _multinomial(
|
||||
recovered_probs,
|
||||
num_samples=1,
|
||||
k=k,
|
||||
generators=generators,
|
||||
seed_indices=seed_indices,
|
||||
# this arg is unused when None but torch.jit requires a list
|
||||
non_seed_indices=non_seed_indices or [],
|
||||
).reshape(batch_size, k)
|
||||
|
||||
return accepted, recovered_token_ids
|
||||
|
||||
def _get_accepted(
|
||||
self,
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
self,
|
||||
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
|
||||
draft_token_ids: torch.Tensor, # [batch_size, k]
|
||||
generators: List[Optional[torch.Generator]],
|
||||
) -> torch.Tensor:
|
||||
r"""Create bool matrix over the proposed draft tokens. If
|
||||
True, then a token can be accepted, else it should be
|
||||
@ -164,10 +178,28 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
||||
selected_target_probs = target_probs[batch_indices, probs_indicies,
|
||||
draft_token_ids]
|
||||
|
||||
uniform_rand = torch.rand(batch_size,
|
||||
k,
|
||||
dtype=self.probs_dtype,
|
||||
device=target_probs.device)
|
||||
seed_indices, non_seed_indices = self._split_batch_by_seeded(
|
||||
generators)
|
||||
|
||||
if len(seed_indices) == 0:
|
||||
uniform_rand = torch.rand_like(selected_target_probs)
|
||||
else:
|
||||
uniform_rand = torch.empty_like(selected_target_probs)
|
||||
|
||||
for idx in seed_indices:
|
||||
uniform_rand[idx, :] = torch.rand(1,
|
||||
k,
|
||||
dtype=self.probs_dtype,
|
||||
device=target_probs.device,
|
||||
generator=generators[idx])
|
||||
|
||||
if non_seed_indices:
|
||||
uniform_rand[non_seed_indices, :] = torch.rand(
|
||||
len(non_seed_indices),
|
||||
k,
|
||||
dtype=self.probs_dtype,
|
||||
device=target_probs.device)
|
||||
|
||||
capped_ratio = torch.minimum(
|
||||
selected_target_probs / selected_draft_probs,
|
||||
torch.full((1, ), 1, device=target_probs.device))
|
||||
@ -240,6 +272,27 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
||||
"""
|
||||
return torch.finfo(self.probs_dtype).tiny
|
||||
|
||||
# partition batch into indices for which a generator is provided
|
||||
# and indicies for which no generator is provided
|
||||
@staticmethod
|
||||
def _split_batch_by_seeded(
|
||||
generators: List[Optional[torch.Generator]],
|
||||
k: int = 1,
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
|
||||
if all(generator is None for generator in generators):
|
||||
seed_indices: List[int] = []
|
||||
non_seed_indices: Optional[List[int]] = None
|
||||
else:
|
||||
seed_indices, non_seed_indices = [], []
|
||||
for i, generator in enumerate(generators):
|
||||
if generator is None:
|
||||
non_seed_indices.extend(range(k * i, k * (i + 1)))
|
||||
else:
|
||||
seed_indices.extend(range(k * i, k * (i + 1)))
|
||||
|
||||
return seed_indices, non_seed_indices
|
||||
|
||||
|
||||
# torch.multinomial forces a GPU<->CPU sync.
|
||||
# Therefore, we use an optimized implementation instead that skips the sync.
|
||||
@ -250,12 +303,25 @@ class RejectionSampler(SpecDecodeBaseSampler):
|
||||
def _multinomial(
|
||||
probs: torch.Tensor,
|
||||
num_samples: int,
|
||||
k: int,
|
||||
generators: List[Optional[torch.Generator]],
|
||||
seed_indices: List[int],
|
||||
non_seed_indices: List[int],
|
||||
) -> torch.Tensor:
|
||||
|
||||
if num_samples > 1:
|
||||
# This is equivalent to torch.repeat_interleaved (which also
|
||||
# forces a GPU<->CPU sync).
|
||||
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
|
||||
probs.shape[1]).contiguous().view(
|
||||
-1, probs.shape[1])
|
||||
q = torch.empty_like(probs).exponential_(1.0)
|
||||
|
||||
q = torch.empty_like(probs)
|
||||
if len(seed_indices) == 0:
|
||||
q.exponential_(1.0)
|
||||
else:
|
||||
q[non_seed_indices].exponential_(1.0)
|
||||
for idx in seed_indices:
|
||||
q[idx].exponential_(1.0, generator=generators[idx // k])
|
||||
|
||||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.jit
|
||||
@ -54,16 +54,6 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
def token_id_dtype(self):
|
||||
return torch.int64
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _create_output(
|
||||
self,
|
||||
accepted: torch.Tensor, # [batch_size, k]
|
||||
@ -217,3 +207,36 @@ class SpecDecodeBaseSampler(nn.Module):
|
||||
assert torch.all(bonus_token_ids >= 0)
|
||||
assert torch.all(draft_token_ids < vocab_size)
|
||||
assert torch.all(draft_token_ids >= 0)
|
||||
|
||||
|
||||
class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler):
|
||||
"""Base class for samplers used for Speculative Decoding verification
|
||||
step which are deterministic.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler):
|
||||
"""Base class for samplers used for Speculative Decoding verification
|
||||
step which are stochastic
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
target_probs: torch.Tensor,
|
||||
bonus_token_ids: torch.Tensor,
|
||||
draft_probs: torch.Tensor,
|
||||
draft_token_ids: torch.Tensor,
|
||||
generators: List[Optional[torch.Generator]],
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -2,10 +2,10 @@ import torch
|
||||
import torch.jit
|
||||
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
SpecDecodeDeterministicBaseSampler)
|
||||
|
||||
|
||||
class TypicalAcceptanceSampler(SpecDecodeBaseSampler):
|
||||
class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
|
||||
"""Apply typical acceptance sampling as described in section 3.3.1 in
|
||||
"MEDUSA: Simple LLM Inference Acceleration Framework with
|
||||
Multiple Decoding Heads"
|
||||
|
||||
@ -4,7 +4,8 @@ from typing import Iterator, List, Tuple
|
||||
import torch
|
||||
|
||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
||||
SequenceGroupMetadata, get_all_seq_ids)
|
||||
SequenceGroupMetadata, SequenceGroupState,
|
||||
get_all_seq_ids)
|
||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||
SpeculativeScorer, SpeculativeScores)
|
||||
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
|
||||
@ -292,6 +293,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
for data in new_seq_data_dict.values():
|
||||
data.update_num_computed_tokens(data.get_len() - 1)
|
||||
|
||||
if (seq_group_metadata.state is not None
|
||||
and seq_group_metadata.state.generator is not None):
|
||||
generator = torch.Generator(
|
||||
device=seq_group_metadata.state.generator.device)
|
||||
generator.set_state(seq_group_metadata.state.generator.get_state())
|
||||
state = SequenceGroupState(generator=generator)
|
||||
else:
|
||||
state = None
|
||||
|
||||
return SequenceGroupMetadata(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
is_prompt=seq_group_metadata.is_prompt,
|
||||
@ -302,6 +312,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
|
||||
},
|
||||
lora_request=None,
|
||||
token_chunk_size=1,
|
||||
state=state,
|
||||
)
|
||||
|
||||
def _split_scoring_output(
|
||||
|
||||
@ -9,7 +9,7 @@ from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
from vllm.model_executor.layers.spec_decode_base_sampler import (
|
||||
SpecDecodeBaseSampler)
|
||||
SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler)
|
||||
from vllm.model_executor.layers.typical_acceptance_sampler import (
|
||||
TypicalAcceptanceSampler)
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest,
|
||||
@ -521,11 +521,28 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
# Get proposed tokens.
|
||||
proposal_token_ids = proposals.proposal_token_ids[spec_indices]
|
||||
|
||||
# Sampler arguments
|
||||
sampler_extra_kwargs = {}
|
||||
if isinstance(self.spec_decode_sampler,
|
||||
SpecDecodeStochasticBaseSampler):
|
||||
|
||||
# Get sequence group state
|
||||
generators = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
if (seq_group_metadata.state is not None
|
||||
and seq_group_metadata.state.generator is not None):
|
||||
generators.append(seq_group_metadata.state.generator)
|
||||
else:
|
||||
generators.append(None)
|
||||
|
||||
sampler_extra_kwargs["generators"] = generators
|
||||
|
||||
accepted_token_ids = self.spec_decode_sampler(
|
||||
target_probs=proposal_verifier_probs,
|
||||
bonus_token_ids=bonus_token_ids,
|
||||
draft_probs=proposal_probs,
|
||||
draft_token_ids=proposal_token_ids,
|
||||
**sampler_extra_kwargs,
|
||||
)
|
||||
|
||||
# Append output tokens from non-speculative sequences to
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user