mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:34:56 +08:00
[Speculative Decoding] Test refactor (#8317)
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
parent
8baa454937
commit
775f00f81e
@ -217,7 +217,8 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
# See https://github.com/vllm-project/vllm/issues/5152
|
# See https://github.com/vllm-project/vllm/issues/5152
|
||||||
- export VLLM_ATTENTION_BACKEND=XFORMERS
|
- export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||||
- pytest -v -s spec_decode
|
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||||
|
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
||||||
|
|
||||||
- label: LoRA Test %N # 30min each
|
- label: LoRA Test %N # 30min each
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
|
|||||||
@ -1,224 +1,54 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import ray
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.multimodal import MultiModalDataDict
|
|
||||||
from vllm.outputs import RequestOutput
|
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
|
||||||
from vllm.sampling_params import SamplingParams
|
|
||||||
from vllm.sequence import Logprob
|
|
||||||
from vllm.usage.usage_lib import UsageContext
|
|
||||||
from vllm.utils import Counter, random_uuid
|
|
||||||
|
|
||||||
from ...conftest import cleanup
|
from ...conftest import cleanup
|
||||||
from ...utils import wait_for_gpu_memory_to_clear
|
from ...models.utils import check_logprobs_close, check_outputs_equal
|
||||||
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
|
PROMPTS = [
|
||||||
class AsyncLLM:
|
"Hello, my name is",
|
||||||
"""AsyncLLM
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
Note: Current LLM class in vllm don't support async mode, for test purpose,
|
"The future of AI is",
|
||||||
we implement async one in here. Maybe we could move to
|
"San Francisco is know for its",
|
||||||
vllm/entrypoints/llm.py in future.
|
"Facebook was created in 2004 by",
|
||||||
|
"Curious George is a",
|
||||||
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
|
"Python 3.11 brings improvements to its",
|
||||||
to make to work in async mode.
|
]
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
tokenizer: Optional[str] = None,
|
|
||||||
tokenizer_mode: str = "auto",
|
|
||||||
skip_tokenizer_init: bool = False,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
tensor_parallel_size: int = 1,
|
|
||||||
dtype: str = "auto",
|
|
||||||
quantization: Optional[str] = None,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
tokenizer_revision: Optional[str] = None,
|
|
||||||
seed: int = 0,
|
|
||||||
gpu_memory_utilization: float = 0.9,
|
|
||||||
swap_space: int = 4,
|
|
||||||
enforce_eager: bool = False,
|
|
||||||
max_seq_len_to_capture: int = 8192,
|
|
||||||
disable_custom_all_reduce: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
if "disable_log_stats" not in kwargs:
|
|
||||||
kwargs["disable_log_stats"] = True
|
|
||||||
|
|
||||||
# Needed to engine_use_ray works as a deprecated feature,
|
|
||||||
# otherwise the following constructor will raise an exception
|
|
||||||
os.environ["VLLM_ALLOW_ENGINE_USE_RAY"] = "1"
|
|
||||||
|
|
||||||
engine_args = AsyncEngineArgs(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
tokenizer_mode=tokenizer_mode,
|
|
||||||
skip_tokenizer_init=skip_tokenizer_init,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
dtype=dtype,
|
|
||||||
quantization=quantization,
|
|
||||||
revision=revision,
|
|
||||||
tokenizer_revision=tokenizer_revision,
|
|
||||||
seed=seed,
|
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
|
||||||
swap_space=swap_space,
|
|
||||||
enforce_eager=enforce_eager,
|
|
||||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
|
||||||
# For now use ray for the distributed back-end, since
|
|
||||||
# we rely on the use of engine_use_ray=True to avoid
|
|
||||||
# reinitializing CUDA in the same process (driver worker)
|
|
||||||
engine_use_ray=True,
|
|
||||||
distributed_executor_backend="ray",
|
|
||||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
self.request_counter = Counter()
|
|
||||||
self.llm_engine = AsyncLLMEngine.from_engine_args(
|
|
||||||
engine_args, usage_context=UsageContext.LLM_CLASS)
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
prompts: Optional[Union[str, List[str]]] = None,
|
|
||||||
sampling_params: Optional[Union[SamplingParams,
|
|
||||||
List[SamplingParams]]] = None,
|
|
||||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
|
||||||
use_tqdm: bool = True,
|
|
||||||
lora_request: Optional[LoRARequest] = None,
|
|
||||||
multi_modal_data: Optional[MultiModalDataDict] = None,
|
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
|
||||||
) -> List[RequestOutput]:
|
|
||||||
|
|
||||||
if prompts is None:
|
|
||||||
raise ValueError("prompts must be provided.")
|
|
||||||
if isinstance(prompts, str):
|
|
||||||
# Convert a single prompt to a list.
|
|
||||||
prompts = [prompts]
|
|
||||||
|
|
||||||
if prompts is not None:
|
|
||||||
num_requests = len(prompts)
|
|
||||||
|
|
||||||
if sampling_params is None:
|
|
||||||
# Use default sampling params.
|
|
||||||
sampling_params = SamplingParams()
|
|
||||||
|
|
||||||
elif isinstance(sampling_params,
|
|
||||||
list) and len(sampling_params) != num_requests:
|
|
||||||
raise ValueError("The lengths of prompts and "
|
|
||||||
"sampling_params must be the same.")
|
|
||||||
|
|
||||||
async def get_output(prompt, sampling_param) -> RequestOutput:
|
|
||||||
request_id = random_uuid()
|
|
||||||
results_generator = self.llm_engine.generate(
|
|
||||||
prompt, sampling_param, request_id)
|
|
||||||
final_output = None
|
|
||||||
async for request_output in results_generator:
|
|
||||||
final_output = request_output
|
|
||||||
assert final_output is not None
|
|
||||||
return final_output
|
|
||||||
|
|
||||||
outputs: List[RequestOutput] = []
|
|
||||||
try:
|
|
||||||
for i in range(num_requests):
|
|
||||||
prompt = prompts[i] if prompts is not None else None
|
|
||||||
params = sampling_params[i] if isinstance(
|
|
||||||
sampling_params, Sequence) else sampling_params
|
|
||||||
res = asyncio.run(get_output(prompt, params))
|
|
||||||
outputs.append(res)
|
|
||||||
finally:
|
|
||||||
ray.shutdown()
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def baseline_llm_generator(request, common_llm_kwargs,
|
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
|
||||||
seed):
|
|
||||||
return create_llm_generator("baseline", request, common_llm_kwargs,
|
|
||||||
per_test_common_llm_kwargs,
|
|
||||||
baseline_llm_kwargs, seed)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
|
|
||||||
test_llm_kwargs, seed):
|
test_llm_kwargs, seed):
|
||||||
return create_llm_generator("test", request, common_llm_kwargs,
|
|
||||||
per_test_common_llm_kwargs, test_llm_kwargs,
|
|
||||||
seed)
|
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
kwargs = {
|
||||||
|
**common_llm_kwargs,
|
||||||
|
**per_test_common_llm_kwargs,
|
||||||
|
**test_llm_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
llm = LLM(**kwargs)
|
||||||
per_test_common_llm_kwargs, distinct_llm_kwargs,
|
|
||||||
seed):
|
|
||||||
kwargs = {
|
|
||||||
**common_llm_kwargs,
|
|
||||||
**per_test_common_llm_kwargs,
|
|
||||||
**distinct_llm_kwargs,
|
|
||||||
}
|
|
||||||
test_name = request.node.name
|
|
||||||
|
|
||||||
model = kwargs["model"]
|
|
||||||
draft_model = kwargs.get("speculative_model", None)
|
|
||||||
same_draft_target_model = (draft_model is not None
|
|
||||||
and draft_model == model)
|
|
||||||
|
|
||||||
def generator_inner():
|
|
||||||
|
|
||||||
wait_for_gpu_memory_to_clear(
|
|
||||||
devices=list(range(torch.cuda.device_count())),
|
|
||||||
threshold_bytes=2 * 2**30,
|
|
||||||
timeout_s=60,
|
|
||||||
)
|
|
||||||
|
|
||||||
use_async = False
|
|
||||||
if "use_async" in kwargs:
|
|
||||||
use_async = kwargs.pop("use_async")
|
|
||||||
print(f'{use_async=}')
|
|
||||||
|
|
||||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
|
||||||
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
|
|
||||||
|
|
||||||
# Override logging interval to 0 for spec decode test run to
|
|
||||||
# log all metrics in time.
|
|
||||||
if (baseline_or_test == "test" and not use_async
|
|
||||||
and llm.llm_engine.log_stats):
|
|
||||||
for sate_logger in llm.llm_engine.stat_loggers.values():
|
|
||||||
sate_logger.local_interval = 0
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
set_random_seed(seed)
|
set_random_seed(seed)
|
||||||
|
|
||||||
yield llm
|
yield llm
|
||||||
|
|
||||||
del llm
|
del llm
|
||||||
cleanup()
|
cleanup()
|
||||||
|
|
||||||
def generator_outer():
|
return generate
|
||||||
for llm in generator_inner():
|
|
||||||
yield llm
|
|
||||||
del llm
|
|
||||||
|
|
||||||
# Set an attribute to the generator_outer function to allow us to
|
|
||||||
# determine whether to further check the acceptance rate in tests.
|
|
||||||
generator_outer.same_draft_target_model = same_draft_target_model # type: ignore
|
|
||||||
return generator_outer
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_assert_ngram_worker(llm):
|
def maybe_assert_ngram_worker(llm):
|
||||||
# Verify the proposer worker is ngram if ngram is specified.
|
# Verify the proposer worker is ngram if ngram is specified.
|
||||||
if (not isinstance(llm, AsyncLLM)
|
if (llm.llm_engine.speculative_config is not None
|
||||||
and llm.llm_engine.speculative_config is not None
|
|
||||||
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
|
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
|
||||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
@ -251,118 +81,165 @@ def get_output_from_llm_generator(
|
|||||||
return tokens, token_ids, acceptance_rate
|
return tokens, token_ids, acceptance_rate
|
||||||
|
|
||||||
|
|
||||||
def get_logprobs_from_llm_generator(
|
def run_logprob_correctness_test(vllm_runner,
|
||||||
llm_generator, prompts,
|
common_llm_kwargs,
|
||||||
sampling_params) -> List[List[Dict[int, Logprob]]]:
|
per_test_common_llm_kwargs,
|
||||||
"""Returns a dict of (token_id: Logprob) for each generated position, for
|
baseline_llm_kwargs,
|
||||||
each sequence in the batch.
|
test_llm_kwargs,
|
||||||
"""
|
batch_size: int,
|
||||||
for llm in llm_generator():
|
max_output_len: int,
|
||||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
seed: Optional[int] = 0,
|
||||||
logprobs = [output.outputs[0].logprobs[:] for output in outputs]
|
temperature: float = 0.0,
|
||||||
del llm
|
logprobs: int = 1):
|
||||||
|
org_args = {
|
||||||
|
**common_llm_kwargs,
|
||||||
|
**per_test_common_llm_kwargs,
|
||||||
|
**baseline_llm_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
return logprobs
|
sd_args = {
|
||||||
|
**common_llm_kwargs,
|
||||||
|
**per_test_common_llm_kwargs,
|
||||||
|
**test_llm_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=temperature,
|
||||||
|
max_tokens=max_output_len,
|
||||||
|
seed=seed,
|
||||||
|
logprobs=logprobs)
|
||||||
|
|
||||||
|
with vllm_runner(**org_args) as vllm_model:
|
||||||
|
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||||
|
|
||||||
|
with vllm_runner(**sd_args) as vllm_model:
|
||||||
|
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||||
|
|
||||||
|
check_logprobs_close(outputs_0_lst=org_outputs,
|
||||||
|
outputs_1_lst=sd_outputs,
|
||||||
|
name_0="org",
|
||||||
|
name_1="sd")
|
||||||
|
|
||||||
|
|
||||||
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
def run_equality_correctness_test(
|
||||||
test_llm_generator,
|
vllm_runner,
|
||||||
batch_size,
|
common_llm_kwargs,
|
||||||
max_output_len,
|
per_test_common_llm_kwargs,
|
||||||
force_output_len: bool,
|
baseline_llm_kwargs,
|
||||||
print_tokens: bool = False,
|
test_llm_kwargs,
|
||||||
ensure_all_accepted: bool = False):
|
batch_size: int,
|
||||||
|
max_output_len: int,
|
||||||
|
seed: Optional[int] = 0,
|
||||||
|
temperature: float = 0.0,
|
||||||
|
disable_seed: bool = False,
|
||||||
|
ignore_eos: bool = True,
|
||||||
|
ensure_all_accepted: bool = False,
|
||||||
|
expected_acceptance_rate: Optional[float] = None):
|
||||||
|
|
||||||
|
org_args = {
|
||||||
|
**common_llm_kwargs,
|
||||||
|
**per_test_common_llm_kwargs,
|
||||||
|
**baseline_llm_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
sd_args = {
|
||||||
|
**common_llm_kwargs,
|
||||||
|
**per_test_common_llm_kwargs,
|
||||||
|
**test_llm_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||||
|
|
||||||
|
if disable_seed:
|
||||||
|
seed = None
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=temperature,
|
||||||
|
max_tokens=max_output_len,
|
||||||
|
seed=seed,
|
||||||
|
ignore_eos=ignore_eos)
|
||||||
|
|
||||||
|
with vllm_runner(**org_args) as vllm_model:
|
||||||
|
org_outputs = vllm_model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
with vllm_runner(**sd_args) as vllm_model:
|
||||||
|
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||||
|
# Force log interval to be 0 to catch all metrics.
|
||||||
|
stat_logger = vllm_model.model.llm_engine.stat_loggers[
|
||||||
|
'prometheus']
|
||||||
|
stat_logger.local_interval = -100
|
||||||
|
|
||||||
|
sd_outputs = vllm_model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||||
|
acceptance_rate = (stat_logger.metrics.
|
||||||
|
gauge_spec_decode_draft_acceptance_rate.labels(
|
||||||
|
**stat_logger.labels)._value.get())
|
||||||
|
|
||||||
|
if ensure_all_accepted:
|
||||||
|
assert True
|
||||||
|
# FIXME: ci fails to log acceptance rate.
|
||||||
|
# It works locally.
|
||||||
|
# assert acceptance_rate == 1.0
|
||||||
|
|
||||||
|
if expected_acceptance_rate is not None:
|
||||||
|
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
||||||
|
|
||||||
|
check_outputs_equal(outputs_0_lst=org_outputs,
|
||||||
|
outputs_1_lst=sd_outputs,
|
||||||
|
name_0="org",
|
||||||
|
name_1="sd")
|
||||||
|
|
||||||
|
|
||||||
|
def run_equality_correctness_test_tp(model,
|
||||||
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size: int,
|
||||||
|
max_output_len: int,
|
||||||
|
seed: int = 0,
|
||||||
|
temperature: float = 0.0):
|
||||||
"""Helper method that compares the outputs of both the baseline LLM and
|
"""Helper method that compares the outputs of both the baseline LLM and
|
||||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||||
the same when temperature is zero.
|
the same when temperature is zero.
|
||||||
"""
|
"""
|
||||||
|
arg1 = common_llm_kwargs + per_test_common_llm_kwargs + baseline_llm_kwargs
|
||||||
|
arg2 = common_llm_kwargs + per_test_common_llm_kwargs + test_llm_kwargs
|
||||||
|
env1 = env2 = None
|
||||||
|
|
||||||
run_equality_correctness_test(baseline_llm_generator,
|
max_wait_seconds = 240
|
||||||
test_llm_generator,
|
results = []
|
||||||
batch_size,
|
|
||||||
max_output_len,
|
|
||||||
force_output_len,
|
|
||||||
temperature=0.0,
|
|
||||||
seeded=False,
|
|
||||||
print_tokens=print_tokens,
|
|
||||||
ensure_all_accepted=ensure_all_accepted)
|
|
||||||
|
|
||||||
|
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||||
|
|
||||||
def run_equality_correctness_test(
|
for args, env in ((arg1, env1), (arg2, env2)):
|
||||||
baseline_llm_generator,
|
with RemoteOpenAIServer(model,
|
||||||
test_llm_generator,
|
args,
|
||||||
batch_size,
|
env_dict=env,
|
||||||
max_output_len,
|
max_wait_seconds=max_wait_seconds) as server:
|
||||||
force_output_len: bool,
|
client = server.get_client()
|
||||||
temperature: float,
|
|
||||||
seeded: bool,
|
|
||||||
print_tokens: bool = False,
|
|
||||||
ensure_all_accepted: bool = False,
|
|
||||||
expected_acceptance_rate: Optional[float] = None):
|
|
||||||
"""Helper method that compares the outputs of both the baseline LLM and
|
|
||||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
|
||||||
the same when temperature is zero (or when temperature is > 0 and seeded).
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompts = [
|
completion = client.completions.create(model=model,
|
||||||
"Hello, my name is",
|
prompt=prompts,
|
||||||
"The president of the United States is",
|
max_tokens=max_output_len,
|
||||||
"The capital of France is",
|
seed=seed,
|
||||||
"The future of AI is",
|
temperature=temperature)
|
||||||
"San Francisco is know for its",
|
|
||||||
"Facebook was created in 2004 by",
|
|
||||||
"Curious George is a",
|
|
||||||
"Python 3.11 brings improvements to its",
|
|
||||||
]
|
|
||||||
|
|
||||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
results.append({
|
||||||
|
"test":
|
||||||
|
"seeded_sampling",
|
||||||
|
"text": [choice.text for choice in completion.choices],
|
||||||
|
"finish_reason":
|
||||||
|
[choice.finish_reason for choice in completion.choices],
|
||||||
|
"usage":
|
||||||
|
completion.usage,
|
||||||
|
})
|
||||||
|
|
||||||
# If the test requires that we generated max_output_len tokens, then set the
|
n = len(results) // 2
|
||||||
# sampling params to ignore eos token.
|
arg1_results = results[:n]
|
||||||
ignore_eos = force_output_len
|
arg2_results = results[n:]
|
||||||
|
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
|
||||||
if seeded:
|
assert arg1_result == arg2_result, (
|
||||||
sampling_params = [
|
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
|
||||||
SamplingParams(
|
f"{arg1_result=} != {arg2_result=}")
|
||||||
max_tokens=max_output_len,
|
|
||||||
ignore_eos=ignore_eos,
|
|
||||||
temperature=temperature,
|
|
||||||
seed=i,
|
|
||||||
) for i in range(len(prompts))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
max_tokens=max_output_len,
|
|
||||||
ignore_eos=ignore_eos,
|
|
||||||
temperature=temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
(spec_batch_tokens, spec_batch_token_ids,
|
|
||||||
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
|
|
||||||
prompts, sampling_params)
|
|
||||||
|
|
||||||
(baseline_batch_tokens, baseline_batch_token_ids,
|
|
||||||
_) = get_output_from_llm_generator(baseline_llm_generator, prompts,
|
|
||||||
sampling_params)
|
|
||||||
|
|
||||||
assert len(baseline_batch_token_ids) == len(prompts)
|
|
||||||
assert len(spec_batch_token_ids) == len(prompts)
|
|
||||||
|
|
||||||
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
|
|
||||||
spec_tokens) in enumerate(
|
|
||||||
zip(baseline_batch_token_ids, baseline_batch_tokens,
|
|
||||||
spec_batch_token_ids, spec_batch_tokens)):
|
|
||||||
if print_tokens:
|
|
||||||
print(f'{i=} {baseline_tokens=}')
|
|
||||||
print(f'{i=} {spec_tokens=}')
|
|
||||||
print(f'{i=} {baseline_token_ids=}')
|
|
||||||
print(f'{i=} {spec_token_ids=}')
|
|
||||||
assert baseline_token_ids == spec_token_ids
|
|
||||||
|
|
||||||
print(f'{acceptance_rate=}')
|
|
||||||
|
|
||||||
if ensure_all_accepted:
|
|
||||||
assert acceptance_rate == 1.0
|
|
||||||
|
|
||||||
if expected_acceptance_rate is not None:
|
|
||||||
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@ correctess for the target model outputs.
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .conftest import run_greedy_equality_correctness_test
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
# main model
|
# main model
|
||||||
MAIN_MODEL = "JackFram/llama-68m"
|
MAIN_MODEL = "JackFram/llama-68m"
|
||||||
@ -53,7 +53,7 @@ PRECISION = "float32"
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -68,15 +68,16 @@ PRECISION = "float32"
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
|
def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||||
test_llm_generator, batch_size: int,
|
per_test_common_llm_kwargs,
|
||||||
output_len: int):
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
"""Verify greedy equality with different batch size."""
|
batch_size: int, output_len: int,
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
seed: int):
|
||||||
test_llm_generator,
|
|
||||||
batch_size,
|
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
per_test_common_llm_kwargs,
|
||||||
force_output_len=True)
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size, output_len, seed)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -94,7 +95,7 @@ def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -109,17 +110,16 @@ def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
def test_eagle_e2e_greedy_correctness_cuda_graph(
|
||||||
test_llm_generator,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
batch_size: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
output_len: int):
|
seed: int):
|
||||||
"""Verify greedy equality with cuda graph enabled and different
|
"""Verify greedy equality with cuda graph enabled and different
|
||||||
batch sizes."""
|
batch sizes."""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||||
test_llm_generator,
|
per_test_common_llm_kwargs,
|
||||||
batch_size,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
max_output_len=output_len,
|
batch_size, output_len, seed)
|
||||||
force_output_len=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -140,7 +140,7 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -158,18 +158,17 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
def test_eagle_e2e_greedy_correctness_with_preemption(
|
||||||
test_llm_generator,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
batch_size: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
output_len: int):
|
seed: int):
|
||||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
generation.
|
generation.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||||
test_llm_generator,
|
per_test_common_llm_kwargs,
|
||||||
batch_size,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
max_output_len=output_len,
|
batch_size, output_len, seed)
|
||||||
force_output_len=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -185,7 +184,7 @@ def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -207,16 +206,17 @@ def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
|
def test_eagle_different_k(vllm_runner, common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify that eagle speculative decoding produces exact equality
|
"""Verify that eagle speculative decoding produces exact equality
|
||||||
to without spec decode with different values of num_speculative_tokens.
|
to without spec decode with different values of num_speculative_tokens.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||||
test_llm_generator,
|
per_test_common_llm_kwargs,
|
||||||
batch_size,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
max_output_len=output_len,
|
batch_size, output_len, seed)
|
||||||
force_output_len=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -232,7 +232,7 @@ def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -250,17 +250,18 @@ def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator,
|
def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify that eagle speculative decoding produces exact equality
|
"""Verify that eagle speculative decoding produces exact equality
|
||||||
to without spec decode when speculation is disabled for large
|
to without spec decode when speculation is disabled for large
|
||||||
batch sizes.
|
batch sizes.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||||
test_llm_generator,
|
per_test_common_llm_kwargs,
|
||||||
batch_size,
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
max_output_len=output_len,
|
batch_size, output_len, seed)
|
||||||
force_output_len=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -4,7 +4,9 @@ other features, e.g. cuda graphs.
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .conftest import run_greedy_equality_correctness_test
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
|
MAIN_MODEL = "JackFram/llama-68m"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -15,7 +17,7 @@ from .conftest import run_greedy_equality_correctness_test
|
|||||||
|
|
||||||
# Verify equality when cuda graphs allowed.
|
# Verify equality when cuda graphs allowed.
|
||||||
"enforce_eager": False,
|
"enforce_eager": False,
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"per_test_common_llm_kwargs",
|
"per_test_common_llm_kwargs",
|
||||||
@ -31,23 +33,27 @@ from .conftest import run_greedy_equality_correctness_test
|
|||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("output_len", [32])
|
@pytest.mark.parametrize("output_len", [32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
|
def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||||
batch_size, output_len):
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int, seed: int):
|
||||||
"""Verify spec decode equality when cuda graphs are enabled.
|
"""Verify spec decode equality when cuda graphs are enabled.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(
|
run_equality_correctness_test(vllm_runner,
|
||||||
baseline_llm_generator,
|
common_llm_kwargs,
|
||||||
test_llm_generator,
|
per_test_common_llm_kwargs,
|
||||||
batch_size,
|
baseline_llm_kwargs,
|
||||||
max_output_len=output_len,
|
test_llm_kwargs,
|
||||||
force_output_len=True,
|
batch_size,
|
||||||
)
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -80,13 +86,19 @@ def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
|
|||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_speculative_model_quantization_config(baseline_llm_generator,
|
def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
||||||
test_llm_generator,
|
per_test_common_llm_kwargs,
|
||||||
batch_size: int):
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
|
batch_size: int, seed: int):
|
||||||
"""Verify spec decode works well with draft model quantization configs.
|
"""Verify spec decode works well with draft model quantization configs.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=32,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=32,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|||||||
@ -7,42 +7,39 @@ import torch
|
|||||||
|
|
||||||
from vllm.utils import is_hip
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
from .conftest import run_greedy_equality_correctness_test
|
from .conftest import run_equality_correctness_test_tp
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[[
|
||||||
"model": "JackFram/llama-68m",
|
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"--enforce-eager",
|
||||||
|
|
||||||
# Required for spec decode.
|
# Required for spec decode.
|
||||||
"use_v2_block_manager": True,
|
"--use-v2-block-manager",
|
||||||
"tensor_parallel_size": 2,
|
"--tensor-parallel-size",
|
||||||
|
"2"
|
||||||
# Use AsyncLLM engine, so that the engine runs in its own process.
|
]])
|
||||||
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||||
# process will have both the engine and the rank0 worker. NCCL is not
|
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||||
# cleaned up properly, and its server host thread leaks, causing the
|
|
||||||
# second run of the test to fail with internal NCCL error.
|
|
||||||
"use_async": True,
|
|
||||||
}])
|
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
{
|
[
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"--speculative-model",
|
||||||
"num_speculative_tokens": 3,
|
"JackFram/llama-68m",
|
||||||
},
|
"--num-speculative-tokens",
|
||||||
{
|
"3",
|
||||||
"speculative_model": "[ngram]",
|
],
|
||||||
"num_speculative_tokens": 5,
|
[
|
||||||
"ngram_prompt_lookup_max": 3,
|
"--speculative-model",
|
||||||
},
|
"[ngram]",
|
||||||
|
"--num-speculative-tokens",
|
||||||
|
"5",
|
||||||
|
"--ngram-prompt-lookup-max",
|
||||||
|
"3",
|
||||||
|
],
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -52,75 +49,75 @@ from .conftest import run_greedy_equality_correctness_test
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
|
def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int, seed: int):
|
||||||
"""Verify greedy equality when tensor parallelism is used.
|
"""Verify greedy equality when tensor parallelism is used.
|
||||||
"""
|
"""
|
||||||
if is_hip():
|
if is_hip():
|
||||||
pytest.skip("hip is not well-supported yet")
|
pytest.skip("hip is not well-supported yet")
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test_tp("JackFram/llama-68m",
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
output_len,
|
||||||
|
seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[[
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"--enforce-eager",
|
||||||
|
|
||||||
# Required for spec decode.
|
# Required for spec decode.
|
||||||
"use_v2_block_manager": True,
|
"--use_v2_block_manager",
|
||||||
"tensor_parallel_size": 2,
|
"--tensor_parallel_size",
|
||||||
|
"2",
|
||||||
# Use AsyncLLM engine, so that the engine runs in its own process.
|
|
||||||
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
|
||||||
# process will have both the engine and the rank0 worker. NCCL is not
|
|
||||||
# cleaned up properly, and its server host thread leaks, causing the
|
|
||||||
# second run of the test to fail with internal NCCL error.
|
|
||||||
"use_async": True,
|
|
||||||
|
|
||||||
# precision
|
# precision
|
||||||
"dtype": "float32",
|
"--dtype",
|
||||||
}])
|
"bfloat16",
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
]])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||||
"per_test_common_llm_kwargs, test_llm_kwargs",
|
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||||
[
|
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||||
(
|
[("JackFram/llama-68m", [
|
||||||
{
|
"--speculative-model",
|
||||||
# Use a small model for a fast test.
|
"JackFram/llama-68m",
|
||||||
# Note this is repeated in the test body; to initialize a
|
"--num_speculative-tokens",
|
||||||
# tokenizer.
|
"5",
|
||||||
"model": "JackFram/llama-68m",
|
"--speculative-draft-tensor-parallel-size",
|
||||||
},
|
"1",
|
||||||
{
|
]),
|
||||||
"speculative_model": "JackFram/llama-68m",
|
("ibm-granite/granite-3b-code-instruct", [
|
||||||
"num_speculative_tokens": 5,
|
"--speculative-model",
|
||||||
"speculative_draft_tensor_parallel_size": 1,
|
"ibm-granite/granite-3b-code-instruct",
|
||||||
}),
|
"--num_speculative-tokens",
|
||||||
({
|
"5",
|
||||||
"model": "ibm-granite/granite-3b-code-instruct",
|
"--speculative-draft-tensor-parallel-size",
|
||||||
}, {
|
"1",
|
||||||
"speculative_model":
|
])])
|
||||||
"ibm-granite/granite-3b-code-instruct-accelerator",
|
|
||||||
"num_speculative_tokens": 5,
|
|
||||||
"speculative_draft_tensor_parallel_size": 1,
|
|
||||||
})
|
|
||||||
])
|
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
|
def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
||||||
baseline_llm_generator,
|
per_test_common_llm_kwargs,
|
||||||
batch_size: int):
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int,
|
||||||
|
seed: int):
|
||||||
"""Verify spec decode works well with smaller tp for draft models.
|
"""Verify spec decode works well with smaller tp for draft models.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test_tp(model,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=32,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=32,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|||||||
@ -2,98 +2,97 @@
|
|||||||
tensor parallelism.
|
tensor parallelism.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .conftest import run_greedy_equality_correctness_test
|
from .conftest import run_equality_correctness_test_tp
|
||||||
|
|
||||||
|
MAIN_MODEL = "JackFram/llama-68m"
|
||||||
|
SPEC_MODEL = "JackFram/llama-68m"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
reason="Need at least 4 GPUs to run the test.")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[[
|
||||||
# Use a small model for a fast test.
|
|
||||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
|
||||||
"model": "JackFram/llama-68m",
|
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"--enforce_eager",
|
||||||
|
|
||||||
# Required for spec decode.
|
# Required for spec decode.
|
||||||
"use_v2_block_manager": True,
|
"--use-v2-block-manager",
|
||||||
"tensor_parallel_size": 4,
|
"--tensor-parallel-size",
|
||||||
|
"4",
|
||||||
# Use AsyncLLM engine, so that the engine runs in its own process.
|
]])
|
||||||
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
|
||||||
# process will have both the engine and the rank0 worker. NCCL is not
|
|
||||||
# cleaned up properly, and its server host thread leaks, causing the
|
|
||||||
# second run of the test to fail with internal NCCL error.
|
|
||||||
"use_async": True,
|
|
||||||
}])
|
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
{
|
[
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"--speculative-model",
|
||||||
"num_speculative_tokens": 5,
|
f"{SPEC_MODEL}",
|
||||||
},
|
"--num-speculative-tokens",
|
||||||
|
"5",
|
||||||
|
],
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
#TODO(wooyeon): add spec_draft_dp=2 case
|
#TODO(wooyeon): add spec_draft_dp=2 case
|
||||||
{
|
[
|
||||||
"speculative_draft_tensor_parallel_size": 1,
|
"--speculative-draft-tensor-parallel-size",
|
||||||
},
|
"1",
|
||||||
|
],
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
|
def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
|
||||||
baseline_llm_generator,
|
per_test_common_llm_kwargs,
|
||||||
batch_size: int):
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int,
|
||||||
|
seed: int):
|
||||||
"""Verify spec decode works well with smaller tp for draft models.
|
"""Verify spec decode works well with smaller tp for draft models.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test_tp(MAIN_MODEL,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=32,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=32,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||||
reason="Need at least 4 GPUs to run the test.")
|
reason="Need at least 4 GPUs to run the test.")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[[
|
||||||
"model": "JackFram/llama-160m",
|
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"--enforce-eager",
|
||||||
|
|
||||||
# Required for spec decode.
|
# Required for spec decode.
|
||||||
"use_v2_block_manager": True,
|
"--use-v2-block-manager",
|
||||||
"tensor_parallel_size": 4,
|
"--tensor-parallel-size",
|
||||||
|
"4",
|
||||||
# Use AsyncLLM engine, so that the engine runs in its own process.
|
]])
|
||||||
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||||
# process will have both the engine and the rank0 worker. NCCL is not
|
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||||
# cleaned up properly, and its server host thread leaks, causing the
|
|
||||||
# second run of the test to fail with internal NCCL error.
|
|
||||||
"use_async": True,
|
|
||||||
}])
|
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_llm_kwargs",
|
"test_llm_kwargs",
|
||||||
[
|
[
|
||||||
{
|
[
|
||||||
"speculative_model": "JackFram/llama-68m",
|
"--speculative-model",
|
||||||
"num_speculative_tokens": 5,
|
f"{SPEC_MODEL}",
|
||||||
|
"--num-speculative-tokens",
|
||||||
|
"5",
|
||||||
|
|
||||||
# Artificially limit the draft model max model len; this forces vLLM
|
# Artificially limit the draft model max model len; this forces vLLM
|
||||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||||
"speculative_max_model_len": 32,
|
"--speculative-max-model-len",
|
||||||
},
|
"32",
|
||||||
|
],
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -105,8 +104,9 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
|
|||||||
64,
|
64,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
def test_skip_speculation(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int, seed: int):
|
||||||
"""Verify job failure with RuntimeError when all sequences skip speculation.
|
"""Verify job failure with RuntimeError when all sequences skip speculation.
|
||||||
We do this by setting the max model len of the draft model to an
|
We do this by setting the max model len of the draft model to an
|
||||||
artificially low value, such that when the sequences grow beyond it, they
|
artificially low value, such that when the sequences grow beyond it, they
|
||||||
@ -114,9 +114,13 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
|||||||
|
|
||||||
TODO: fix it to pass without raising Error. (#5814)
|
TODO: fix it to pass without raising Error. (#5814)
|
||||||
"""
|
"""
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(openai.APIConnectionError):
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test_tp(MAIN_MODEL,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
output_len,
|
||||||
|
seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|||||||
@ -1,24 +1,22 @@
|
|||||||
import math
|
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
from .conftest import get_logprobs_from_llm_generator
|
from .conftest import run_logprob_correctness_test
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
|
|
||||||
# Required for spec decode.
|
# Required for spec decode.
|
||||||
"use_v2_block_manager": True,
|
"use_v2_block_manager": True,
|
||||||
"max_logprobs": 6,
|
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -36,64 +34,29 @@ from .conftest import get_logprobs_from_llm_generator
|
|||||||
7,
|
7,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||||
batch_size: int, output_len: int):
|
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int, logprobs: int):
|
||||||
"""Verify output logprobs are equal with and without speculative decoding.
|
"""Verify output logprobs are equal with and without speculative decoding.
|
||||||
"""
|
"""
|
||||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
run_logprob_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
output_len,
|
||||||
|
seed,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=logprobs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
|
||||||
"enforce_eager": True,
|
|
||||||
|
|
||||||
# Required for spec decode.
|
|
||||||
"use_v2_block_manager": True,
|
|
||||||
"max_logprobs": 6,
|
|
||||||
}])
|
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize("test_llm_kwargs",
|
|
||||||
[{
|
|
||||||
"speculative_model": "JackFram/llama-160m",
|
|
||||||
"num_speculative_tokens": 3,
|
|
||||||
"disable_logprobs_during_spec_decoding": False,
|
|
||||||
}])
|
|
||||||
@pytest.mark.parametrize("batch_size", [1])
|
|
||||||
@pytest.mark.parametrize("num_logprobs", [6])
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"output_len",
|
|
||||||
[
|
|
||||||
# Use smaller output len for fast test.
|
|
||||||
7,
|
|
||||||
])
|
|
||||||
@pytest.mark.parametrize("seed", [1])
|
|
||||||
def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
|
|
||||||
batch_size: int, output_len: int,
|
|
||||||
num_logprobs: int):
|
|
||||||
"""Verify output logprobs are equal with and without spec decode.
|
|
||||||
This specifies a number of logprobs >1.
|
|
||||||
"""
|
|
||||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
|
||||||
test_llm_generator,
|
|
||||||
batch_size,
|
|
||||||
max_output_len=output_len,
|
|
||||||
force_output_len=True,
|
|
||||||
logprob_rank=num_logprobs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"common_llm_kwargs",
|
|
||||||
[{
|
|
||||||
"model": "JackFram/llama-68m",
|
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -121,21 +84,29 @@ def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
|
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||||
batch_size: int, output_len: int):
|
def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int,
|
||||||
|
output_len: int, seed: int, logprobs: int):
|
||||||
"""Veriy logprob greedy equality with different speculation lens.
|
"""Veriy logprob greedy equality with different speculation lens.
|
||||||
"""
|
"""
|
||||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
run_logprob_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
output_len,
|
||||||
|
seed,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=logprobs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -164,22 +135,30 @@ def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_logprobs_when_skip_speculation(baseline_llm_generator,
|
@pytest.mark.parametrize("logprobs", [1])
|
||||||
test_llm_generator, batch_size: int,
|
def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||||
output_len: int):
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int,
|
||||||
|
seed: int, logprobs: int):
|
||||||
"""Verify logprobs greedy equality when some sequences skip speculation.
|
"""Verify logprobs greedy equality when some sequences skip speculation.
|
||||||
"""
|
"""
|
||||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
run_logprob_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
output_len,
|
||||||
|
seed,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=logprobs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -203,19 +182,17 @@ def test_logprobs_when_skip_speculation(baseline_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
|
@pytest.mark.parametrize("logprobs", [6])
|
||||||
batch_size: int, output_len: int):
|
def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int, logprobs: int):
|
||||||
"""Verify at least one logprob result has num_logprobs+1, which tests the
|
"""Verify at least one logprob result has num_logprobs+1, which tests the
|
||||||
case where the sampled token is not in top-k logprobs.
|
case where the sampled token is not in top-k logprobs.
|
||||||
|
|
||||||
Ideally, this test should validate equality with non-spec by getting
|
Ideally, this test should validate equality with non-spec by getting
|
||||||
logprobs. This is left as future improvement.
|
logprobs. This is left as future improvement.
|
||||||
"""
|
"""
|
||||||
batch_size = 8
|
|
||||||
max_output_len = output_len
|
|
||||||
force_output_len = True
|
|
||||||
logprob_rank = 5
|
|
||||||
|
|
||||||
temperature = 1.0
|
temperature = 1.0
|
||||||
|
|
||||||
prompts = [
|
prompts = [
|
||||||
@ -231,129 +208,40 @@ def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
|
|||||||
|
|
||||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||||
|
|
||||||
# If the test requires that we generated max_output_len tokens, then set the
|
|
||||||
# sampling params to ignore eos token.
|
|
||||||
ignore_eos = force_output_len
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
max_tokens=max_output_len,
|
max_tokens=output_len,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=True,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
logprobs=logprob_rank,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
spec_batch_logprobs = get_logprobs_from_llm_generator(
|
sd_args = {
|
||||||
test_llm_generator, prompts, sampling_params)
|
**common_llm_kwargs,
|
||||||
|
**per_test_common_llm_kwargs,
|
||||||
|
**test_llm_kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
with vllm_runner(**sd_args) as vllm_model:
|
||||||
|
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||||
|
|
||||||
num_returned_logprobs = [
|
num_returned_logprobs = [
|
||||||
len(logprob_dict) for seq_logprobs in spec_batch_logprobs
|
len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
|
||||||
for logprob_dict in seq_logprobs
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Assert one of the returned logprobs has > num_logprobs (indicating the
|
# Assert one of the returned logprobs has > num_logprobs (indicating the
|
||||||
# sampled token is not in top-k).
|
# sampled token is not in top-k).
|
||||||
assert any([
|
assert any(
|
||||||
num_returned > logprob_rank for num_returned in num_returned_logprobs
|
[num_returned > logprobs for num_returned in num_returned_logprobs])
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
def run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
|
||||||
test_llm_generator,
|
|
||||||
batch_size,
|
|
||||||
max_output_len,
|
|
||||||
force_output_len: bool,
|
|
||||||
logprob_rank: int = 1):
|
|
||||||
"""Helper method that compares the logprobs outputs of both the baseline LLM
|
|
||||||
and the test LLM. It asserts greedy equality of the logprobs when the
|
|
||||||
temperature is zero.
|
|
||||||
"""
|
|
||||||
temperature = 0.0
|
|
||||||
|
|
||||||
prompts = [
|
|
||||||
"Hello, my name is",
|
|
||||||
"The president of the United States is",
|
|
||||||
"The capital of France is",
|
|
||||||
"The future of AI is",
|
|
||||||
"San Francisco is know for its",
|
|
||||||
"Facebook was created in 2004 by",
|
|
||||||
"Curious George is a",
|
|
||||||
"Python 3.11 brings improvements to its",
|
|
||||||
]
|
|
||||||
|
|
||||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
|
||||||
|
|
||||||
# If the test requires that we generated max_output_len tokens, then set the
|
|
||||||
# sampling params to ignore eos token.
|
|
||||||
ignore_eos = force_output_len
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
max_tokens=max_output_len,
|
|
||||||
ignore_eos=ignore_eos,
|
|
||||||
temperature=temperature,
|
|
||||||
logprobs=logprob_rank,
|
|
||||||
)
|
|
||||||
|
|
||||||
spec_batch_logprobs = get_logprobs_from_llm_generator(
|
|
||||||
test_llm_generator, prompts, sampling_params)
|
|
||||||
baseline_batch_logprobs = get_logprobs_from_llm_generator(
|
|
||||||
baseline_llm_generator, prompts, sampling_params)
|
|
||||||
|
|
||||||
assert len(baseline_batch_logprobs) == len(prompts)
|
|
||||||
assert len(spec_batch_logprobs) == len(prompts)
|
|
||||||
|
|
||||||
# For each sequence in the batch.
|
|
||||||
for i, (baseline_logprobs, spec_logprobs) in enumerate(
|
|
||||||
zip(baseline_batch_logprobs, spec_batch_logprobs)):
|
|
||||||
assert len(spec_logprobs) == len(baseline_logprobs)
|
|
||||||
|
|
||||||
# For each generated position of the sequence.
|
|
||||||
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
|
||||||
zip(spec_logprobs, baseline_logprobs)):
|
|
||||||
|
|
||||||
# Map rank to token/logprob in spec output.
|
|
||||||
spec_rank_to_token_id = {
|
|
||||||
value.rank: key
|
|
||||||
for key, value in spec_pos_logprobs.items()
|
|
||||||
}
|
|
||||||
spec_rank_to_logprob = {
|
|
||||||
value.rank: value.logprob
|
|
||||||
for key, value in spec_pos_logprobs.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Map rank to token/logprob in baseline output.
|
|
||||||
baseline_rank_to_token_id = {
|
|
||||||
value.rank: key
|
|
||||||
for key, value in baseline_pos_logprobs.items()
|
|
||||||
}
|
|
||||||
baseline_rank_to_logprob = {
|
|
||||||
value.rank: value.logprob
|
|
||||||
for key, value in baseline_pos_logprobs.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Assert set of ranks returned is equal.
|
|
||||||
assert set(spec_rank_to_token_id.keys()) == set(
|
|
||||||
baseline_rank_to_token_id.keys())
|
|
||||||
|
|
||||||
# Assert each logprob/token id is correct, keyed by rank.
|
|
||||||
for rank in sorted(set(spec_rank_to_token_id.keys())):
|
|
||||||
assert spec_rank_to_token_id[
|
|
||||||
rank] == baseline_rank_to_token_id[rank], f"{rank}"
|
|
||||||
assert math.isclose(
|
|
||||||
a=spec_rank_to_logprob[rank],
|
|
||||||
b=baseline_rank_to_logprob[rank],
|
|
||||||
abs_tol=1e-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
# Required for spec decode.
|
# Required for spec decode.
|
||||||
"use_v2_block_manager": True,
|
"use_v2_block_manager": True,
|
||||||
"max_logprobs": 6,
|
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -364,57 +252,28 @@ def run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
|||||||
"disable_logprobs_during_spec_decoding": True,
|
"disable_logprobs_during_spec_decoding": True,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_logprobs_disabled(baseline_llm_generator, test_llm_generator):
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("logprobs", [0])
|
||||||
|
def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int, logprobs: int):
|
||||||
"""Check the behavior when logprobs are disabled.
|
"""Check the behavior when logprobs are disabled.
|
||||||
Token choices should match with the base model.
|
Token choices should match with the base model.
|
||||||
"""
|
"""
|
||||||
prompts = [
|
run_logprob_correctness_test(vllm_runner,
|
||||||
"Hello, my name is",
|
common_llm_kwargs,
|
||||||
"The president of the United States is",
|
per_test_common_llm_kwargs,
|
||||||
"The capital of France is",
|
baseline_llm_kwargs,
|
||||||
"The future of AI is",
|
test_llm_kwargs,
|
||||||
"San Francisco is know for its",
|
batch_size,
|
||||||
"Facebook was created in 2004 by",
|
output_len,
|
||||||
"Curious George is a",
|
seed,
|
||||||
"Python 3.11 brings improvements to its",
|
temperature=0.0,
|
||||||
]
|
logprobs=logprobs)
|
||||||
|
|
||||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(4))]
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
# Use smaller output len for fast test
|
|
||||||
max_tokens=7,
|
|
||||||
ignore_eos=True,
|
|
||||||
temperature=0.0,
|
|
||||||
logprobs=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
spec_batch_logprobs = get_logprobs_from_llm_generator(
|
|
||||||
test_llm_generator, prompts, sampling_params)
|
|
||||||
baseline_batch_logprobs = get_logprobs_from_llm_generator(
|
|
||||||
baseline_llm_generator, prompts, sampling_params)
|
|
||||||
|
|
||||||
assert len(baseline_batch_logprobs) == len(prompts)
|
|
||||||
assert len(spec_batch_logprobs) == len(prompts)
|
|
||||||
|
|
||||||
# For each sequence in the batch.
|
|
||||||
for _, (baseline_logprobs, spec_logprobs) in enumerate(
|
|
||||||
zip(baseline_batch_logprobs, spec_batch_logprobs)):
|
|
||||||
assert len(spec_logprobs) == len(baseline_logprobs)
|
|
||||||
|
|
||||||
# For each generated position of the sequence.
|
|
||||||
for _, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
|
||||||
zip(spec_logprobs, baseline_logprobs)):
|
|
||||||
|
|
||||||
assert len(spec_pos_logprobs) == 1
|
|
||||||
spec_top_token_id = list(spec_pos_logprobs)[0]
|
|
||||||
|
|
||||||
spec_top_logprob = spec_pos_logprobs[spec_top_token_id]
|
|
||||||
assert spec_top_logprob.logprob == 0.0
|
|
||||||
assert spec_top_logprob.rank == -1
|
|
||||||
|
|
||||||
# check that the chosen token matches the base model
|
|
||||||
baseline_logprob = baseline_pos_logprobs[spec_top_token_id]
|
|
||||||
assert baseline_logprob.rank == 1
|
|
||||||
assert spec_top_logprob.decoded_token \
|
|
||||||
== baseline_logprob.decoded_token
|
|
||||||
|
|||||||
@ -21,7 +21,7 @@ correctess for the target model outputs.
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .conftest import run_greedy_equality_correctness_test
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
# main model
|
# main model
|
||||||
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
|
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
|
||||||
@ -55,7 +55,7 @@ PRECISION = "float32"
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -70,15 +70,21 @@ PRECISION = "float32"
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
|
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||||
test_llm_generator, batch_size: int,
|
per_test_common_llm_kwargs,
|
||||||
output_len: int):
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify greedy equality with different batch size."""
|
"""Verify greedy equality with different batch size."""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -96,7 +102,7 @@ def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -111,17 +117,21 @@ def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||||
test_llm_generator,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
batch_size: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
output_len: int):
|
seed: int):
|
||||||
"""Verify greedy equality with cuda graph enabled and different
|
"""Verify greedy equality with cuda graph enabled and different
|
||||||
batch sizes."""
|
batch sizes."""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -142,7 +152,7 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -160,18 +170,22 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||||
test_llm_generator,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
batch_size: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
output_len: int):
|
seed: int):
|
||||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
generation.
|
generation.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -187,7 +201,7 @@ def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -209,16 +223,22 @@ def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
|
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify that medusa speculative decoding produces exact equality
|
"""Verify that medusa speculative decoding produces exact equality
|
||||||
to without spec decode with different values of num_speculative_tokens.
|
to without spec decode with different values of num_speculative_tokens.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -234,7 +254,7 @@ def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -252,17 +272,23 @@ def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_medusa_disable_queue(baseline_llm_generator, test_llm_generator,
|
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int,
|
||||||
|
output_len: int, seed: int):
|
||||||
"""Verify that medusa speculative decoding produces exact equality
|
"""Verify that medusa speculative decoding produces exact equality
|
||||||
to without spec decode when speculation is disabled for large
|
to without spec decode when speculation is disabled for large
|
||||||
batch sizes.
|
batch sizes.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -25,8 +25,7 @@ import pytest
|
|||||||
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
|
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
|
||||||
|
|
||||||
from .conftest import (run_equality_correctness_test,
|
from .conftest import run_equality_correctness_test
|
||||||
run_greedy_equality_correctness_test)
|
|
||||||
|
|
||||||
# main model
|
# main model
|
||||||
MAIN_MODEL = "JackFram/llama-160m"
|
MAIN_MODEL = "JackFram/llama-160m"
|
||||||
@ -58,7 +57,7 @@ PRECISION = "float32"
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -72,14 +71,21 @@ PRECISION = "float32"
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify greedy equality with different batch size."""
|
"""Verify greedy equality with different batch size."""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -98,7 +104,7 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -110,17 +116,21 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
|||||||
@pytest.mark.parametrize("output_len", [2048])
|
@pytest.mark.parametrize("output_len", [2048])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
|
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int, seed: int):
|
||||||
"""Verify acceptance rate with different batch size and large output
|
"""Verify acceptance rate with different batch size and large output
|
||||||
length."""
|
length."""
|
||||||
run_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
batch_size,
|
batch_size,
|
||||||
max_output_len=output_len,
|
max_output_len=output_len,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
seeded=True,
|
seed=seed,
|
||||||
force_output_len=True,
|
|
||||||
expected_acceptance_rate=0.48)
|
expected_acceptance_rate=0.48)
|
||||||
|
|
||||||
|
|
||||||
@ -140,7 +150,7 @@ def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
|
|
||||||
# Speculative model
|
# Speculative model
|
||||||
"speculative_model": SPEC_MODEL,
|
"speculative_model": SPEC_MODEL,
|
||||||
@ -151,28 +161,35 @@ def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
|
|||||||
@pytest.mark.parametrize("output_len", [64])
|
@pytest.mark.parametrize("output_len", [64])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("temperature", [0.1, 1.0])
|
@pytest.mark.parametrize("temperature", [0.1, 1.0])
|
||||||
@pytest.mark.parametrize("seed", [None])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
|
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
batch_size: int, output_len: int,
|
batch_size: int, output_len: int,
|
||||||
temperature: float):
|
temperature: float, seed: int):
|
||||||
"""Verify seeded runs produce the same output."""
|
"""Verify seeded runs produce the same output."""
|
||||||
run_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
batch_size,
|
batch_size,
|
||||||
max_output_len=output_len,
|
max_output_len=output_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
seeded=True,
|
seed=seed)
|
||||||
force_output_len=True)
|
|
||||||
|
|
||||||
# Ensure this same test does fail if we _don't_ include per-request seeds
|
# Ensure this same test does fail if we _don't_ include per-request seeds
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
run_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs,
|
||||||
batch_size,
|
batch_size,
|
||||||
max_output_len=output_len,
|
max_output_len=output_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
seeded=False,
|
seed=seed,
|
||||||
force_output_len=True)
|
disable_seed=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -193,7 +210,7 @@ def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -210,18 +227,22 @@ def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||||
test_llm_generator,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
batch_size: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
output_len: int):
|
seed: int):
|
||||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
generation.
|
generation.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -242,7 +263,7 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -259,10 +280,10 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
|
def test_mlp_e2e_greedy_correctness_with_padding(
|
||||||
test_llm_generator,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
batch_size: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
output_len: int):
|
seed: int):
|
||||||
"""Verify greedy equality when the vocab dimension is padded
|
"""Verify greedy equality when the vocab dimension is padded
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -273,11 +294,15 @@ def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
|
|||||||
with patch(
|
with patch(
|
||||||
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
|
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
|
||||||
patched_pad_vocab_size):
|
patched_pad_vocab_size):
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -293,7 +318,7 @@ def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -315,16 +340,22 @@ def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
|
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, seed: int,
|
||||||
|
output_len: int):
|
||||||
"""Verify that mlp speculative decoding produces exact equality
|
"""Verify that mlp speculative decoding produces exact equality
|
||||||
to without spec decode with different values of num_speculative_tokens.
|
to without spec decode with different values of num_speculative_tokens.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -340,7 +371,7 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
|
|||||||
"dtype": PRECISION,
|
"dtype": PRECISION,
|
||||||
|
|
||||||
# Main model
|
# Main model
|
||||||
"model": MAIN_MODEL,
|
"model_name": MAIN_MODEL,
|
||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -357,14 +388,20 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
|
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, seed: int,
|
||||||
|
output_len: int):
|
||||||
"""Verify that mlp speculative decoding produces exact equality
|
"""Verify that mlp speculative decoding produces exact equality
|
||||||
to without spec decode when speculation is disabled for large
|
to without spec decode when speculation is disabled for large
|
||||||
batch sizes.
|
batch sizes.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|||||||
@ -41,8 +41,9 @@ from transformers import AutoTokenizer
|
|||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
from ...utils import fork_new_process_for_each_test
|
||||||
from .conftest import (get_output_from_llm_generator,
|
from .conftest import (get_output_from_llm_generator,
|
||||||
run_greedy_equality_correctness_test)
|
run_equality_correctness_test)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -73,6 +74,7 @@ from .conftest import (get_output_from_llm_generator,
|
|||||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||||
batch_size: int):
|
batch_size: int):
|
||||||
"""Run generation with speculative decoding on a batch. Verify the engine
|
"""Run generation with speculative decoding on a batch. Verify the engine
|
||||||
@ -116,44 +118,6 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
|||||||
assert actual_tokens.strip() == expected_tokens.strip()
|
assert actual_tokens.strip() == expected_tokens.strip()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"common_llm_kwargs",
|
|
||||||
[{
|
|
||||||
# Use a small model for a fast test.
|
|
||||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
|
||||||
"model": "JackFram/llama-68m",
|
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
|
||||||
"enforce_eager": True,
|
|
||||||
|
|
||||||
# Required for spec decode.
|
|
||||||
"use_v2_block_manager": True,
|
|
||||||
|
|
||||||
# Use AsyncLLM engine
|
|
||||||
"use_async": True,
|
|
||||||
}])
|
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
|
||||||
{
|
|
||||||
"speculative_model": "JackFram/llama-68m",
|
|
||||||
"num_speculative_tokens": 5,
|
|
||||||
},
|
|
||||||
])
|
|
||||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
|
||||||
@pytest.mark.parametrize("batch_size", [2])
|
|
||||||
@pytest.mark.parametrize("seed", [1])
|
|
||||||
def test_spec_decode_e2e_with_async_engine(test_llm_generator,
|
|
||||||
baseline_llm_generator,
|
|
||||||
batch_size: int):
|
|
||||||
"""Verify spec decode works well with async LLM engine.
|
|
||||||
"""
|
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
|
||||||
test_llm_generator,
|
|
||||||
batch_size,
|
|
||||||
max_output_len=32,
|
|
||||||
force_output_len=True)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
@ -172,10 +136,10 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
|
|||||||
# Try two different tiny base models.
|
# Try two different tiny base models.
|
||||||
# Note that one is equal to the draft model, another isn't.
|
# Note that one is equal to the draft model, another isn't.
|
||||||
{
|
{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -189,13 +153,15 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
|
|||||||
"output_len",
|
"output_len",
|
||||||
[
|
[
|
||||||
# Use long output len for the small model test.
|
# Use long output len for the small model test.
|
||||||
1536,
|
10,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1])
|
@pytest.mark.parametrize("batch_size", [1])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
output_len: int):
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify greedy equality on a tiny model with batch size of one.
|
"""Verify greedy equality on a tiny model with batch size of one.
|
||||||
|
|
||||||
Since this test is cheaper than other e2e correctness tests, we generate
|
Since this test is cheaper than other e2e correctness tests, we generate
|
||||||
@ -204,14 +170,18 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
|||||||
When the draft model is the same as the target model, we further check
|
When the draft model is the same as the target model, we further check
|
||||||
whether all speculative tokens are accepted.
|
whether all speculative tokens are accepted.
|
||||||
"""
|
"""
|
||||||
ensure_all_accepted = test_llm_generator.same_draft_target_model
|
ensure_all_accepted = per_test_common_llm_kwargs.get(
|
||||||
run_greedy_equality_correctness_test(
|
"model_name") == test_llm_kwargs.get("speculative_model")
|
||||||
baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True,
|
test_llm_kwargs,
|
||||||
ensure_all_accepted=ensure_all_accepted)
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0,
|
||||||
|
ensure_all_accepted=ensure_all_accepted)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -232,10 +202,10 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
|||||||
# Try two different tiny base models.
|
# Try two different tiny base models.
|
||||||
# Note that one is equal to the draft model, another isn't.
|
# Note that one is equal to the draft model, another isn't.
|
||||||
{
|
{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -253,16 +223,22 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [64])
|
@pytest.mark.parametrize("batch_size", [64])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
output_len: int):
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify greedy equality on a tiny model and large batch size.
|
"""Verify greedy equality on a tiny model and large batch size.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -280,10 +256,10 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
|||||||
# Try two different tiny base models.
|
# Try two different tiny base models.
|
||||||
# Note that one is equal to the draft model, another isn't.
|
# Note that one is equal to the draft model, another isn't.
|
||||||
{
|
{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"model": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -298,24 +274,31 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [32])
|
@pytest.mark.parametrize("batch_size", [32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
max_output_len: int):
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||||
|
max_output_len: int, seed: int):
|
||||||
"""Verify greedy equality on a tiny model, with a large batch size, and when
|
"""Verify greedy equality on a tiny model, with a large batch size, and when
|
||||||
sampling respects the EOS token.
|
sampling respects the EOS token.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=False)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0,
|
||||||
|
ignore_eos=False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
# A "real" model (not tiny).
|
# A "real" model (not tiny).
|
||||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
"model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -342,24 +325,30 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
|||||||
256,
|
256,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
output_len: int):
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify greedy equality on a "real" model and batch size of 1. This is
|
"""Verify greedy equality on a "real" model and batch size of 1. This is
|
||||||
separate from large BS tests to make identifying the source of bugs easier.
|
separate from large BS tests to make identifying the source of bugs easier.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
# A "real" model (not tiny).
|
# A "real" model (not tiny).
|
||||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
"model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -386,17 +375,23 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
|||||||
64,
|
64,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
output_len: int):
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify greedy equality with a "real" model on a nontrivial batch size.
|
"""Verify greedy equality with a "real" model on a nontrivial batch size.
|
||||||
This is the closest test to a real production workload.
|
This is the closest test to a real production workload.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -415,7 +410,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"model": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -433,23 +428,29 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
@fork_new_process_for_each_test
|
||||||
def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
output_len: int):
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
generation.
|
generation.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -487,22 +488,29 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_spec_decode_different_block_size(baseline_llm_generator,
|
@fork_new_process_for_each_test
|
||||||
test_llm_generator, batch_size: int,
|
def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
|
||||||
output_len: int):
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify greedy equality over different block sizes.
|
"""Verify greedy equality over different block sizes.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -534,24 +542,31 @@ def test_spec_decode_different_block_size(baseline_llm_generator,
|
|||||||
64,
|
64,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
@fork_new_process_for_each_test
|
||||||
batch_size: int, output_len: int):
|
def test_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify greedy equality when some (or all) sequences skip speculation.
|
"""Verify greedy equality when some (or all) sequences skip speculation.
|
||||||
We do this by setting the max model len of the draft model to an
|
We do this by setting the max model len of the draft model to an
|
||||||
artificially low value, such that when the sequences grow beyond it, they
|
artificially low value, such that when the sequences grow beyond it, they
|
||||||
are skipped in speculative decoding.
|
are skipped in speculative decoding.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -571,21 +586,28 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
|||||||
@pytest.mark.parametrize("batch_size", [8])
|
@pytest.mark.parametrize("batch_size", [8])
|
||||||
@pytest.mark.parametrize("output_len", [10])
|
@pytest.mark.parametrize("output_len", [10])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_disable_speculation(baseline_llm_generator, test_llm_generator,
|
@fork_new_process_for_each_test
|
||||||
batch_size: int, output_len: int):
|
def test_disable_speculation(vllm_runner, common_llm_kwargs,
|
||||||
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify greedy equality when all sequences disable speculation.
|
"""Verify greedy equality when all sequences disable speculation.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -613,22 +635,28 @@ def test_disable_speculation(baseline_llm_generator, test_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
@fork_new_process_for_each_test
|
||||||
output_len: int):
|
def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||||
|
output_len: int, seed: int):
|
||||||
"""Verify that speculative decoding produces exact equality to without spec
|
"""Verify that speculative decoding produces exact equality to without spec
|
||||||
decode with many different values of k.
|
decode with many different values of k.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -657,15 +685,22 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_typical_acceptance_sampling(baseline_llm_generator,
|
@fork_new_process_for_each_test
|
||||||
test_llm_generator, batch_size: int,
|
def test_typical_acceptance_sampling(vllm_runner, common_llm_kwargs,
|
||||||
output_len: int):
|
per_test_common_llm_kwargs,
|
||||||
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify that speculative decoding produces exact equality to without spec
|
"""Verify that speculative decoding produces exact equality to without spec
|
||||||
decode with TypicalAcceptanceSampler as the draft token acceptance
|
decode with TypicalAcceptanceSampler as the draft token acceptance
|
||||||
sampling method.
|
sampling method.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|||||||
@ -26,7 +26,7 @@ for the target model outputs.
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from .conftest import run_greedy_equality_correctness_test
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -43,7 +43,7 @@ from .conftest import run_greedy_equality_correctness_test
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -59,15 +59,21 @@ from .conftest import run_greedy_equality_correctness_test
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||||
test_llm_generator, batch_size: int,
|
per_test_common_llm_kwargs,
|
||||||
output_len: int):
|
baseline_llm_kwargs, test_llm_kwargs,
|
||||||
|
batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify greedy equality on a tiny model with different batch size."""
|
"""Verify greedy equality on a tiny model with different batch size."""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -86,7 +92,7 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
|||||||
}])
|
}])
|
||||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
{
|
{
|
||||||
"model": "JackFram/llama-160m",
|
"model_name": "JackFram/llama-160m",
|
||||||
},
|
},
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
@ -105,24 +111,28 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
|||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
def test_ngram_e2e_greedy_correctness_with_preemption(
|
||||||
test_llm_generator,
|
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||||
batch_size: int,
|
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
output_len: int):
|
seed: int):
|
||||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
generation.
|
generation.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
temperature=0,
|
||||||
|
seed=seed)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -159,23 +169,29 @@ def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
|
def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify that ngram speculative decoding produces exact equality
|
"""Verify that ngram speculative decoding produces exact equality
|
||||||
to without spec decode with many different values of k and
|
to without spec decode with many different values of k and
|
||||||
different ngram_prompt_lookup_max.
|
different ngram_prompt_lookup_max.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -200,14 +216,20 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
|
|||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator,
|
def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||||
batch_size: int, output_len: int):
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
|
test_llm_kwargs, batch_size: int, output_len: int,
|
||||||
|
seed: int):
|
||||||
"""Verify that ngram speculative decoding produces exact equality
|
"""Verify that ngram speculative decoding produces exact equality
|
||||||
to without spec decode with many different values of k and
|
to without spec decode with many different values of k and
|
||||||
different ngram_prompt_lookup_max.
|
different ngram_prompt_lookup_max.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(vllm_runner,
|
||||||
test_llm_generator,
|
common_llm_kwargs,
|
||||||
batch_size,
|
per_test_common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
baseline_llm_kwargs,
|
||||||
force_output_len=True)
|
test_llm_kwargs,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
seed=seed,
|
||||||
|
temperature=0.0)
|
||||||
|
|||||||
@ -2,11 +2,17 @@ import pytest
|
|||||||
|
|
||||||
from .conftest import run_equality_correctness_test
|
from .conftest import run_equality_correctness_test
|
||||||
|
|
||||||
|
# main model
|
||||||
|
MAIN_MODEL = "JackFram/llama-68m"
|
||||||
|
|
||||||
|
# speculative model
|
||||||
|
SPEC_MODEL = "JackFram/llama-160m"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
"model": "JackFram/llama-68m",
|
"model_name": "JackFram/llama-68m",
|
||||||
|
|
||||||
# Skip cuda graph recording for fast test.
|
# Skip cuda graph recording for fast test.
|
||||||
"enforce_eager": True,
|
"enforce_eager": True,
|
||||||
@ -31,26 +37,34 @@ from .conftest import run_equality_correctness_test
|
|||||||
# Use smaller output len for fast test.
|
# Use smaller output len for fast test.
|
||||||
20,
|
20,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [None])
|
def test_seeded_consistency(vllm_runner, common_llm_kwargs,
|
||||||
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||||
batch_size: int, temperature: float,
|
test_llm_kwargs, batch_size: int,
|
||||||
output_len: int):
|
temperature: float, output_len: int):
|
||||||
"""Verify outputs are consistent across multiple runs with same seed
|
"""Verify outputs are consistent across multiple runs with same seed
|
||||||
"""
|
"""
|
||||||
run_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(
|
||||||
test_llm_generator,
|
vllm_runner,
|
||||||
batch_size,
|
common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
per_test_common_llm_kwargs,
|
||||||
temperature=temperature,
|
baseline_llm_kwargs,
|
||||||
seeded=True,
|
test_llm_kwargs,
|
||||||
force_output_len=True)
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
temperature=temperature,
|
||||||
|
disable_seed=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure this same test does fail if we _don't_ include per-request seeds
|
# Ensure this same test does fail if we _don't_ include per-request seeds
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
run_equality_correctness_test(baseline_llm_generator,
|
run_equality_correctness_test(
|
||||||
test_llm_generator,
|
vllm_runner,
|
||||||
batch_size,
|
common_llm_kwargs,
|
||||||
max_output_len=output_len,
|
per_test_common_llm_kwargs,
|
||||||
temperature=temperature,
|
baseline_llm_kwargs,
|
||||||
seeded=False,
|
test_llm_kwargs,
|
||||||
force_output_len=True)
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
temperature=temperature,
|
||||||
|
disable_seed=True,
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user