diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh index a67fc89d54e6..897f84d1e360 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh @@ -2,7 +2,7 @@ # We can use this script to compute baseline accuracy on GSM for transformers. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.4 +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] usage() { echo`` diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index b98d42aa7b82..792f355c47a5 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -3,7 +3,7 @@ # We use this for fp8, which HF does not support. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.4 +# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] usage() { echo`` diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 4f40f32a39f2..f16485732504 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -71,7 +71,7 @@ COPY --from=build_vllm ${COMMON_WORKDIR}/vllm /vllm-workspace RUN cd /vllm-workspace \ && rm -rf vllm \ && python3 -m pip install -e tests/vllm_test_utils \ - && python3 -m pip install lm-eval[api]==0.4.4 \ + && python3 -m pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] \ && python3 -m pip install pytest-shard # ----------------------- diff --git a/docs/features/quantization/fp8.md b/docs/features/quantization/fp8.md index 0661933acd61..834c03cbe05b 100644 --- a/docs/features/quantization/fp8.md +++ b/docs/features/quantization/fp8.md @@ -79,7 +79,7 @@ Since simple RTN does not require data for weight quantization and the activatio Install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` Load and run the model in `vllm`: diff --git a/docs/features/quantization/int4.md b/docs/features/quantization/int4.md index 127e40398994..d6fdac7b07f7 100644 --- a/docs/features/quantization/int4.md +++ b/docs/features/quantization/int4.md @@ -18,7 +18,7 @@ pip install llmcompressor Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/int8.md b/docs/features/quantization/int8.md index 45fae58a6486..247d0cbdd3f1 100644 --- a/docs/features/quantization/int8.md +++ b/docs/features/quantization/int8.md @@ -19,7 +19,7 @@ pip install llmcompressor Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index e8ed2155375d..047cc8382445 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -20,7 +20,7 @@ for more installation details. Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: ```bash -pip install vllm lm-eval==0.4.4 +pip install vllm git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d#egg=lm-eval[api] ``` ## Quantization Process diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 184c30891eca..c4972f02d0f8 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -5,6 +5,7 @@ from transformers import AutoTokenizer from vllm import LLM, SamplingParams from vllm.benchmarks.datasets import add_dataset_parser, get_samples +from vllm.inputs import TokensPrompt from vllm.v1.metrics.reader import Counter, Vector try: @@ -137,7 +138,8 @@ def main(): sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if not args.custom_mm_prompts: outputs = llm.generate( - prompt_token_ids=prompt_ids, sampling_params=sampling_params + TokensPrompt(prompt_token_ids=prompt_ids), + sampling_params=sampling_params, ) else: outputs = llm.chat(prompts, sampling_params=sampling_params) diff --git a/examples/offline_inference/structured_outputs.py b/examples/offline_inference/structured_outputs.py index f46064931dba..88d87beb4874 100644 --- a/examples/offline_inference/structured_outputs.py +++ b/examples/offline_inference/structured_outputs.py @@ -85,7 +85,7 @@ def format_output(title: str, output: str): def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM): - outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + outputs = llm.generate(prompt, sampling_params=sampling_params) return outputs[0].outputs[0].text diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 491fa0625963..a529bf4504e4 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -27,7 +27,7 @@ mistral_common[image,audio] >= 1.8.2 # required for voxtral test num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test -lm-eval[api]==0.4.8 # required for model evaluation test +lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test mteb>=1.38.11, <2 # required for mteb test transformers==4.52.4 tokenizers==0.21.1 diff --git a/requirements/test.in b/requirements/test.in index 7f141fe281d6..098a9242bc3a 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -32,7 +32,8 @@ num2words # required for smolvlm test open_clip_torch==2.32.0 # Required for nemotron_vl test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test -lm-eval[api]==0.4.8 # required for model evaluation test +# TODO: Use lm-eval[api]==0.4.10 once released +lm-eval[api] @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # required for model evaluation test mteb[bm25s]>=1.38.11, <2 # required for mteb test transformers==4.55.2 tokenizers==0.21.1 diff --git a/requirements/test.txt b/requirements/test.txt index 48eb09811bcc..85b677c00b1d 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -408,7 +408,7 @@ lightning-utilities==0.14.3 # torchmetrics llvmlite==0.44.0 # via numba -lm-eval==0.4.8 +lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@206b7722158f58c35b7ffcd53b035fdbdda5126d # via -r requirements/test.in lxml==5.3.0 # via diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 97cf3b5ce8fc..2cbfed98a577 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -18,10 +18,9 @@ def text_llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() @@ -88,10 +87,9 @@ def vision_llm(): seed=0, ) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() @@ -158,10 +156,9 @@ def thinking_llm(): seed=0, ) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_classify.py b/tests/entrypoints/llm/test_classify.py index 71e76abcb7d2..57705ff66907 100644 --- a/tests/entrypoints/llm/test_classify.py +++ b/tests/entrypoints/llm/test_classify.py @@ -35,10 +35,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_embedding.py b/tests/entrypoints/llm/test_embedding.py index ba20d7b9548e..485f04ed6d84 100644 --- a/tests/entrypoints/llm/test_embedding.py +++ b/tests/entrypoints/llm/test_embedding.py @@ -26,10 +26,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index b930f05bebd0..cb54b16b0b04 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -5,11 +5,9 @@ import weakref import pytest -from vllm import LLM, PoolingParams, PoolingRequestOutput +from vllm import LLM, PoolingParams from vllm.distributed import cleanup_dist_env_and_memory -from ...models.utils import check_embeddings_close - MODEL_NAME = "intfloat/multilingual-e5-small" PROMPTS = [ @@ -48,57 +46,13 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() -def assert_outputs_match(o1: list[PoolingRequestOutput], - o2: list[PoolingRequestOutput]): - check_embeddings_close( - embeddings_0_lst=[o.outputs.data for o in o1], - embeddings_1_lst=[o.outputs.data for o in o2], - name_0="hf", - name_1="vllm", - tol=1e-2, - ) - - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) -def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, - prompt_token_ids): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.encode(prompt_token_ids=prompt_token_ids, - pooling_params=pooling_params) - - v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, - pooling_params=pooling_params) - assert_outputs_match(v1_output, v2_output) - - -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, - pooling_params=pooling_params) - - v2_output = llm.encode( - [{ - "prompt_token_ids": p - } for p in TOKEN_IDS], - pooling_params=pooling_params, - ) - assert_outputs_match(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_multiple_pooling_params(llm: LLM): pooling_params = [ diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 707891f6bdd8..3bbbcc755d13 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -5,7 +5,7 @@ import weakref import pytest -from vllm import LLM, RequestOutput, SamplingParams +from vllm import LLM, SamplingParams from vllm.distributed import cleanup_dist_env_and_memory MODEL_NAME = "distilbert/distilgpt2" @@ -41,50 +41,13 @@ def llm(): gpu_memory_utilization=0.10, enforce_eager=True) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() -def assert_outputs_equal(o1: list[RequestOutput], o2: list[RequestOutput]): - assert [o.outputs for o in o1] == [o.outputs for o in o2] - - -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) -def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, - prompt_token_ids): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.generate(prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params) - - v2_output = llm.generate({"prompt_token_ids": prompt_token_ids}, - sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): - v1_output = llm.generate(prompt_token_ids=TOKEN_IDS, - sampling_params=sampling_params) - - v2_output = llm.generate( - [{ - "prompt_token_ids": p - } for p in TOKEN_IDS], - sampling_params=sampling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_multiple_sampling_params(llm: LLM): sampling_params = [ diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py index b7d53e31fd71..a04f195692e9 100644 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ b/tests/entrypoints/llm/test_generate_multiple_loras.py @@ -48,10 +48,9 @@ def llm(request, monkeypatch_module): max_num_seqs=128, enforce_eager=True) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_reward.py b/tests/entrypoints/llm/test_reward.py index 361e2d0e1047..de82cf8d4038 100644 --- a/tests/entrypoints/llm/test_reward.py +++ b/tests/entrypoints/llm/test_reward.py @@ -36,10 +36,9 @@ def llm(): trust_remote_code=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/llm/test_score.py b/tests/entrypoints/llm/test_score.py index dd4eae0ccc06..5a1339b2addf 100644 --- a/tests/entrypoints/llm/test_score.py +++ b/tests/entrypoints/llm/test_score.py @@ -33,10 +33,9 @@ def llm(): enforce_eager=True, seed=0) - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) + yield weakref.proxy(llm) - del llm + del llm cleanup_dist_env_and_memory() diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 0b37c83c92c2..d781f462b4ad 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -38,8 +38,7 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool, with vllm_runner(model_id) as llm: # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy - outputs = llm.generate_greedy(prompts=["Hello my name is"], - max_tokens=10) + outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10) print(outputs[0][1]) @@ -90,8 +89,7 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, # note: this does not test accuracy, just that we can run through # see lm-eval tests for accuracy - outputs = llm.generate_greedy(prompts=["Hello my name is"], - max_tokens=10) + outputs = llm.generate_greedy(["Hello my name is"], max_tokens=10) print(outputs[0][1]) diff --git a/tests/quantization/test_lm_head.py b/tests/quantization/test_lm_head.py index 5ec8b27c1571..b24964a9d0a9 100644 --- a/tests/quantization/test_lm_head.py +++ b/tests/quantization/test_lm_head.py @@ -46,5 +46,5 @@ def test_lm_head( vllm_model.apply_model(check_model) print( - vllm_model.generate_greedy(prompts=["Hello my name is"], + vllm_model.generate_greedy(["Hello my name is"], max_tokens=10)[0][1]) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 8bddfb0b48a5..58b6297762d3 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -127,13 +127,15 @@ def test_structured_output( temperature=1.0, max_tokens=4096, guided_decoding=GuidedDecodingParams(json=sample_json_schema)) - outputs = llm.generate(prompts=[ - (f"Give an example JSON for an employee profile that fits this " - f"schema. Make the response as short as possible. Schema: " - f"{sample_json_schema}") - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + + prompt = ("Give an example JSON for an employee profile that fits this " + "schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}") + outputs = llm.generate( + [prompt] * 2, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -191,20 +193,24 @@ def test_structured_output( with pytest.raises(ValueError, match="The provided JSON schema contains features " "not supported by xgrammar."): + + prompt = (f"Give an example JSON for an employee profile that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible.") llm.generate( - prompts=[(f"Give an example JSON for an employee profile that " - f"fits this schema: {unsupported_json_schema}. " - f"Make the response as short as possible.")] * 2, + [prompt] * 2, sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) else: - outputs = llm.generate(prompts=( - "Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}. Make the response as short as " - "possible."), - sampling_params=sampling_params, - use_tqdm=True) + prompt = (f"Give an example JSON object for a grade that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible.") + outputs = llm.generate( + prompt, + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None for output in outputs: assert output is not None @@ -227,10 +233,9 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) outputs = llm.generate( - prompts=( - "Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -261,10 +266,9 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) outputs = llm.generate( - prompts=( - "Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), + ("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -301,7 +305,6 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(grammar="not a grammar")) with pytest.raises(ValueError, match="Failed to convert the grammar "): llm.generate( - prompts= ("Generate a sql statement that selects col_1 from " "table_1 where it is equal to 1. Make the response as short " "as possible."), @@ -316,11 +319,11 @@ def test_structured_output( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams(regex=sample_regex)) + + prompt = (f"Give an example IPv4 address with this regex: {sample_regex}. " + f"Make the response as short as possible.") outputs = llm.generate( - prompts=[ - (f"Give an example IPv4 address with this regex: {sample_regex}. " - f"Make the response as short as possible.") - ] * 2, + [prompt] * 2, sampling_params=sampling_params, use_tqdm=True, ) @@ -343,11 +346,13 @@ def test_structured_output( temperature=0.8, top_p=0.95, guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) + outputs = llm.generate( - prompts=("The best language for type-safe systems programming is " - "(Make the response as short as possible.) "), + ("The best language for type-safe systems programming is " + "(Make the response as short as possible.) "), sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) assert outputs is not None for output in outputs: assert output is not None @@ -367,12 +372,14 @@ def test_structured_output( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams(json=json_schema)) - outputs = llm.generate(prompts=( - "Generate a JSON with the brand, model and car_type of the most " - "iconic car from the 90's. Make the response as short as " - "possible."), - sampling_params=sampling_params, - use_tqdm=True) + + outputs = llm.generate( + ("Generate a JSON with the brand, model and car_type of the most " + "iconic car from the 90's. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True, + ) assert outputs is not None @@ -411,10 +418,11 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(json=json_schema)) outputs = llm.generate( - prompts=("Generate a description of a frog using 50 characters. " - "Make the response as short as possible."), + ("Generate a description of a frog using 50 characters. " + "Make the response as short as possible."), sampling_params=sampling_params, - use_tqdm=True) + use_tqdm=True, + ) assert outputs is not None @@ -498,7 +506,7 @@ Make the response as short as possible. """ # Change this once other backends support structural_tag - outputs = llm.generate(prompts=prompt, + outputs = llm.generate(prompt, sampling_params=sampling_params, use_tqdm=True) assert outputs is not None @@ -639,15 +647,13 @@ def test_structured_output_auto_mode( f"{unsupported_json_schema}. Make the response as short as possible.") # This would fail with the default of "xgrammar", but in "auto" # we will handle fallback automatically. - outputs = llm.generate(prompts=prompts, + outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True) # Make sure `auto` backend handling doesn't mess up sampling_params # and that we can reuse it without error. outputs.extend( - llm.generate(prompts=prompts, - sampling_params=sampling_params, - use_tqdm=True)) + llm.generate(prompts, sampling_params=sampling_params, use_tqdm=True)) assert outputs is not None for output in outputs: @@ -705,7 +711,7 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): max_tokens=256, guided_decoding=guided_params) - outputs = llm.generate(prompts=prompt, sampling_params=sampling_params) + outputs = llm.generate(prompt, sampling_params=sampling_params) assert outputs is not None generated_text = outputs[0].outputs[0].text assert generated_text is not None diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b002f234c043..728ed8328d36 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,15 +3,13 @@ import itertools from collections.abc import Sequence -from contextlib import contextmanager -from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union, - cast, overload) +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast import cloudpickle import torch.nn as nn from pydantic import ValidationError from tqdm.auto import tqdm -from typing_extensions import TypeVar, deprecated +from typing_extensions import TypeVar import vllm.envs as envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, @@ -40,7 +38,6 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam, from vllm.entrypoints.utils import (_validate_truncation_size, log_non_default_args) from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt -from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.quantization import QuantizationMethods @@ -54,7 +51,7 @@ from vllm.tasks import PoolingTask from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of +from vllm.utils import Counter, Device, is_list_of from vllm.v1.sample.logits_processor import LogitsProcessor if TYPE_CHECKING: @@ -157,18 +154,6 @@ class LLM: serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead. """ - DEPRECATE_LEGACY: ClassVar[bool] = True - """A flag to toggle whether to deprecate the legacy generate/encode API.""" - - @classmethod - @contextmanager - def deprecate_legacy_api(cls): - cls.DEPRECATE_LEGACY = True - - yield - - cls.DEPRECATE_LEGACY = False - def __init__( self, model: str, @@ -325,99 +310,14 @@ class LLM: return SamplingParams.from_optional(**self.default_sampling_params) return SamplingParams() - @overload def generate( self, prompts: Union[PromptType, Sequence[PromptType]], - /, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: str, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - prompt_token_ids: Optional[list[int]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: multi (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: list[str], - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - prompt_token_ids: Optional[list[list[int]]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: Optional[str] = None, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - *, - prompt_token_ids: list[int], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: multi (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: Optional[list[str]] = None, - sampling_params: Optional[Union[SamplingParams, - list[SamplingParams]]] = None, - *, - prompt_token_ids: list[list[int]], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @overload # LEGACY: single or multi token ids [pos-only] - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def generate( - self, - prompts: None, - sampling_params: None, - prompt_token_ids: Union[list[int], list[list[int]]], - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - ) -> list[RequestOutput]: - ... - - @deprecate_kwargs( - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'prompts' parameter instead.", - ) - def generate( - self, - prompts: Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, list[str]]]] = None, - sampling_params: Optional[Union[SamplingParams, - Sequence[SamplingParams]]] = None, - prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, priority: Optional[list[int]] = None, ) -> list[RequestOutput]: """Generates the completions for the input prompts. @@ -460,15 +360,6 @@ class LLM: "Try passing `--runner generate` to use the model as a " "generative model.") - if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( - prompts=cast(Optional[Union[str, list[str]]], prompts), - prompt_token_ids=prompt_token_ids, - ) - else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) - if sampling_params is None: # Use default sampling params. sampling_params = self.get_default_sampling_params() @@ -483,10 +374,10 @@ class LLM: # Add any modality specific loras to the corresponding prompts lora_request = self._get_modality_specific_lora_reqs( - parsed_prompts, lora_request) + prompts, lora_request) self._validate_and_add_requests( - prompts=parsed_prompts, + prompts=prompts, params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, @@ -498,7 +389,7 @@ class LLM: return self.engine_class.validate_outputs(outputs, RequestOutput) def _get_modality_specific_lora_reqs( - self, parsed_prompts: Union[PromptType, Sequence[PromptType]], + self, prompts: Union[PromptType, Sequence[PromptType]], lora_request: Optional[Union[list[LoRARequest], LoRARequest]]): # Grab the lora config off the vllm config on the engine, # since this is the same for both v0 & v1. @@ -511,35 +402,33 @@ class LLM: or (lora_config and lora_config.default_mm_loras is None)): return lora_request - if not isinstance(parsed_prompts, Sequence): - parsed_prompts = [parsed_prompts] + if not isinstance(prompts, Sequence): + prompts = [prompts] - optional_loras = ([lora_request] * len(parsed_prompts) + optional_loras = ([lora_request] * len(prompts) if not isinstance(lora_request, Sequence) else lora_request) return [ self._resolve_single_prompt_mm_lora( - parsed_prompt, + prompt, opt_lora_req, lora_config.default_mm_loras, - ) for parsed_prompt, opt_lora_req in zip(parsed_prompts, - optional_loras) + ) for prompt, opt_lora_req in zip(prompts, optional_loras) ] - def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType, + def _resolve_single_prompt_mm_lora(self, prompt: PromptType, lora_request: Optional[LoRARequest], default_mm_loras: Optional[dict[str, str]]): - if (not default_mm_loras or not isinstance(parsed_prompt, dict) - or "multi_modal_data" not in parsed_prompt): + if (not default_mm_loras or not isinstance(prompt, dict) + or "multi_modal_data" not in prompt): return lora_request - parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt) + prompt = cast(Union[TextPrompt, TokensPrompt], prompt) - intersection = set( - parsed_prompt["multi_modal_data"].keys()).intersection( - default_mm_loras.keys()) + intersection = set(prompt["multi_modal_data"].keys()) \ + .intersection(default_mm_loras.keys()) if not intersection: return lora_request if len(intersection) > 1: @@ -933,11 +822,9 @@ class LLM: lora_request=lora_request, ) - @overload def encode( self, prompts: Union[PromptType, Sequence[PromptType]], - /, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, *, @@ -946,107 +833,6 @@ class LLM: lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, pooling_task: PoolingTask = "encode", tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: str, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[list[int]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: multi (prompt + optional token ids) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: list[str], - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[list[list[int]]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: Optional[str] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - *, - prompt_token_ids: list[int], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: multi (token ids + optional prompt) - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: Optional[list[str]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - *, - prompt_token_ids: list[list[int]], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @overload # LEGACY: single or multi token ids [pos-only] - @deprecated("'prompt_token_ids' will become part of 'prompts'") - def encode( - self, - prompts: None, - pooling_params: None, - prompt_token_ids: Union[list[int], list[list[int]]], - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: PoolingTask = "encode", - tokenization_kwargs: Optional[dict[str, Any]] = None, - ) -> list[PoolingRequestOutput]: - ... - - @deprecate_kwargs( - "prompt_token_ids", - is_deprecated=lambda: LLM.DEPRECATE_LEGACY, - additional_message="Please use the 'prompts' parameter instead.", - ) - def encode( - self, - prompts: Union[Union[PromptType, Sequence[PromptType]], - Optional[Union[str, list[str]]]] = None, - pooling_params: Optional[Union[PoolingParams, - Sequence[PoolingParams]]] = None, - prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, - truncate_prompt_tokens: Optional[int] = None, - use_tqdm: Union[bool, Callable[..., tqdm]] = True, - lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - pooling_task: Optional[PoolingTask] = None, - tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input prompts. @@ -1108,15 +894,6 @@ class LLM: raise ValueError( f"pooling_task must be one of {self.supported_tasks}.") - if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( - prompts=cast(Optional[Union[str, list[str]]], prompts), - prompt_token_ids=prompt_token_ids, - ) - else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) - if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() @@ -1134,7 +911,7 @@ class LLM: tokenization_kwargs) self._validate_and_add_requests( - prompts=parsed_prompts, + prompts=prompts, params=pooling_params, use_tqdm=use_tqdm, lora_request=lora_request, @@ -1148,7 +925,6 @@ class LLM: def embed( self, prompts: Union[PromptType, Sequence[PromptType]], - /, *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, @@ -1198,7 +974,6 @@ class LLM: def classify( self, prompts: Union[PromptType, Sequence[PromptType]], - /, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, pooling_params: Optional[Union[PoolingParams, @@ -1348,7 +1123,7 @@ class LLM: _validate_truncation_size(model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) - parsed_prompts = [] + prompts = list[PromptType]() input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] @@ -1372,10 +1147,10 @@ class LLM: else: pooling_params_list.append(pooling_params) - parsed_prompts.append(engine_prompt) + prompts.append(engine_prompt) self._validate_and_add_requests( - prompts=parsed_prompts, + prompts=prompts, params=pooling_params_list, use_tqdm=use_tqdm, lora_request=lora_request, @@ -1585,48 +1360,6 @@ class LLM: assert isinstance(self.llm_engine, V1LLMEngine) return self.llm_engine.get_metrics() - # LEGACY - def _convert_v1_inputs( - self, - prompts: Optional[Union[str, list[str]]], - prompt_token_ids: Optional[Union[list[int], list[list[int]]]], - ): - # skip_tokenizer_init is now checked in engine - - if prompts is None and prompt_token_ids is None: - raise ValueError( - "Either prompts or prompt_token_ids must be provided.") - if prompts is not None and prompt_token_ids is not None \ - and len(prompts) != len(prompt_token_ids): - raise ValueError( - "The lengths of prompts and prompt_token_ids must be the same." - ) - - if prompts is not None: - prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] - if prompt_token_ids is not None: - prompt_token_ids = [ - p["content"] for p in parse_and_batch_prompt(prompt_token_ids) - ] - if prompts is not None: - num_requests = len(prompts) - elif prompt_token_ids is not None: - num_requests = len(prompt_token_ids) - parsed_prompts: list[PromptType] = [] - for i in range(num_requests): - item: PromptType - - if prompts is not None: - item = TextPrompt(prompt=prompts[i]) - elif prompt_token_ids is not None: - item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) - else: - raise AssertionError - - parsed_prompts.append(item) - - return parsed_prompts - def _validate_and_add_requests( self, prompts: Union[PromptType, Sequence[PromptType]],