[UX][Attention] Add attention_config argument to LLM() (#30710)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni 2025-12-15 17:29:09 -05:00 committed by GitHub
parent c01d589813
commit a182be4308
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,6 +18,7 @@ from vllm.beam_search import (
create_sort_beams_key_function,
)
from vllm.config import (
AttentionConfig,
CompilationConfig,
PoolerConfig,
ProfilerConfig,
@ -175,6 +176,10 @@ class LLM:
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
attention_config: Configuration for attention mechanisms. Can be a
dictionary or an AttentionConfig instance. If a dictionary, it will
be converted to an AttentionConfig. Allows specifying the attention
backend and other attention-related settings.
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
Note:
@ -213,6 +218,7 @@ class LLM:
| StructuredOutputsConfig
| None = None,
profiler_config: dict[str, Any] | ProfilerConfig | None = None,
attention_config: dict[str, Any] | AttentionConfig | None = None,
kv_cache_memory_bytes: int | None = None,
compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
logits_processors: list[str | type[LogitsProcessor]] | None = None,
@ -252,51 +258,28 @@ class LLM:
if hf_overrides is None:
hf_overrides = {}
if compilation_config is not None:
def _make_config(value: Any, cls: type[_R]) -> _R:
"""Convert dict/None/instance to a config instance."""
if value is None:
return cls()
if isinstance(value, dict):
return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)}) # type: ignore[arg-type]
return value
if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig(
mode=CompilationMode(compilation_config)
)
elif isinstance(compilation_config, dict):
compilation_config_instance = CompilationConfig(
**{
k: v
for k, v in compilation_config.items()
if is_init_field(CompilationConfig, k)
}
else:
compilation_config_instance = _make_config(
compilation_config, CompilationConfig
)
else:
compilation_config_instance = compilation_config
else:
compilation_config_instance = CompilationConfig()
if structured_outputs_config is not None:
if isinstance(structured_outputs_config, dict):
structured_outputs_instance = StructuredOutputsConfig(
**{
k: v
for k, v in structured_outputs_config.items()
if is_init_field(StructuredOutputsConfig, k)
}
structured_outputs_instance = _make_config(
structured_outputs_config, StructuredOutputsConfig
)
else:
structured_outputs_instance = structured_outputs_config
else:
structured_outputs_instance = StructuredOutputsConfig()
if profiler_config is not None:
if isinstance(profiler_config, dict):
profiler_config_instance = ProfilerConfig(
**{
k: v
for k, v in profiler_config.items()
if is_init_field(ProfilerConfig, k)
}
)
else:
profiler_config_instance = profiler_config
else:
profiler_config_instance = ProfilerConfig()
profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
attention_config_instance = _make_config(attention_config, AttentionConfig)
# warn about single-process data parallel usage.
_dp_size = int(kwargs.get("data_parallel_size", 1))
@ -341,6 +324,7 @@ class LLM:
pooler_config=pooler_config,
structured_outputs_config=structured_outputs_instance,
profiler_config=profiler_config_instance,
attention_config=attention_config_instance,
compilation_config=compilation_config_instance,
logits_processors=logits_processors,
**kwargs,