mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:15:01 +08:00
[misc][core] lazy import outlines (#7831)
This commit is contained in:
parent
d81abefd2e
commit
7d9ffa2ae1
@ -87,7 +87,8 @@ steps:
|
||||
commands:
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
|
||||
- pytest -v -s entrypoints/llm
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/openai
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
|
||||
48
tests/entrypoints/llm/test_lazy_outlines.py
Normal file
48
tests/entrypoints/llm/test_lazy_outlines.py
Normal file
@ -0,0 +1,48 @@
|
||||
import sys
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def test_lazy_outlines(sample_regex):
|
||||
"""If users don't use guided decoding, outlines should not be imported.
|
||||
"""
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.3)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# make sure outlines is not imported
|
||||
assert 'outlines' not in sys.modules
|
||||
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
gpu_memory_utilization=0.3)
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
guided_options_request=dict(guided_regex=sample_regex))
|
||||
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
# make sure outlines is not imported
|
||||
assert 'outlines' not in sys.modules
|
||||
@ -5,9 +5,6 @@ from vllm.entrypoints.openai.protocol import (
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest)
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_local_outlines_guided_decoding_logits_processor,
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
|
||||
@ -18,6 +15,9 @@ async def get_guided_decoding_logits_processor(
|
||||
request = _adapt_request_for_tool_use(request)
|
||||
|
||||
if guided_decoding_backend == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
if guided_decoding_backend == 'lm-format-enforcer':
|
||||
@ -37,6 +37,9 @@ def get_local_guided_decoding_logits_processor(
|
||||
# request = _adapt_request_for_tool_use(request)
|
||||
|
||||
if guided_decoding_backend == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||
get_local_outlines_guided_decoding_logits_processor)
|
||||
return get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_options, tokenizer)
|
||||
if guided_decoding_backend == 'lm-format-enforcer':
|
||||
|
||||
@ -14,9 +14,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest)
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_local_outlines_guided_decoding_logits_processor,
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
|
||||
@ -43,6 +40,10 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
character_level_parser = RegexParser(request.guided_regex)
|
||||
elif request.guided_grammar:
|
||||
# CFG grammar not supported by LMFE, revert to outlines
|
||||
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
elif (request.response_format is not None
|
||||
@ -87,6 +88,10 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
character_level_parser = RegexParser(guided_options.guided_regex)
|
||||
elif guided_options.guided_grammar:
|
||||
# CFG grammar not supported by LMFE, revert to outlines
|
||||
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_local_outlines_guided_decoding_logits_processor)
|
||||
return get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_options, tokenizer)
|
||||
elif guided_options.guided_json_object:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user