[3/N] model runner pass the whole config to model (#9958)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-02 12:08:49 -07:00 committed by GitHub
parent 74b529ceee
commit cea808f325
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 87 additions and 140 deletions

View File

@ -248,11 +248,10 @@ def llama_2_7b_engine_extra_embeddings():
cleanup_dist_env_and_memory(shutdown_ray=True) cleanup_dist_env_and_memory(shutdown_ray=True)
get_model_old = get_model get_model_old = get_model
def get_model_patched(*, model_config, device_config, **kwargs): def get_model_patched(**kwargs):
kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8) kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
return get_model_old(model_config=model_config, max_lora_rank=8)
device_config=device_config, return get_model_old(**kwargs)
**kwargs)
with patch("vllm.worker.model_runner.get_model", get_model_patched): with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)

View File

@ -1,27 +1,15 @@
from typing import Optional
from torch import nn from torch import nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import VllmConfig
ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.model_executor.model_loader.loader import (BaseModelLoader, from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader) get_model_loader)
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture) get_architecture_class_name, get_model_architecture)
def get_model(*, model_config: ModelConfig, load_config: LoadConfig, def get_model(*, vllm_config: VllmConfig) -> nn.Module:
device_config: DeviceConfig, parallel_config: ParallelConfig, loader = get_model_loader(vllm_config.load_config)
scheduler_config: SchedulerConfig, return loader.load_model(vllm_config=vllm_config)
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config,
device_config=device_config,
lora_config=lora_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config)
__all__ = [ __all__ = [

View File

@ -21,9 +21,9 @@ from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, from vllm.config import (CacheConfig, LoadConfig, LoadFormat, LoRAConfig,
LoRAConfig, ModelConfig, MultiModalConfig, ModelConfig, MultiModalConfig, ParallelConfig,
ParallelConfig, PoolerConfig, SchedulerConfig) PoolerConfig, SchedulerConfig, VllmConfig)
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
@ -150,6 +150,7 @@ def _get_model_initialization_kwargs(
def build_model(model_class: Type[nn.Module], def build_model(model_class: Type[nn.Module],
vllm_config: VllmConfig,
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig], cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig], quant_config: Optional[QuantizationConfig],
@ -166,23 +167,29 @@ def build_model(model_class: Type[nn.Module],
if prefix: if prefix:
extra_kwargs["prefix"] = prefix extra_kwargs["prefix"] = prefix
# TODO: unify all the module initialization code
# to only take the `VllmConfig` object as input
from vllm.plugins import set_vllm_config
set_vllm_config(vllm_config)
return model_class(config=hf_config, return model_class(config=hf_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
**extra_kwargs) **extra_kwargs)
def _initialize_model( def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
model_config: ModelConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
load_config = vllm_config.load_config
model_class, _ = get_model_architecture(model_config) model_class, _ = get_model_architecture(model_config)
return build_model( return build_model(
model_class, model_class,
vllm_config,
model_config.hf_config, model_config.hf_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=_get_quantization_config(model_config, load_config), quant_config=_get_quantization_config(model_config, load_config),
@ -205,12 +212,7 @@ class BaseModelLoader(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
"""Load a model with the given configurations.""" """Load a model with the given configurations."""
raise NotImplementedError raise NotImplementedError
@ -396,18 +398,14 @@ class DefaultModelLoader(BaseModelLoader):
model_config.revision, model_config.revision,
fall_back_to_pt=True) fall_back_to_pt=True)
def load_model(self, *, model_config: ModelConfig, def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config: DeviceConfig, device_config = vllm_config.device_config
lora_config: Optional[LoRAConfig], model_config = vllm_config.model_config
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
target_device = torch.device(device_config.device) target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with target_device: with target_device:
model = _initialize_model(model_config, self.load_config, model = _initialize_model(vllm_config=vllm_config)
lora_config, cache_config,
scheduler_config)
model.load_weights(self._get_all_weights(model_config, model)) model.load_weights(self._get_all_weights(model_config, model))
@ -436,17 +434,12 @@ class DummyModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download pass # Nothing to download
def load_model(self, *, model_config: ModelConfig, def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config: DeviceConfig, device_config = vllm_config.device_config
lora_config: Optional[LoRAConfig], model_config = vllm_config.model_config
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(vllm_config=vllm_config)
lora_config, cache_config,
scheduler_config)
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
initialize_dummy_weights(model) initialize_dummy_weights(model)
@ -488,10 +481,7 @@ class TensorizerLoader(BaseModelLoader):
def _load_model_serialized_cpu( def _load_model_serialized_cpu(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
) -> nn.Module: ) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU. """Load a serialized model with tensorizer to the CPU.
@ -500,26 +490,30 @@ class TensorizerLoader(BaseModelLoader):
default HuggingFace loading, but will be slower than loading a default HuggingFace loading, but will be slower than loading a
vLLM-tensorized model. vLLM-tensorized model.
""" """
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(vllm_config=vllm_config)
lora_config, cache_config)
model.load_weights(self._get_weights_iterator()) model.load_weights(self._get_weights_iterator())
return model.eval() return model.eval()
def _load_model_serialized( def _load_model_serialized(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
) -> nn.Module: ) -> nn.Module:
"""Load a serialized model with tensorizer. """Load a serialized model with tensorizer.
Expects a vLLM-tensorized model. See the Expects a vLLM-tensorized model. See the
examples/tensorize_vllm_model.py example script examples/tensorize_vllm_model.py example script
for serializing vLLM models.""" for serializing vLLM models."""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
cache_config = vllm_config.cache_config
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0] model_class = get_model_architecture(model_config)[0]
@ -544,12 +538,9 @@ class TensorizerLoader(BaseModelLoader):
with self.tensorizer_config.open_stream(): with self.tensorizer_config.open_stream():
pass pass
def load_model(self, *, model_config: ModelConfig, def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config: DeviceConfig, model_config = vllm_config.model_config
lora_config: Optional[LoRAConfig], parallel_config = vllm_config.parallel_config
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
self._verify_config(model_config, parallel_config) self._verify_config(model_config, parallel_config)
if parallel_config.tensor_parallel_size > 1: if parallel_config.tensor_parallel_size > 1:
@ -559,10 +550,8 @@ class TensorizerLoader(BaseModelLoader):
% get_tensor_model_parallel_rank() % get_tensor_model_parallel_rank()
if is_vllm_tensorized(self.tensorizer_config): if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config, return self._load_model_serialized(vllm_config=vllm_config)
lora_config, cache_config) return self._load_model_serialized_cpu(vllm_config=vllm_config)
return self._load_model_serialized_cpu(model_config, device_config,
lora_config, cache_config)
@staticmethod @staticmethod
def save_model( def save_model(
@ -648,12 +637,9 @@ class ShardedStateLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision) self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, *, model_config: ModelConfig, def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config: DeviceConfig, device_config = vllm_config.device_config
lora_config: Optional[LoRAConfig], model_config = vllm_config.model_config
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
from safetensors.torch import safe_open from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
@ -663,8 +649,7 @@ class ShardedStateLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(vllm_config=vllm_config)
lora_config, cache_config)
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
if quant_method is not None: if quant_method is not None:
@ -1157,16 +1142,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision) self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, *, model_config: ModelConfig, def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config: DeviceConfig, device_config = vllm_config.device_config
lora_config: Optional[LoRAConfig], model_config = vllm_config.model_config
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(vllm_config=vllm_config)
lora_config, cache_config)
self._load_weights(model_config, model) self._load_weights(model_config, model)
@ -1235,13 +1216,9 @@ class GGUFModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model) self._prepare_weights(model_config.model)
def load_model(self, *, model_config: ModelConfig, def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config: DeviceConfig, device_config = vllm_config.device_config
lora_config: Optional[LoRAConfig], model_config = vllm_config.model_config
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
local_model_path = self._prepare_weights(model_config.model) local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config) gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights # we can only know if tie word embeddings after mapping weights
@ -1251,8 +1228,7 @@ class GGUFModelLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(vllm_config=vllm_config)
lora_config, cache_config)
model.load_weights( model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map)) self._get_weights_iterator(local_model_path, gguf_weights_map))
return model return model

View File

@ -1,8 +1,14 @@
import logging import logging
from typing import Callable, Optional, Union from typing import TYPE_CHECKING, Callable, Optional, Union
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
if TYPE_CHECKING:
from vllm.compilation.config import CompilationConfig
from vllm.config import VllmConfig
else:
CompilationConfig = None
VllmConfig = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,3 +61,15 @@ def set_compilation_config(config: Optional[CompilationConfig]):
def get_compilation_config() -> Optional[CompilationConfig]: def get_compilation_config() -> Optional[CompilationConfig]:
return _compilation_config return _compilation_config
_vllm_config: Optional[VllmConfig] = None
def set_vllm_config(config: Optional[VllmConfig]):
global _vllm_config
_vllm_config = config
def get_vllm_config() -> Optional[VllmConfig]:
return _vllm_config

View File

@ -369,13 +369,7 @@ class GPUModelRunner:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117 with DeviceMemoryProfiler() as m: # noqa: SIM117
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler): with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
self.model = get_model(model_config=self.model_config, self.model = get_model(vllm_config=self.vllm_config)
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB", logger.info("Loading model weights took %.4f GB",

View File

@ -453,13 +453,7 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
return uses_mrope(self.model_config.hf_config) return uses_mrope(self.model_config.hf_config)
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(model_config=self.model_config, self.model = get_model(vllm_config=self.vllm_config)
load_config=self.load_config,
device_config=self.device_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
def make_model_input_from_broadcasted_tensor_dict( def make_model_input_from_broadcasted_tensor_dict(
self, self,

View File

@ -1051,13 +1051,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def load_model(self) -> None: def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model) logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
self.model = get_model(model_config=self.model_config, self.model = get_model(vllm_config=self.vllm_config)
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB", logger.info("Loading model weights took %.4f GB",

View File

@ -137,15 +137,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
"vllm.model_executor.layers.vocab_parallel_embedding." "vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank", "get_tensor_model_parallel_rank",
return_value=xm_tp_rank): return_value=xm_tp_rank):
model = get_model( model = get_model(vllm_config=self.vllm_config)
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
parallel_config=self.parallel_config,
cache_config=self.cache_config,
scheduler_config=self.scheduler_config,
lora_config=None,
)
model = model.eval() model = model.eval()
xm.wait_device_ops() xm.wait_device_ops()
self.model = ModelWrapper(model) self.model = ModelWrapper(model)

View File

@ -405,15 +405,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
def load_model(self) -> None: def load_model(self) -> None:
with DeviceMemoryProfiler() as m: with DeviceMemoryProfiler() as m:
self.model = get_model( self.model = get_model(vllm_config=self.vllm_config)
model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
)
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB", logger.info("Loading model weights took %.4f GB",