diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 31319cf64aeb8..2768e267f4837 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -18,6 +18,7 @@ from vllm.beam_search import ( create_sort_beams_key_function, ) from vllm.config import ( + AttentionConfig, CompilationConfig, PoolerConfig, ProfilerConfig, @@ -175,6 +176,10 @@ class LLM: compilation_config: Either an integer or a dictionary. If it is an integer, it is used as the mode of compilation optimization. If it is a dictionary, it can specify the full compilation configuration. + attention_config: Configuration for attention mechanisms. Can be a + dictionary or an AttentionConfig instance. If a dictionary, it will + be converted to an AttentionConfig. Allows specifying the attention + backend and other attention-related settings. **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs]. Note: @@ -213,6 +218,7 @@ class LLM: | StructuredOutputsConfig | None = None, profiler_config: dict[str, Any] | ProfilerConfig | None = None, + attention_config: dict[str, Any] | AttentionConfig | None = None, kv_cache_memory_bytes: int | None = None, compilation_config: int | dict[str, Any] | CompilationConfig | None = None, logits_processors: list[str | type[LogitsProcessor]] | None = None, @@ -252,51 +258,28 @@ class LLM: if hf_overrides is None: hf_overrides = {} - if compilation_config is not None: - if isinstance(compilation_config, int): - compilation_config_instance = CompilationConfig( - mode=CompilationMode(compilation_config) - ) - elif isinstance(compilation_config, dict): - compilation_config_instance = CompilationConfig( - **{ - k: v - for k, v in compilation_config.items() - if is_init_field(CompilationConfig, k) - } - ) - else: - compilation_config_instance = compilation_config - else: - compilation_config_instance = CompilationConfig() + def _make_config(value: Any, cls: type[_R]) -> _R: + """Convert dict/None/instance to a config instance.""" + if value is None: + return cls() + if isinstance(value, dict): + return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)}) # type: ignore[arg-type] + return value - if structured_outputs_config is not None: - if isinstance(structured_outputs_config, dict): - structured_outputs_instance = StructuredOutputsConfig( - **{ - k: v - for k, v in structured_outputs_config.items() - if is_init_field(StructuredOutputsConfig, k) - } - ) - else: - structured_outputs_instance = structured_outputs_config + if isinstance(compilation_config, int): + compilation_config_instance = CompilationConfig( + mode=CompilationMode(compilation_config) + ) else: - structured_outputs_instance = StructuredOutputsConfig() + compilation_config_instance = _make_config( + compilation_config, CompilationConfig + ) - if profiler_config is not None: - if isinstance(profiler_config, dict): - profiler_config_instance = ProfilerConfig( - **{ - k: v - for k, v in profiler_config.items() - if is_init_field(ProfilerConfig, k) - } - ) - else: - profiler_config_instance = profiler_config - else: - profiler_config_instance = ProfilerConfig() + structured_outputs_instance = _make_config( + structured_outputs_config, StructuredOutputsConfig + ) + profiler_config_instance = _make_config(profiler_config, ProfilerConfig) + attention_config_instance = _make_config(attention_config, AttentionConfig) # warn about single-process data parallel usage. _dp_size = int(kwargs.get("data_parallel_size", 1)) @@ -341,6 +324,7 @@ class LLM: pooler_config=pooler_config, structured_outputs_config=structured_outputs_instance, profiler_config=profiler_config_instance, + attention_config=attention_config_instance, compilation_config=compilation_config_instance, logits_processors=logits_processors, **kwargs,