mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:05:53 +08:00
[misc][core] lazy import outlines (#7831)
This commit is contained in:
parent
d81abefd2e
commit
7d9ffa2ae1
@ -87,7 +87,8 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pip install -e ./plugins/vllm_add_dummy_model
|
- pip install -e ./plugins/vllm_add_dummy_model
|
||||||
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
|
- pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api]
|
||||||
- pytest -v -s entrypoints/llm
|
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py
|
||||||
|
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||||
- pytest -v -s entrypoints/openai
|
- pytest -v -s entrypoints/openai
|
||||||
|
|
||||||
- label: Distributed Tests (4 GPUs) # 10min
|
- label: Distributed Tests (4 GPUs) # 10min
|
||||||
|
|||||||
48
tests/entrypoints/llm/test_lazy_outlines.py
Normal file
48
tests/entrypoints/llm/test_lazy_outlines.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import sys
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
def test_lazy_outlines(sample_regex):
|
||||||
|
"""If users don't use guided decoding, outlines should not be imported.
|
||||||
|
"""
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
llm = LLM(model="facebook/opt-125m",
|
||||||
|
enforce_eager=True,
|
||||||
|
gpu_memory_utilization=0.3)
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
# make sure outlines is not imported
|
||||||
|
assert 'outlines' not in sys.modules
|
||||||
|
|
||||||
|
llm = LLM(model="facebook/opt-125m",
|
||||||
|
enforce_eager=True,
|
||||||
|
guided_decoding_backend="lm-format-enforcer",
|
||||||
|
gpu_memory_utilization=0.3)
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts=[
|
||||||
|
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||||
|
] * 2,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True,
|
||||||
|
guided_options_request=dict(guided_regex=sample_regex))
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
# make sure outlines is not imported
|
||||||
|
assert 'outlines' not in sys.modules
|
||||||
@ -5,9 +5,6 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
CompletionRequest)
|
CompletionRequest)
|
||||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||||
GuidedDecodingRequest)
|
GuidedDecodingRequest)
|
||||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
|
||||||
get_local_outlines_guided_decoding_logits_processor,
|
|
||||||
get_outlines_guided_decoding_logits_processor)
|
|
||||||
from vllm.sampling_params import LogitsProcessor
|
from vllm.sampling_params import LogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +15,9 @@ async def get_guided_decoding_logits_processor(
|
|||||||
request = _adapt_request_for_tool_use(request)
|
request = _adapt_request_for_tool_use(request)
|
||||||
|
|
||||||
if guided_decoding_backend == 'outlines':
|
if guided_decoding_backend == 'outlines':
|
||||||
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
|
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||||
|
get_outlines_guided_decoding_logits_processor)
|
||||||
return await get_outlines_guided_decoding_logits_processor(
|
return await get_outlines_guided_decoding_logits_processor(
|
||||||
request, tokenizer)
|
request, tokenizer)
|
||||||
if guided_decoding_backend == 'lm-format-enforcer':
|
if guided_decoding_backend == 'lm-format-enforcer':
|
||||||
@ -37,6 +37,9 @@ def get_local_guided_decoding_logits_processor(
|
|||||||
# request = _adapt_request_for_tool_use(request)
|
# request = _adapt_request_for_tool_use(request)
|
||||||
|
|
||||||
if guided_decoding_backend == 'outlines':
|
if guided_decoding_backend == 'outlines':
|
||||||
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
|
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
|
||||||
|
get_local_outlines_guided_decoding_logits_processor)
|
||||||
return get_local_outlines_guided_decoding_logits_processor(
|
return get_local_outlines_guided_decoding_logits_processor(
|
||||||
guided_options, tokenizer)
|
guided_options, tokenizer)
|
||||||
if guided_decoding_backend == 'lm-format-enforcer':
|
if guided_decoding_backend == 'lm-format-enforcer':
|
||||||
|
|||||||
@ -14,9 +14,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
CompletionRequest)
|
CompletionRequest)
|
||||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||||
GuidedDecodingRequest)
|
GuidedDecodingRequest)
|
||||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
|
||||||
get_local_outlines_guided_decoding_logits_processor,
|
|
||||||
get_outlines_guided_decoding_logits_processor)
|
|
||||||
from vllm.sampling_params import LogitsProcessor
|
from vllm.sampling_params import LogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
@ -43,6 +40,10 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
|
|||||||
character_level_parser = RegexParser(request.guided_regex)
|
character_level_parser = RegexParser(request.guided_regex)
|
||||||
elif request.guided_grammar:
|
elif request.guided_grammar:
|
||||||
# CFG grammar not supported by LMFE, revert to outlines
|
# CFG grammar not supported by LMFE, revert to outlines
|
||||||
|
|
||||||
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
|
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||||
|
get_outlines_guided_decoding_logits_processor)
|
||||||
return await get_outlines_guided_decoding_logits_processor(
|
return await get_outlines_guided_decoding_logits_processor(
|
||||||
request, tokenizer)
|
request, tokenizer)
|
||||||
elif (request.response_format is not None
|
elif (request.response_format is not None
|
||||||
@ -87,6 +88,10 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
|
|||||||
character_level_parser = RegexParser(guided_options.guided_regex)
|
character_level_parser = RegexParser(guided_options.guided_regex)
|
||||||
elif guided_options.guided_grammar:
|
elif guided_options.guided_grammar:
|
||||||
# CFG grammar not supported by LMFE, revert to outlines
|
# CFG grammar not supported by LMFE, revert to outlines
|
||||||
|
|
||||||
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
|
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||||
|
get_local_outlines_guided_decoding_logits_processor)
|
||||||
return get_local_outlines_guided_decoding_logits_processor(
|
return get_local_outlines_guided_decoding_logits_processor(
|
||||||
guided_options, tokenizer)
|
guided_options, tokenizer)
|
||||||
elif guided_options.guided_json_object:
|
elif guided_options.guided_json_object:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user