Remove default values from InitVars so that they're not stored (#29859)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-12-02 12:16:37 +00:00 committed by GitHub
parent d8c6210eea
commit 951445a52d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 139 additions and 77 deletions

View File

@ -108,7 +108,10 @@ def benchmark_batched_propose(args):
device_config=DeviceConfig(device=current_platform.device_type), device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(), parallel_config=ParallelConfig(),
load_config=LoadConfig(), load_config=LoadConfig(),
scheduler_config=SchedulerConfig(), scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
) )
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group # monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group

View File

@ -318,13 +318,18 @@ def test_attention_quant_pattern(
torch.set_default_dtype(dtype) torch.set_default_dtype(dtype)
torch.manual_seed(42) torch.manual_seed(42)
model_config = ModelConfig(
model=model_name,
max_model_len=2048,
dtype=dtype,
)
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=ModelConfig( model_config=model_config,
model=model_name, scheduler_config=SchedulerConfig(
max_model_len=2048, max_num_seqs=1024,
dtype=dtype, max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
), ),
scheduler_config=SchedulerConfig(max_num_seqs=1024),
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops_list, custom_ops=custom_ops_list,

View File

@ -33,14 +33,16 @@ def test_worker_apply_lora(qwen3_lora_files):
lora_requests, lora_mapping lora_requests, lora_mapping
) )
model_config = ModelConfig(
MODEL_PATH,
seed=0,
dtype="float16",
max_model_len=127,
enforce_eager=True,
)
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=ModelConfig( model_config=model_config,
MODEL_PATH,
seed=0,
dtype="float16",
max_model_len=127,
enforce_eager=True,
),
load_config=LoadConfig( load_config=LoadConfig(
download_dir=None, download_dir=None,
load_format="dummy", load_format="dummy",
@ -50,7 +52,14 @@ def test_worker_apply_lora(qwen3_lora_files):
tensor_parallel_size=1, tensor_parallel_size=1,
data_parallel_size=1, data_parallel_size=1,
), ),
scheduler_config=SchedulerConfig("generate", 32, 32, 32), scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
runner_type="generate",
max_num_batched_tokens=32,
max_num_seqs=32,
max_num_partial_prefills=32,
),
device_config=DeviceConfig("cuda"), device_config=DeviceConfig("cuda"),
cache_config=CacheConfig( cache_config=CacheConfig(
block_size=16, block_size=16,

View File

@ -6,12 +6,14 @@ from dataclasses import MISSING, Field, asdict, dataclass, field
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from pydantic import ValidationError
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.config import ( from vllm.config import (
CompilationConfig, CompilationConfig,
ModelConfig, ModelConfig,
PoolerConfig, PoolerConfig,
SchedulerConfig,
VllmConfig, VllmConfig,
update_config, update_config,
) )
@ -1095,3 +1097,14 @@ def test_vllm_config_explicit_overrides():
# Other fields should still use defaults # Other fields should still use defaults
assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE
def test_scheduler_config_init():
with pytest.raises(ValidationError):
# Positional InitVars missing
# (InitVars cannot have defaults otherwise they will become attributes)
SchedulerConfig()
with pytest.raises(AttributeError):
# InitVar does not become an attribute
print(SchedulerConfig.default_factory().max_model_len)

View File

@ -185,6 +185,8 @@ def create_vllm_config(
max_num_seqs=max_num_seqs, max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens, max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
) )
device_config = DeviceConfig() device_config = DeviceConfig()

View File

@ -1128,7 +1128,11 @@ def test_estimate_max_model_len(model_id, max_model_len, want_estimated_max_len)
dtype="float16", dtype="float16",
max_model_len=max_model_len, max_model_len=max_model_len,
) )
scheduler_config = SchedulerConfig(max_num_batched_tokens=32768) scheduler_config = SchedulerConfig(
max_num_batched_tokens=32768,
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
)
vllm_config = VllmConfig( vllm_config = VllmConfig(
model_config=model_config, model_config=model_config,
@ -1163,7 +1167,10 @@ def test_get_max_concurrency_for_kv_cache_config():
max_model_len=max_model_len, max_model_len=max_model_len,
) )
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
max_num_batched_tokens=1024, enable_chunked_prefill=True max_num_batched_tokens=1024,
enable_chunked_prefill=True,
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
) )
vllm_config = VllmConfig( vllm_config = VllmConfig(

View File

@ -1508,6 +1508,12 @@ def create_scheduler_with_priority(
Returns: Returns:
{class}`Scheduler` instance with priority scheduling {class}`Scheduler` instance with priority scheduling
""" """
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
seed=42,
)
if max_model_len is None: if max_model_len is None:
max_model_len = max_num_batched_tokens max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
@ -1517,14 +1523,9 @@ def create_scheduler_with_priority(
long_prefill_token_threshold=long_prefill_token_threshold, long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input, disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True, enable_chunked_prefill=True,
is_encoder_decoder=model_config.is_encoder_decoder,
policy="priority", # Enable priority scheduling policy="priority", # Enable priority scheduling
) )
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
seed=42,
)
# Cache config, optionally force APC # Cache config, optionally force APC
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=block_size, block_size=block_size,

View File

@ -69,6 +69,13 @@ def create_scheduler(
Returns: Returns:
{class}`Scheduler` instance {class}`Scheduler` instance
""" """
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
seed=42,
skip_tokenizer_init=skip_tokenizer_init,
)
if max_model_len is None: if max_model_len is None:
max_model_len = max_num_batched_tokens max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
@ -79,13 +86,7 @@ def create_scheduler(
disable_chunked_mm_input=disable_chunked_mm_input, disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=enable_chunked_prefill, enable_chunked_prefill=enable_chunked_prefill,
async_scheduling=async_scheduling, async_scheduling=async_scheduling,
) is_encoder_decoder=model_config.is_encoder_decoder,
model_config = ModelConfig(
model=model,
trust_remote_code=True,
dtype="float16",
seed=42,
skip_tokenizer_init=skip_tokenizer_init,
) )
# Cache config, optionally force APC # Cache config, optionally force APC
cache_config = CacheConfig( cache_config = CacheConfig(

View File

@ -40,7 +40,9 @@ def _create_vllm_config(
) -> MagicMock: ) -> MagicMock:
mock_config = MagicMock(spec=VllmConfig) mock_config = MagicMock(spec=VllmConfig)
mock_config.compilation_config = compilation_config mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs) mock_config.scheduler_config = SchedulerConfig.default_factory(
max_num_seqs=max_num_seqs,
)
mock_config.parallel_config = ParallelConfig() mock_config.parallel_config = ParallelConfig()
mock_config.speculative_config = None # No speculative decoding mock_config.speculative_config = None # No speculative decoding
if not lora_config: if not lora_config:

View File

@ -484,12 +484,6 @@ def test_encoder_instance_zero_kv_cache(
vision encoder, so they don't need KV cache for text generation. vision encoder, so they don't need KV cache for text generation.
""" """
# Form vllm config # Form vllm config
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
disable_hybrid_kv_cache_manager=True,
)
model_config = ModelConfig( model_config = ModelConfig(
model="llava-hf/llava-1.5-7b-hf", # Multimodal model model="llava-hf/llava-1.5-7b-hf", # Multimodal model
enforce_eager=True, enforce_eager=True,
@ -497,6 +491,13 @@ def test_encoder_instance_zero_kv_cache(
dtype="float16", dtype="float16",
seed=42, seed=42,
) )
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
disable_hybrid_kv_cache_manager=True,
is_encoder_decoder=model_config.is_encoder_decoder,
)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=16, block_size=16,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,

View File

@ -92,18 +92,19 @@ def create_vllm_config(
enable_permute_local_kv: bool = False, enable_permute_local_kv: bool = False,
) -> VllmConfig: ) -> VllmConfig:
"""Initialize VllmConfig For Testing.""" """Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
enable_chunked_prefill=enable_chunked_prefill,
)
model_config = ModelConfig( model_config = ModelConfig(
model=model, model=model,
trust_remote_code=True, trust_remote_code=True,
dtype="float16", dtype="float16",
seed=42, seed=42,
) )
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
enable_chunked_prefill=enable_chunked_prefill,
is_encoder_decoder=model_config.is_encoder_decoder,
)
# Cache config, optionally force APC # Cache config, optionally force APC
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=block_size, block_size=block_size,

View File

@ -66,7 +66,10 @@ def _create_proposer(
device_config=DeviceConfig(device=current_platform.device_type), device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(), parallel_config=ParallelConfig(),
load_config=LoadConfig(), load_config=LoadConfig(),
scheduler_config=SchedulerConfig(), scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
) )
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)

View File

@ -51,7 +51,10 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer:
device_config=DeviceConfig(device=current_platform.device_type), device_config=DeviceConfig(device=current_platform.device_type),
parallel_config=ParallelConfig(), parallel_config=ParallelConfig(),
load_config=LoadConfig(), load_config=LoadConfig(),
scheduler_config=SchedulerConfig(), scheduler_config=SchedulerConfig(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
) )
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)

View File

@ -26,16 +26,17 @@ from vllm.v1.worker.tpu_model_runner import (
def get_vllm_config(): def get_vllm_config():
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
)
model_config = ModelConfig( model_config = ModelConfig(
model="facebook/opt-125m", model="facebook/opt-125m",
dtype="bfloat16", # TPUs typically use bfloat16 dtype="bfloat16", # TPUs typically use bfloat16
seed=42, seed=42,
) )
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
is_encoder_decoder=model_config.is_encoder_decoder,
)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=16, block_size=16,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,

View File

@ -79,16 +79,17 @@ def initialize_kv_cache(runner: GPUModelRunner):
def get_vllm_config(): def get_vllm_config():
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
)
model_config = ModelConfig( model_config = ModelConfig(
model="facebook/opt-125m", model="facebook/opt-125m",
dtype="float16", dtype="float16",
seed=42, seed=42,
) )
scheduler_config = SchedulerConfig(
max_num_seqs=10,
max_num_batched_tokens=512,
max_model_len=512,
is_encoder_decoder=model_config.is_encoder_decoder,
)
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.9,
@ -784,14 +785,15 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):
initialize_model_parallel(tensor_model_parallel_size=1) initialize_model_parallel(tensor_model_parallel_size=1)
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)
model_config = ModelConfig(
model="ibm-granite/granite-4.0-tiny-preview",
dtype="float16",
)
scheduler_config = SchedulerConfig( scheduler_config = SchedulerConfig(
max_num_seqs=10, max_num_seqs=10,
max_num_batched_tokens=512, max_num_batched_tokens=512,
max_model_len=512, max_model_len=512,
) is_encoder_decoder=model_config.is_encoder_decoder,
model_config = ModelConfig(
model="ibm-granite/granite-4.0-tiny-preview",
dtype="float16",
) )
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=BLOCK_SIZE, block_size=BLOCK_SIZE,

View File

@ -28,6 +28,19 @@ SchedulerPolicy = Literal["fcfs", "priority"]
class SchedulerConfig: class SchedulerConfig:
"""Scheduler configuration.""" """Scheduler configuration."""
max_model_len: InitVar[int]
"""Maximum length of a sequence (including prompt and generated text).
Note: This is stored in the ModelConfig, and is used only here to
provide fallbacks and validate other attributes."""
is_encoder_decoder: InitVar[bool]
"""True if the model is an encoder-decoder model.
Note: This is stored in the ModelConfig, and is used only here to
disable chunked prefill and prefix caching for encoder-decoder models.
"""
DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048 DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048
DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128 DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128
@ -73,19 +86,6 @@ class SchedulerConfig:
is_multimodal_model: bool = False is_multimodal_model: bool = False
"""True if the model is multimodal.""" """True if the model is multimodal."""
max_model_len: InitVar[int] = 8192
"""Maximum length of a sequence (including prompt and generated text).
Note: This is stored in the ModelConfig, and is used only here to
provide fallbacks and validate other attributes."""
is_encoder_decoder: InitVar[bool] = False
"""True if the model is an encoder-decoder model.
Note: This is stored in the ModelConfig, and is used only here to
disable chunked prefill and prefix caching for encoder-decoder models.
"""
# TODO (ywang96): Make this configurable. # TODO (ywang96): Make this configurable.
max_num_encoder_input_tokens: int = Field(init=False) max_num_encoder_input_tokens: int = Field(init=False)
"""Multimodal encoder compute budget, only used in V1. """Multimodal encoder compute budget, only used in V1.
@ -141,6 +141,17 @@ class SchedulerConfig:
while a larger value (e.g., 10) reduces host overhead and may increase throughput while a larger value (e.g., 10) reduces host overhead and may increase throughput
by batching multiple tokens before sending.""" by batching multiple tokens before sending."""
@staticmethod
def default_factory(**kwargs):
"""
Factory method to create `SchedulerConfig` with default values for `InitVar`s.
"""
if "max_model_len" not in kwargs:
kwargs["max_model_len"] = 8192
if "is_encoder_decoder" not in kwargs:
kwargs["is_encoder_decoder"] = False
return SchedulerConfig(**kwargs)
def get_scheduler_cls(self) -> type["SchedulerInterface"]: def get_scheduler_cls(self) -> type["SchedulerInterface"]:
if self.scheduler_cls is None: if self.scheduler_cls is None:
if self.async_scheduling: if self.async_scheduling:
@ -284,8 +295,3 @@ class SchedulerConfig:
) )
return self return self
def __getattribute__(self, name: str) -> Any:
if name == "max_model_len" or name == "is_encoder_decoder":
raise AttributeError(f"{name} is an init-only parameter. ")
return object.__getattribute__(self, name)

View File

@ -170,7 +170,9 @@ class VllmConfig:
"""Cache configuration.""" """Cache configuration."""
parallel_config: ParallelConfig = Field(default_factory=ParallelConfig) parallel_config: ParallelConfig = Field(default_factory=ParallelConfig)
"""Parallel configuration.""" """Parallel configuration."""
scheduler_config: SchedulerConfig = Field(default_factory=SchedulerConfig) scheduler_config: SchedulerConfig = Field(
default_factory=SchedulerConfig.default_factory,
)
"""Scheduler configuration.""" """Scheduler configuration."""
device_config: DeviceConfig = Field(default_factory=DeviceConfig) device_config: DeviceConfig = Field(default_factory=DeviceConfig)
"""Device configuration.""" """Device configuration."""