mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:24:56 +08:00
[Speculative Decoding] Test refactor (#8317)
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
parent
8baa454937
commit
775f00f81e
@ -217,7 +217,8 @@ steps:
|
||||
commands:
|
||||
# See https://github.com/vllm-project/vllm/issues/5152
|
||||
- export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||
- pytest -v -s spec_decode
|
||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 30min each
|
||||
mirror_hardwares: [amd]
|
||||
|
||||
@ -1,224 +1,54 @@
|
||||
import asyncio
|
||||
import os
|
||||
from itertools import cycle
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, random_uuid
|
||||
|
||||
from ...conftest import cleanup
|
||||
from ...utils import wait_for_gpu_memory_to_clear
|
||||
from ...models.utils import check_logprobs_close, check_outputs_equal
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
|
||||
class AsyncLLM:
|
||||
"""AsyncLLM
|
||||
|
||||
Note: Current LLM class in vllm don't support async mode, for test purpose,
|
||||
we implement async one in here. Maybe we could move to
|
||||
vllm/entrypoints/llm.py in future.
|
||||
|
||||
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
|
||||
to make to work in async mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
|
||||
# Needed to engine_use_ray works as a deprecated feature,
|
||||
# otherwise the following constructor will raise an exception
|
||||
os.environ["VLLM_ALLOW_ENGINE_USE_RAY"] = "1"
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
revision=revision,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
# For now use ray for the distributed back-end, since
|
||||
# we rely on the use of engine_use_ray=True to avoid
|
||||
# reinitializing CUDA in the same process (driver worker)
|
||||
engine_use_ray=True,
|
||||
distributed_executor_backend="ray",
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
**kwargs,
|
||||
)
|
||||
self.request_counter = Counter()
|
||||
self.llm_engine = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalDataDict] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> List[RequestOutput]:
|
||||
|
||||
if prompts is None:
|
||||
raise ValueError("prompts must be provided.")
|
||||
if isinstance(prompts, str):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
elif isinstance(sampling_params,
|
||||
list) and len(sampling_params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and "
|
||||
"sampling_params must be the same.")
|
||||
|
||||
async def get_output(prompt, sampling_param) -> RequestOutput:
|
||||
request_id = random_uuid()
|
||||
results_generator = self.llm_engine.generate(
|
||||
prompt, sampling_param, request_id)
|
||||
final_output = None
|
||||
async for request_output in results_generator:
|
||||
final_output = request_output
|
||||
assert final_output is not None
|
||||
return final_output
|
||||
|
||||
outputs: List[RequestOutput] = []
|
||||
try:
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
params = sampling_params[i] if isinstance(
|
||||
sampling_params, Sequence) else sampling_params
|
||||
res = asyncio.run(get_output(prompt, params))
|
||||
outputs.append(res)
|
||||
finally:
|
||||
ray.shutdown()
|
||||
return outputs
|
||||
PROMPTS = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def baseline_llm_generator(request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
seed):
|
||||
return create_llm_generator("baseline", request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, seed)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
test_llm_kwargs, seed):
|
||||
return create_llm_generator("test", request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, test_llm_kwargs,
|
||||
seed)
|
||||
|
||||
|
||||
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, distinct_llm_kwargs,
|
||||
seed):
|
||||
def generate():
|
||||
kwargs = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**distinct_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
test_name = request.node.name
|
||||
|
||||
model = kwargs["model"]
|
||||
draft_model = kwargs.get("speculative_model", None)
|
||||
same_draft_target_model = (draft_model is not None
|
||||
and draft_model == model)
|
||||
llm = LLM(**kwargs)
|
||||
|
||||
def generator_inner():
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=list(range(torch.cuda.device_count())),
|
||||
threshold_bytes=2 * 2**30,
|
||||
timeout_s=60,
|
||||
)
|
||||
|
||||
use_async = False
|
||||
if "use_async" in kwargs:
|
||||
use_async = kwargs.pop("use_async")
|
||||
print(f'{use_async=}')
|
||||
|
||||
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
||||
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
|
||||
|
||||
# Override logging interval to 0 for spec decode test run to
|
||||
# log all metrics in time.
|
||||
if (baseline_or_test == "test" and not use_async
|
||||
and llm.llm_engine.log_stats):
|
||||
for sate_logger in llm.llm_engine.stat_loggers.values():
|
||||
sate_logger.local_interval = 0
|
||||
if seed is not None:
|
||||
set_random_seed(seed)
|
||||
|
||||
yield llm
|
||||
|
||||
del llm
|
||||
cleanup()
|
||||
|
||||
def generator_outer():
|
||||
for llm in generator_inner():
|
||||
yield llm
|
||||
del llm
|
||||
|
||||
# Set an attribute to the generator_outer function to allow us to
|
||||
# determine whether to further check the acceptance rate in tests.
|
||||
generator_outer.same_draft_target_model = same_draft_target_model # type: ignore
|
||||
return generator_outer
|
||||
return generate
|
||||
|
||||
|
||||
def maybe_assert_ngram_worker(llm):
|
||||
# Verify the proposer worker is ngram if ngram is specified.
|
||||
if (not isinstance(llm, AsyncLLM)
|
||||
and llm.llm_engine.speculative_config is not None
|
||||
if (llm.llm_engine.speculative_config is not None
|
||||
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
|
||||
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||
assert isinstance(
|
||||
@ -251,118 +81,165 @@ def get_output_from_llm_generator(
|
||||
return tokens, token_ids, acceptance_rate
|
||||
|
||||
|
||||
def get_logprobs_from_llm_generator(
|
||||
llm_generator, prompts,
|
||||
sampling_params) -> List[List[Dict[int, Logprob]]]:
|
||||
"""Returns a dict of (token_id: Logprob) for each generated position, for
|
||||
each sequence in the batch.
|
||||
"""
|
||||
for llm in llm_generator():
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||
logprobs = [output.outputs[0].logprobs[:] for output in outputs]
|
||||
del llm
|
||||
def run_logprob_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
max_output_len: int,
|
||||
seed: Optional[int] = 0,
|
||||
temperature: float = 0.0,
|
||||
logprobs: int = 1):
|
||||
org_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**baseline_llm_kwargs,
|
||||
}
|
||||
|
||||
return logprobs
|
||||
sd_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||
|
||||
sampling_params = SamplingParams(temperature=temperature,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
logprobs=logprobs)
|
||||
|
||||
with vllm_runner(**org_args) as vllm_model:
|
||||
org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
check_logprobs_close(outputs_0_lst=org_outputs,
|
||||
outputs_1_lst=sd_outputs,
|
||||
name_0="org",
|
||||
name_1="sd")
|
||||
|
||||
|
||||
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
print_tokens: bool = False,
|
||||
ensure_all_accepted: bool = False):
|
||||
def run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
max_output_len: int,
|
||||
seed: Optional[int] = 0,
|
||||
temperature: float = 0.0,
|
||||
disable_seed: bool = False,
|
||||
ignore_eos: bool = True,
|
||||
ensure_all_accepted: bool = False,
|
||||
expected_acceptance_rate: Optional[float] = None):
|
||||
|
||||
org_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**baseline_llm_kwargs,
|
||||
}
|
||||
|
||||
sd_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||
|
||||
if disable_seed:
|
||||
seed = None
|
||||
|
||||
sampling_params = SamplingParams(temperature=temperature,
|
||||
max_tokens=max_output_len,
|
||||
seed=seed,
|
||||
ignore_eos=ignore_eos)
|
||||
|
||||
with vllm_runner(**org_args) as vllm_model:
|
||||
org_outputs = vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
# Force log interval to be 0 to catch all metrics.
|
||||
stat_logger = vllm_model.model.llm_engine.stat_loggers[
|
||||
'prometheus']
|
||||
stat_logger.local_interval = -100
|
||||
|
||||
sd_outputs = vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
if ensure_all_accepted or expected_acceptance_rate is not None:
|
||||
acceptance_rate = (stat_logger.metrics.
|
||||
gauge_spec_decode_draft_acceptance_rate.labels(
|
||||
**stat_logger.labels)._value.get())
|
||||
|
||||
if ensure_all_accepted:
|
||||
assert True
|
||||
# FIXME: ci fails to log acceptance rate.
|
||||
# It works locally.
|
||||
# assert acceptance_rate == 1.0
|
||||
|
||||
if expected_acceptance_rate is not None:
|
||||
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
||||
|
||||
check_outputs_equal(outputs_0_lst=org_outputs,
|
||||
outputs_1_lst=sd_outputs,
|
||||
name_0="org",
|
||||
name_1="sd")
|
||||
|
||||
|
||||
def run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int,
|
||||
max_output_len: int,
|
||||
seed: int = 0,
|
||||
temperature: float = 0.0):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero.
|
||||
"""
|
||||
arg1 = common_llm_kwargs + per_test_common_llm_kwargs + baseline_llm_kwargs
|
||||
arg2 = common_llm_kwargs + per_test_common_llm_kwargs + test_llm_kwargs
|
||||
env1 = env2 = None
|
||||
|
||||
run_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len,
|
||||
temperature=0.0,
|
||||
seeded=False,
|
||||
print_tokens=print_tokens,
|
||||
ensure_all_accepted=ensure_all_accepted)
|
||||
max_wait_seconds = 240
|
||||
results = []
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
|
||||
|
||||
def run_equality_correctness_test(
|
||||
baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
temperature: float,
|
||||
seeded: bool,
|
||||
print_tokens: bool = False,
|
||||
ensure_all_accepted: bool = False,
|
||||
expected_acceptance_rate: Optional[float] = None):
|
||||
"""Helper method that compares the outputs of both the baseline LLM and
|
||||
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
||||
the same when temperature is zero (or when temperature is > 0 and seeded).
|
||||
"""
|
||||
for args, env in ((arg1, env1), (arg2, env2)):
|
||||
with RemoteOpenAIServer(model,
|
||||
args,
|
||||
env_dict=env,
|
||||
max_wait_seconds=max_wait_seconds) as server:
|
||||
client = server.get_client()
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
if seeded:
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
completion = client.completions.create(model=model,
|
||||
prompt=prompts,
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
seed=i,
|
||||
) for i in range(len(prompts))
|
||||
]
|
||||
else:
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
seed=seed,
|
||||
temperature=temperature)
|
||||
|
||||
(spec_batch_tokens, spec_batch_token_ids,
|
||||
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
|
||||
prompts, sampling_params)
|
||||
results.append({
|
||||
"test":
|
||||
"seeded_sampling",
|
||||
"text": [choice.text for choice in completion.choices],
|
||||
"finish_reason":
|
||||
[choice.finish_reason for choice in completion.choices],
|
||||
"usage":
|
||||
completion.usage,
|
||||
})
|
||||
|
||||
(baseline_batch_tokens, baseline_batch_token_ids,
|
||||
_) = get_output_from_llm_generator(baseline_llm_generator, prompts,
|
||||
sampling_params)
|
||||
|
||||
assert len(baseline_batch_token_ids) == len(prompts)
|
||||
assert len(spec_batch_token_ids) == len(prompts)
|
||||
|
||||
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
|
||||
spec_tokens) in enumerate(
|
||||
zip(baseline_batch_token_ids, baseline_batch_tokens,
|
||||
spec_batch_token_ids, spec_batch_tokens)):
|
||||
if print_tokens:
|
||||
print(f'{i=} {baseline_tokens=}')
|
||||
print(f'{i=} {spec_tokens=}')
|
||||
print(f'{i=} {baseline_token_ids=}')
|
||||
print(f'{i=} {spec_token_ids=}')
|
||||
assert baseline_token_ids == spec_token_ids
|
||||
|
||||
print(f'{acceptance_rate=}')
|
||||
|
||||
if ensure_all_accepted:
|
||||
assert acceptance_rate == 1.0
|
||||
|
||||
if expected_acceptance_rate is not None:
|
||||
assert acceptance_rate >= expected_acceptance_rate - 1e-2
|
||||
n = len(results) // 2
|
||||
arg1_results = results[:n]
|
||||
arg2_results = results[n:]
|
||||
for arg1_result, arg2_result in zip(arg1_results, arg2_results):
|
||||
assert arg1_result == arg2_result, (
|
||||
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
|
||||
f"{arg1_result=} != {arg2_result=}")
|
||||
|
||||
@ -21,7 +21,7 @@ correctess for the target model outputs.
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
@ -53,7 +53,7 @@ PRECISION = "float32"
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -68,15 +68,16 @@ PRECISION = "float32"
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -94,7 +95,7 @@ def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -109,17 +110,16 @@ def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
def test_eagle_e2e_greedy_correctness_cuda_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -140,7 +140,7 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -158,18 +158,17 @@ def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
def test_eagle_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -185,7 +184,7 @@ def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -207,16 +206,17 @@ def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_eagle_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that eagle speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -232,7 +232,7 @@ def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -250,17 +250,18 @@ def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_eagle_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that eagle speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
run_equality_correctness_test(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size, output_len, seed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -4,7 +4,9 @@ other features, e.g. cuda graphs.
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -15,7 +17,7 @@ from .conftest import run_greedy_equality_correctness_test
|
||||
|
||||
# Verify equality when cuda graphs allowed.
|
||||
"enforce_eager": False,
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs",
|
||||
@ -31,23 +33,27 @@ from .conftest import run_greedy_equality_correctness_test
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("output_len", [32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
|
||||
batch_size, output_len):
|
||||
def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int):
|
||||
"""Verify spec decode equality when cuda graphs are enabled.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(
|
||||
baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True,
|
||||
)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -80,13 +86,19 @@ def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_speculative_model_quantization_config(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int):
|
||||
def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size: int, seed: int):
|
||||
"""Verify spec decode works well with draft model quantization configs.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
@ -7,42 +7,39 @@ import torch
|
||||
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
from .conftest import run_equality_correctness_test_tp
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
"--enforce-eager",
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"tensor_parallel_size": 2,
|
||||
|
||||
# Use AsyncLLM engine, so that the engine runs in its own process.
|
||||
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
||||
# process will have both the engine and the rank0 worker. NCCL is not
|
||||
# cleaned up properly, and its server host thread leaks, causing the
|
||||
# second run of the test to fail with internal NCCL error.
|
||||
"use_async": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
"--use-v2-block-manager",
|
||||
"--tensor-parallel-size",
|
||||
"2"
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 3,
|
||||
},
|
||||
{
|
||||
"speculative_model": "[ngram]",
|
||||
"num_speculative_tokens": 5,
|
||||
"ngram_prompt_lookup_max": 3,
|
||||
},
|
||||
[
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num-speculative-tokens",
|
||||
"3",
|
||||
],
|
||||
[
|
||||
"--speculative-model",
|
||||
"[ngram]",
|
||||
"--num-speculative-tokens",
|
||||
"5",
|
||||
"--ngram-prompt-lookup-max",
|
||||
"3",
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
@ -52,75 +49,75 @@ from .conftest import run_greedy_equality_correctness_test
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int):
|
||||
"""Verify greedy equality when tensor parallelism is used.
|
||||
"""
|
||||
if is_hip():
|
||||
pytest.skip("hip is not well-supported yet")
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test_tp("JackFram/llama-68m",
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
"--enforce-eager",
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"tensor_parallel_size": 2,
|
||||
|
||||
# Use AsyncLLM engine, so that the engine runs in its own process.
|
||||
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
||||
# process will have both the engine and the rank0 worker. NCCL is not
|
||||
# cleaned up properly, and its server host thread leaks, causing the
|
||||
# second run of the test to fail with internal NCCL error.
|
||||
"use_async": True,
|
||||
"--use_v2_block_manager",
|
||||
"--tensor_parallel_size",
|
||||
"2",
|
||||
|
||||
# precision
|
||||
"dtype": "float32",
|
||||
}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize(
|
||||
"per_test_common_llm_kwargs, test_llm_kwargs",
|
||||
[
|
||||
(
|
||||
{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a
|
||||
# tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_draft_tensor_parallel_size": 1,
|
||||
}),
|
||||
({
|
||||
"model": "ibm-granite/granite-3b-code-instruct",
|
||||
}, {
|
||||
"speculative_model":
|
||||
"ibm-granite/granite-3b-code-instruct-accelerator",
|
||||
"num_speculative_tokens": 5,
|
||||
"speculative_draft_tensor_parallel_size": 1,
|
||||
})
|
||||
])
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("model, test_llm_kwargs",
|
||||
[("JackFram/llama-68m", [
|
||||
"--speculative-model",
|
||||
"JackFram/llama-68m",
|
||||
"--num_speculative-tokens",
|
||||
"5",
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
]),
|
||||
("ibm-granite/granite-3b-code-instruct", [
|
||||
"--speculative-model",
|
||||
"ibm-granite/granite-3b-code-instruct",
|
||||
"--num_speculative-tokens",
|
||||
"5",
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
])])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
|
||||
baseline_llm_generator,
|
||||
batch_size: int):
|
||||
def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
seed: int):
|
||||
"""Verify spec decode works well with smaller tp for draft models.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test_tp(model,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
@ -2,98 +2,97 @@
|
||||
tensor parallelism.
|
||||
"""
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
from .conftest import run_equality_correctness_test_tp
|
||||
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
SPEC_MODEL = "JackFram/llama-68m"
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
[[
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
"--enforce_eager",
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"tensor_parallel_size": 4,
|
||||
|
||||
# Use AsyncLLM engine, so that the engine runs in its own process.
|
||||
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
||||
# process will have both the engine and the rank0 worker. NCCL is not
|
||||
# cleaned up properly, and its server host thread leaks, causing the
|
||||
# second run of the test to fail with internal NCCL error.
|
||||
"use_async": True,
|
||||
}])
|
||||
"--use-v2-block-manager",
|
||||
"--tensor-parallel-size",
|
||||
"4",
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
[
|
||||
"--speculative-model",
|
||||
f"{SPEC_MODEL}",
|
||||
"--num-speculative-tokens",
|
||||
"5",
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
#TODO(wooyeon): add spec_draft_dp=2 case
|
||||
{
|
||||
"speculative_draft_tensor_parallel_size": 1,
|
||||
},
|
||||
[
|
||||
"--speculative-draft-tensor-parallel-size",
|
||||
"1",
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
|
||||
baseline_llm_generator,
|
||||
batch_size: int):
|
||||
def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
seed: int):
|
||||
"""Verify spec decode works well with smaller tp for draft models.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test_tp(MAIN_MODEL,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 4,
|
||||
reason="Need at least 4 GPUs to run the test.")
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
[[
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
"--enforce-eager",
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"tensor_parallel_size": 4,
|
||||
|
||||
# Use AsyncLLM engine, so that the engine runs in its own process.
|
||||
# Otherwise, since vLLM does not follow true SPMD, the test runner
|
||||
# process will have both the engine and the rank0 worker. NCCL is not
|
||||
# cleaned up properly, and its server host thread leaks, causing the
|
||||
# second run of the test to fail with internal NCCL error.
|
||||
"use_async": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
"--use-v2-block-manager",
|
||||
"--tensor-parallel-size",
|
||||
"4",
|
||||
]])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [[]])
|
||||
@pytest.mark.parametrize(
|
||||
"test_llm_kwargs",
|
||||
[
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
[
|
||||
"--speculative-model",
|
||||
f"{SPEC_MODEL}",
|
||||
"--num-speculative-tokens",
|
||||
"5",
|
||||
|
||||
# Artificially limit the draft model max model len; this forces vLLM
|
||||
# to skip speculation once the sequences grow beyond 32-k tokens.
|
||||
"speculative_max_model_len": 32,
|
||||
},
|
||||
"--speculative-max-model-len",
|
||||
"32",
|
||||
],
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize(
|
||||
@ -105,8 +104,9 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_skip_speculation(common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int):
|
||||
"""Verify job failure with RuntimeError when all sequences skip speculation.
|
||||
We do this by setting the max model len of the draft model to an
|
||||
artificially low value, such that when the sequences grow beyond it, they
|
||||
@ -114,9 +114,13 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
||||
|
||||
TODO: fix it to pass without raising Error. (#5814)
|
||||
"""
|
||||
with pytest.raises(RuntimeError):
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
with pytest.raises(openai.APIConnectionError):
|
||||
run_equality_correctness_test_tp(MAIN_MODEL,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0)
|
||||
|
||||
@ -1,24 +1,22 @@
|
||||
import math
|
||||
from itertools import cycle
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from .conftest import get_logprobs_from_llm_generator
|
||||
from .conftest import run_logprob_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"max_logprobs": 6,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -36,64 +34,29 @@ from .conftest import get_logprobs_from_llm_generator
|
||||
7,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify output logprobs are equal with and without speculative decoding.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_logprob_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"max_logprobs": 6,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("test_llm_kwargs",
|
||||
[{
|
||||
"speculative_model": "JackFram/llama-160m",
|
||||
"num_speculative_tokens": 3,
|
||||
"disable_logprobs_during_spec_decoding": False,
|
||||
}])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize("num_logprobs", [6])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
7,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int,
|
||||
num_logprobs: int):
|
||||
"""Verify output logprobs are equal with and without spec decode.
|
||||
This specifies a number of logprobs >1.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True,
|
||||
logprob_rank=num_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -121,21 +84,29 @@ def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
@pytest.mark.parametrize("logprobs", [1, 6])
|
||||
def test_logprobs_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int, logprobs: int):
|
||||
"""Veriy logprob greedy equality with different speculation lens.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_logprob_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -164,22 +135,30 @@ def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_when_skip_speculation(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
@pytest.mark.parametrize("logprobs", [1])
|
||||
def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify logprobs greedy equality when some sequences skip speculation.
|
||||
"""
|
||||
run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_logprob_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=logprobs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -203,19 +182,17 @@ def test_logprobs_when_skip_speculation(baseline_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
@pytest.mark.parametrize("logprobs", [6])
|
||||
def test_logprobs_temp_1(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
"""Verify at least one logprob result has num_logprobs+1, which tests the
|
||||
case where the sampled token is not in top-k logprobs.
|
||||
|
||||
Ideally, this test should validate equality with non-spec by getting
|
||||
logprobs. This is left as future improvement.
|
||||
"""
|
||||
batch_size = 8
|
||||
max_output_len = output_len
|
||||
force_output_len = True
|
||||
logprob_rank = 5
|
||||
|
||||
temperature = 1.0
|
||||
|
||||
prompts = [
|
||||
@ -231,129 +208,40 @@ def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
temperature=temperature,
|
||||
logprobs=logprob_rank,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
spec_batch_logprobs = get_logprobs_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
sd_args = {
|
||||
**common_llm_kwargs,
|
||||
**per_test_common_llm_kwargs,
|
||||
**test_llm_kwargs,
|
||||
}
|
||||
|
||||
with vllm_runner(**sd_args) as vllm_model:
|
||||
sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params)
|
||||
|
||||
num_returned_logprobs = [
|
||||
len(logprob_dict) for seq_logprobs in spec_batch_logprobs
|
||||
for logprob_dict in seq_logprobs
|
||||
len(seq_logprobs) for seq_logprobs in sd_outputs[-1]
|
||||
]
|
||||
|
||||
# Assert one of the returned logprobs has > num_logprobs (indicating the
|
||||
# sampled token is not in top-k).
|
||||
assert any([
|
||||
num_returned > logprob_rank for num_returned in num_returned_logprobs
|
||||
])
|
||||
|
||||
|
||||
def run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len: bool,
|
||||
logprob_rank: int = 1):
|
||||
"""Helper method that compares the logprobs outputs of both the baseline LLM
|
||||
and the test LLM. It asserts greedy equality of the logprobs when the
|
||||
temperature is zero.
|
||||
"""
|
||||
temperature = 0.0
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
||||
|
||||
# If the test requires that we generated max_output_len tokens, then set the
|
||||
# sampling params to ignore eos token.
|
||||
ignore_eos = force_output_len
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
temperature=temperature,
|
||||
logprobs=logprob_rank,
|
||||
)
|
||||
|
||||
spec_batch_logprobs = get_logprobs_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
baseline_batch_logprobs = get_logprobs_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
assert len(baseline_batch_logprobs) == len(prompts)
|
||||
assert len(spec_batch_logprobs) == len(prompts)
|
||||
|
||||
# For each sequence in the batch.
|
||||
for i, (baseline_logprobs, spec_logprobs) in enumerate(
|
||||
zip(baseline_batch_logprobs, spec_batch_logprobs)):
|
||||
assert len(spec_logprobs) == len(baseline_logprobs)
|
||||
|
||||
# For each generated position of the sequence.
|
||||
for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
||||
zip(spec_logprobs, baseline_logprobs)):
|
||||
|
||||
# Map rank to token/logprob in spec output.
|
||||
spec_rank_to_token_id = {
|
||||
value.rank: key
|
||||
for key, value in spec_pos_logprobs.items()
|
||||
}
|
||||
spec_rank_to_logprob = {
|
||||
value.rank: value.logprob
|
||||
for key, value in spec_pos_logprobs.items()
|
||||
}
|
||||
|
||||
# Map rank to token/logprob in baseline output.
|
||||
baseline_rank_to_token_id = {
|
||||
value.rank: key
|
||||
for key, value in baseline_pos_logprobs.items()
|
||||
}
|
||||
baseline_rank_to_logprob = {
|
||||
value.rank: value.logprob
|
||||
for key, value in baseline_pos_logprobs.items()
|
||||
}
|
||||
|
||||
# Assert set of ranks returned is equal.
|
||||
assert set(spec_rank_to_token_id.keys()) == set(
|
||||
baseline_rank_to_token_id.keys())
|
||||
|
||||
# Assert each logprob/token id is correct, keyed by rank.
|
||||
for rank in sorted(set(spec_rank_to_token_id.keys())):
|
||||
assert spec_rank_to_token_id[
|
||||
rank] == baseline_rank_to_token_id[rank], f"{rank}"
|
||||
assert math.isclose(
|
||||
a=spec_rank_to_logprob[rank],
|
||||
b=baseline_rank_to_logprob[rank],
|
||||
abs_tol=1e-1,
|
||||
)
|
||||
assert any(
|
||||
[num_returned > logprobs for num_returned in num_returned_logprobs])
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
"max_logprobs": 6,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -364,57 +252,28 @@ def run_greedy_logprobs_correctness_test(baseline_llm_generator,
|
||||
"disable_logprobs_during_spec_decoding": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_logprobs_disabled(baseline_llm_generator, test_llm_generator):
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize(
|
||||
"output_len",
|
||||
[
|
||||
# Use smaller output len for fast test.
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("logprobs", [0])
|
||||
def test_logprobs_disabled(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int, logprobs: int):
|
||||
"""Check the behavior when logprobs are disabled.
|
||||
Token choices should match with the base model.
|
||||
"""
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
"San Francisco is know for its",
|
||||
"Facebook was created in 2004 by",
|
||||
"Curious George is a",
|
||||
"Python 3.11 brings improvements to its",
|
||||
]
|
||||
|
||||
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(4))]
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
# Use smaller output len for fast test
|
||||
max_tokens=7,
|
||||
ignore_eos=True,
|
||||
run_logprob_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
output_len,
|
||||
seed,
|
||||
temperature=0.0,
|
||||
logprobs=2,
|
||||
)
|
||||
|
||||
spec_batch_logprobs = get_logprobs_from_llm_generator(
|
||||
test_llm_generator, prompts, sampling_params)
|
||||
baseline_batch_logprobs = get_logprobs_from_llm_generator(
|
||||
baseline_llm_generator, prompts, sampling_params)
|
||||
|
||||
assert len(baseline_batch_logprobs) == len(prompts)
|
||||
assert len(spec_batch_logprobs) == len(prompts)
|
||||
|
||||
# For each sequence in the batch.
|
||||
for _, (baseline_logprobs, spec_logprobs) in enumerate(
|
||||
zip(baseline_batch_logprobs, spec_batch_logprobs)):
|
||||
assert len(spec_logprobs) == len(baseline_logprobs)
|
||||
|
||||
# For each generated position of the sequence.
|
||||
for _, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate(
|
||||
zip(spec_logprobs, baseline_logprobs)):
|
||||
|
||||
assert len(spec_pos_logprobs) == 1
|
||||
spec_top_token_id = list(spec_pos_logprobs)[0]
|
||||
|
||||
spec_top_logprob = spec_pos_logprobs[spec_top_token_id]
|
||||
assert spec_top_logprob.logprob == 0.0
|
||||
assert spec_top_logprob.rank == -1
|
||||
|
||||
# check that the chosen token matches the base model
|
||||
baseline_logprob = baseline_pos_logprobs[spec_top_token_id]
|
||||
assert baseline_logprob.rank == 1
|
||||
assert spec_top_logprob.decoded_token \
|
||||
== baseline_logprob.decoded_token
|
||||
logprobs=logprobs)
|
||||
|
||||
@ -21,7 +21,7 @@ correctess for the target model outputs.
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
|
||||
@ -55,7 +55,7 @@ PRECISION = "float32"
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -70,15 +70,21 @@ PRECISION = "float32"
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -96,7 +102,7 @@ def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -111,17 +117,21 @@ def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
def test_medusa_e2e_greedy_correctness_cuda_graph(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with cuda graph enabled and different
|
||||
batch sizes."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -142,7 +152,7 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -160,18 +170,22 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
def test_medusa_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -187,7 +201,7 @@ def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -209,16 +223,22 @@ def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -234,7 +254,7 @@ def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -252,17 +272,23 @@ def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_medusa_disable_queue(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that medusa speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -25,8 +25,7 @@ import pytest
|
||||
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
|
||||
|
||||
from .conftest import (run_equality_correctness_test,
|
||||
run_greedy_equality_correctness_test)
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "JackFram/llama-160m"
|
||||
@ -58,7 +57,7 @@ PRECISION = "float32"
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -72,14 +71,21 @@ PRECISION = "float32"
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with different batch size."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -98,7 +104,7 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -110,17 +116,21 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
||||
@pytest.mark.parametrize("output_len", [2048])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int, seed: int):
|
||||
"""Verify acceptance rate with different batch size and large output
|
||||
length."""
|
||||
run_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=0.0,
|
||||
seeded=True,
|
||||
force_output_len=True,
|
||||
seed=seed,
|
||||
expected_acceptance_rate=0.48)
|
||||
|
||||
|
||||
@ -140,7 +150,7 @@ def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
|
||||
# Speculative model
|
||||
"speculative_model": SPEC_MODEL,
|
||||
@ -151,28 +161,35 @@ def test_mlp_e2e_acceptance_rate(baseline_llm_generator, test_llm_generator,
|
||||
@pytest.mark.parametrize("output_len", [64])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("temperature", [0.1, 1.0])
|
||||
@pytest.mark.parametrize("seed", [None])
|
||||
def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
temperature: float):
|
||||
temperature: float, seed: int):
|
||||
"""Verify seeded runs produce the same output."""
|
||||
run_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seeded=True,
|
||||
force_output_len=True)
|
||||
seed=seed)
|
||||
|
||||
# Ensure this same test does fail if we _don't_ include per-request seeds
|
||||
with pytest.raises(AssertionError):
|
||||
run_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seeded=False,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
disable_seed=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -193,7 +210,7 @@ def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -210,18 +227,22 @@ def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
def test_mlp_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -242,7 +263,7 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -259,10 +280,10 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
def test_mlp_e2e_greedy_correctness_with_padding(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality when the vocab dimension is padded
|
||||
"""
|
||||
|
||||
@ -273,11 +294,15 @@ def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
|
||||
with patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
|
||||
patched_pad_vocab_size):
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -293,7 +318,7 @@ def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -315,16 +340,22 @@ def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, seed: int,
|
||||
output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode with different values of num_speculative_tokens.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -340,7 +371,7 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
|
||||
"dtype": PRECISION,
|
||||
|
||||
# Main model
|
||||
"model": MAIN_MODEL,
|
||||
"model_name": MAIN_MODEL,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -357,14 +388,20 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, seed: int,
|
||||
output_len: int):
|
||||
"""Verify that mlp speculative decoding produces exact equality
|
||||
to without spec decode when speculation is disabled for large
|
||||
batch sizes.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
@ -41,8 +41,9 @@ from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ...utils import fork_new_process_for_each_test
|
||||
from .conftest import (get_output_from_llm_generator,
|
||||
run_greedy_equality_correctness_test)
|
||||
run_equality_correctness_test)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -73,6 +74,7 @@ from .conftest import (get_output_from_llm_generator,
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@fork_new_process_for_each_test
|
||||
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
batch_size: int):
|
||||
"""Run generation with speculative decoding on a batch. Verify the engine
|
||||
@ -116,44 +118,6 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
|
||||
assert actual_tokens.strip() == expected_tokens.strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# Use a small model for a fast test.
|
||||
# Note this is repeated in the test body; to initialize a tokenizer.
|
||||
"model": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
|
||||
# Required for spec decode.
|
||||
"use_v2_block_manager": True,
|
||||
|
||||
# Use AsyncLLM engine
|
||||
"use_async": True,
|
||||
}])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"speculative_model": "JackFram/llama-68m",
|
||||
"num_speculative_tokens": 5,
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||
@pytest.mark.parametrize("batch_size", [2])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_e2e_with_async_engine(test_llm_generator,
|
||||
baseline_llm_generator,
|
||||
batch_size: int):
|
||||
"""Verify spec decode works well with async LLM engine.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size,
|
||||
max_output_len=32,
|
||||
force_output_len=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
@ -172,10 +136,10 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -189,13 +153,15 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
|
||||
"output_len",
|
||||
[
|
||||
# Use long output len for the small model test.
|
||||
1536,
|
||||
10,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@fork_new_process_for_each_test
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality on a tiny model with batch size of one.
|
||||
|
||||
Since this test is cheaper than other e2e correctness tests, we generate
|
||||
@ -204,13 +170,17 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
When the draft model is the same as the target model, we further check
|
||||
whether all speculative tokens are accepted.
|
||||
"""
|
||||
ensure_all_accepted = test_llm_generator.same_draft_target_model
|
||||
run_greedy_equality_correctness_test(
|
||||
baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
ensure_all_accepted = per_test_common_llm_kwargs.get(
|
||||
"model_name") == test_llm_kwargs.get("speculative_model")
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True,
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
ensure_all_accepted=ensure_all_accepted)
|
||||
|
||||
|
||||
@ -232,10 +202,10 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -253,16 +223,22 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [64])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@fork_new_process_for_each_test
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality on a tiny model and large batch size.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -280,10 +256,10 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
# Try two different tiny base models.
|
||||
# Note that one is equal to the draft model, another isn't.
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -298,24 +274,31 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@fork_new_process_for_each_test
|
||||
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
max_output_len: int):
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
max_output_len: int, seed: int):
|
||||
"""Verify greedy equality on a tiny model, with a large batch size, and when
|
||||
sampling respects the EOS token.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len,
|
||||
force_output_len=False)
|
||||
seed=seed,
|
||||
temperature=0.0,
|
||||
ignore_eos=False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# A "real" model (not tiny).
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -342,24 +325,30 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
|
||||
256,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@fork_new_process_for_each_test
|
||||
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality on a "real" model and batch size of 1. This is
|
||||
separate from large BS tests to make identifying the source of bugs easier.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
# A "real" model (not tiny).
|
||||
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"model_name": "meta-llama/Llama-2-7b-chat-hf",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -386,17 +375,23 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@fork_new_process_for_each_test
|
||||
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality with a "real" model on a nontrivial batch size.
|
||||
This is the closest test to a real production workload.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -415,7 +410,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -433,23 +428,29 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
@fork_new_process_for_each_test
|
||||
def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||
baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -487,22 +488,29 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_spec_decode_different_block_size(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
@fork_new_process_for_each_test
|
||||
def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality over different block sizes.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -534,24 +542,31 @@ def test_spec_decode_different_block_size(baseline_llm_generator,
|
||||
64,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
@fork_new_process_for_each_test
|
||||
def test_skip_speculation(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality when some (or all) sequences skip speculation.
|
||||
We do this by setting the max model len of the draft model to an
|
||||
artificially low value, such that when the sequences grow beyond it, they
|
||||
are skipped in speculative decoding.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -571,21 +586,28 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("output_len", [10])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_disable_speculation(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
@fork_new_process_for_each_test
|
||||
def test_disable_speculation(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality when all sequences disable speculation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -613,22 +635,28 @@ def test_disable_speculation(baseline_llm_generator, test_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
@fork_new_process_for_each_test
|
||||
def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
|
||||
output_len: int, seed: int):
|
||||
"""Verify that speculative decoding produces exact equality to without spec
|
||||
decode with many different values of k.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-160m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -657,15 +685,22 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_typical_acceptance_sampling(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
@fork_new_process_for_each_test
|
||||
def test_typical_acceptance_sampling(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that speculative decoding produces exact equality to without spec
|
||||
decode with TypicalAcceptanceSampler as the draft token acceptance
|
||||
sampling method.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
@ -26,7 +26,7 @@ for the target model outputs.
|
||||
|
||||
import pytest
|
||||
|
||||
from .conftest import run_greedy_equality_correctness_test
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -43,7 +43,7 @@ from .conftest import run_greedy_equality_correctness_test
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -59,15 +59,21 @@ from .conftest import run_greedy_equality_correctness_test
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
||||
test_llm_generator, batch_size: int,
|
||||
output_len: int):
|
||||
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs,
|
||||
batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality on a tiny model with different batch size."""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -86,7 +92,7 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
||||
}])
|
||||
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||
{
|
||||
"model": "JackFram/llama-160m",
|
||||
"model_name": "JackFram/llama-160m",
|
||||
},
|
||||
])
|
||||
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||
@ -105,24 +111,28 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
|
||||
])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
batch_size: int,
|
||||
output_len: int):
|
||||
def test_ngram_e2e_greedy_correctness_with_preemption(
|
||||
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||
generation.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
temperature=0,
|
||||
seed=seed)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -159,23 +169,29 @@ def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_ngram_different_k(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -200,14 +216,20 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
|
||||
32,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [1])
|
||||
def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, output_len: int):
|
||||
def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int, output_len: int,
|
||||
seed: int):
|
||||
"""Verify that ngram speculative decoding produces exact equality
|
||||
to without spec decode with many different values of k and
|
||||
different ngram_prompt_lookup_max.
|
||||
"""
|
||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
force_output_len=True)
|
||||
seed=seed,
|
||||
temperature=0.0)
|
||||
|
||||
@ -2,11 +2,17 @@ import pytest
|
||||
|
||||
from .conftest import run_equality_correctness_test
|
||||
|
||||
# main model
|
||||
MAIN_MODEL = "JackFram/llama-68m"
|
||||
|
||||
# speculative model
|
||||
SPEC_MODEL = "JackFram/llama-160m"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"common_llm_kwargs",
|
||||
[{
|
||||
"model": "JackFram/llama-68m",
|
||||
"model_name": "JackFram/llama-68m",
|
||||
|
||||
# Skip cuda graph recording for fast test.
|
||||
"enforce_eager": True,
|
||||
@ -31,26 +37,34 @@ from .conftest import run_equality_correctness_test
|
||||
# Use smaller output len for fast test.
|
||||
20,
|
||||
])
|
||||
@pytest.mark.parametrize("seed", [None])
|
||||
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,
|
||||
batch_size: int, temperature: float,
|
||||
output_len: int):
|
||||
def test_seeded_consistency(vllm_runner, common_llm_kwargs,
|
||||
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
||||
test_llm_kwargs, batch_size: int,
|
||||
temperature: float, output_len: int):
|
||||
"""Verify outputs are consistent across multiple runs with same seed
|
||||
"""
|
||||
run_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seeded=True,
|
||||
force_output_len=True)
|
||||
disable_seed=False,
|
||||
)
|
||||
|
||||
# Ensure this same test does fail if we _don't_ include per-request seeds
|
||||
with pytest.raises(AssertionError):
|
||||
run_equality_correctness_test(baseline_llm_generator,
|
||||
test_llm_generator,
|
||||
run_equality_correctness_test(
|
||||
vllm_runner,
|
||||
common_llm_kwargs,
|
||||
per_test_common_llm_kwargs,
|
||||
baseline_llm_kwargs,
|
||||
test_llm_kwargs,
|
||||
batch_size,
|
||||
max_output_len=output_len,
|
||||
temperature=temperature,
|
||||
seeded=False,
|
||||
force_output_len=True)
|
||||
disable_seed=True,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user