mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 11:16:38 +08:00
[UX][Attention] Add attention_config argument to LLM() (#30710)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
parent
c01d589813
commit
a182be4308
@ -18,6 +18,7 @@ from vllm.beam_search import (
|
|||||||
create_sort_beams_key_function,
|
create_sort_beams_key_function,
|
||||||
)
|
)
|
||||||
from vllm.config import (
|
from vllm.config import (
|
||||||
|
AttentionConfig,
|
||||||
CompilationConfig,
|
CompilationConfig,
|
||||||
PoolerConfig,
|
PoolerConfig,
|
||||||
ProfilerConfig,
|
ProfilerConfig,
|
||||||
@ -175,6 +176,10 @@ class LLM:
|
|||||||
compilation_config: Either an integer or a dictionary. If it is an
|
compilation_config: Either an integer or a dictionary. If it is an
|
||||||
integer, it is used as the mode of compilation optimization. If it
|
integer, it is used as the mode of compilation optimization. If it
|
||||||
is a dictionary, it can specify the full compilation configuration.
|
is a dictionary, it can specify the full compilation configuration.
|
||||||
|
attention_config: Configuration for attention mechanisms. Can be a
|
||||||
|
dictionary or an AttentionConfig instance. If a dictionary, it will
|
||||||
|
be converted to an AttentionConfig. Allows specifying the attention
|
||||||
|
backend and other attention-related settings.
|
||||||
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
|
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
@ -213,6 +218,7 @@ class LLM:
|
|||||||
| StructuredOutputsConfig
|
| StructuredOutputsConfig
|
||||||
| None = None,
|
| None = None,
|
||||||
profiler_config: dict[str, Any] | ProfilerConfig | None = None,
|
profiler_config: dict[str, Any] | ProfilerConfig | None = None,
|
||||||
|
attention_config: dict[str, Any] | AttentionConfig | None = None,
|
||||||
kv_cache_memory_bytes: int | None = None,
|
kv_cache_memory_bytes: int | None = None,
|
||||||
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
|
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
|
||||||
logits_processors: list[str | type[LogitsProcessor]] | None = None,
|
logits_processors: list[str | type[LogitsProcessor]] | None = None,
|
||||||
@ -252,51 +258,28 @@ class LLM:
|
|||||||
if hf_overrides is None:
|
if hf_overrides is None:
|
||||||
hf_overrides = {}
|
hf_overrides = {}
|
||||||
|
|
||||||
if compilation_config is not None:
|
def _make_config(value: Any, cls: type[_R]) -> _R:
|
||||||
if isinstance(compilation_config, int):
|
"""Convert dict/None/instance to a config instance."""
|
||||||
compilation_config_instance = CompilationConfig(
|
if value is None:
|
||||||
mode=CompilationMode(compilation_config)
|
return cls()
|
||||||
)
|
if isinstance(value, dict):
|
||||||
elif isinstance(compilation_config, dict):
|
return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)}) # type: ignore[arg-type]
|
||||||
compilation_config_instance = CompilationConfig(
|
return value
|
||||||
**{
|
|
||||||
k: v
|
|
||||||
for k, v in compilation_config.items()
|
|
||||||
if is_init_field(CompilationConfig, k)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
compilation_config_instance = compilation_config
|
|
||||||
else:
|
|
||||||
compilation_config_instance = CompilationConfig()
|
|
||||||
|
|
||||||
if structured_outputs_config is not None:
|
if isinstance(compilation_config, int):
|
||||||
if isinstance(structured_outputs_config, dict):
|
compilation_config_instance = CompilationConfig(
|
||||||
structured_outputs_instance = StructuredOutputsConfig(
|
mode=CompilationMode(compilation_config)
|
||||||
**{
|
)
|
||||||
k: v
|
|
||||||
for k, v in structured_outputs_config.items()
|
|
||||||
if is_init_field(StructuredOutputsConfig, k)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
structured_outputs_instance = structured_outputs_config
|
|
||||||
else:
|
else:
|
||||||
structured_outputs_instance = StructuredOutputsConfig()
|
compilation_config_instance = _make_config(
|
||||||
|
compilation_config, CompilationConfig
|
||||||
|
)
|
||||||
|
|
||||||
if profiler_config is not None:
|
structured_outputs_instance = _make_config(
|
||||||
if isinstance(profiler_config, dict):
|
structured_outputs_config, StructuredOutputsConfig
|
||||||
profiler_config_instance = ProfilerConfig(
|
)
|
||||||
**{
|
profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
|
||||||
k: v
|
attention_config_instance = _make_config(attention_config, AttentionConfig)
|
||||||
for k, v in profiler_config.items()
|
|
||||||
if is_init_field(ProfilerConfig, k)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
profiler_config_instance = profiler_config
|
|
||||||
else:
|
|
||||||
profiler_config_instance = ProfilerConfig()
|
|
||||||
|
|
||||||
# warn about single-process data parallel usage.
|
# warn about single-process data parallel usage.
|
||||||
_dp_size = int(kwargs.get("data_parallel_size", 1))
|
_dp_size = int(kwargs.get("data_parallel_size", 1))
|
||||||
@ -341,6 +324,7 @@ class LLM:
|
|||||||
pooler_config=pooler_config,
|
pooler_config=pooler_config,
|
||||||
structured_outputs_config=structured_outputs_instance,
|
structured_outputs_config=structured_outputs_instance,
|
||||||
profiler_config=profiler_config_instance,
|
profiler_config=profiler_config_instance,
|
||||||
|
attention_config=attention_config_instance,
|
||||||
compilation_config=compilation_config_instance,
|
compilation_config=compilation_config_instance,
|
||||||
logits_processors=logits_processors,
|
logits_processors=logits_processors,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user