[misc][core] lazy import outlines (#7831)

This commit is contained in:
youkaichao 2024-08-24 00:51:38 -07:00 committed by GitHub
parent d81abefd2e
commit 7d9ffa2ae1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 7 deletions

View File

@ -87,7 +87,8 @@ steps:
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
- pytest -v -s entrypoints/llm
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/openai
- label: Distributed Tests (4 GPUs) # 10min

View File

@ -0,0 +1,48 @@
import sys
from vllm import LLM, SamplingParams
def test_lazy_outlines(sample_regex):
"""If users don't use guided decoding, outlines should not be imported.
"""
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
gpu_memory_utilization=0.3)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# make sure outlines is not imported
assert 'outlines' not in sys.modules
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# make sure outlines is not imported
assert 'outlines' not in sys.modules

View File

@ -5,9 +5,6 @@ from vllm.entrypoints.openai.protocol import (
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
@ -18,6 +15,9 @@ async def get_guided_decoding_logits_processor(
request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
@ -37,6 +37,9 @@ def get_local_guided_decoding_logits_processor(
# request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':

View File

@ -14,9 +14,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
@ -43,6 +40,10 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
elif (request.response_format is not None
@ -87,6 +88,10 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
character_level_parser = RegexParser(guided_options.guided_regex)
elif guided_options.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
elif guided_options.guided_json_object: