[3/N] model runner pass the whole config to model (#9958)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-02 12:08:49 -07:00 committed by GitHub
parent 74b529ceee
commit cea808f325
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 87 additions and 140 deletions

View File

@ -248,11 +248,10 @@ def llama_2_7b_engine_extra_embeddings():
cleanup_dist_env_and_memory(shutdown_ray=True)
get_model_old = get_model
def get_model_patched(*, model_config, device_config, **kwargs):
kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8)
return get_model_old(model_config=model_config,
device_config=device_config,
**kwargs)
def get_model_patched(**kwargs):
kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
max_lora_rank=8)
return get_model_old(**kwargs)
with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)

View File

@ -1,27 +1,15 @@
from typing import Optional
from torch import nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.config import VllmConfig
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader)
from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture)
def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config,
device_config=device_config,
lora_config=lora_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config)
def get_model(*, vllm_config: VllmConfig) -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
return loader.load_model(vllm_config=vllm_config)
__all__ = [

View File

@ -21,9 +21,9 @@ from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, PoolerConfig, SchedulerConfig)
from vllm.config import (CacheConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PoolerConfig, SchedulerConfig, VllmConfig)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
@ -150,6 +150,7 @@ def _get_model_initialization_kwargs(
def build_model(model_class: Type[nn.Module],
vllm_config: VllmConfig,
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
@ -166,23 +167,29 @@ def build_model(model_class: Type[nn.Module],
if prefix:
extra_kwargs["prefix"] = prefix
# TODO: unify all the module initialization code
# to only take the `VllmConfig` object as input
from vllm.plugins import set_vllm_config
set_vllm_config(vllm_config)
return model_class(config=hf_config,
cache_config=cache_config,
quant_config=quant_config,
**extra_kwargs)
def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
load_config = vllm_config.load_config
model_class, _ = get_model_architecture(model_config)
return build_model(
model_class,
vllm_config,
model_config.hf_config,
cache_config=cache_config,
quant_config=_get_quantization_config(model_config, load_config),
@ -205,12 +212,7 @@ class BaseModelLoader(ABC):
raise NotImplementedError
@abstractmethod
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
"""Load a model with the given configurations."""
raise NotImplementedError
@ -396,18 +398,14 @@ class DefaultModelLoader(BaseModelLoader):
model_config.revision,
fall_back_to_pt=True)
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
model = _initialize_model(vllm_config=vllm_config)
model.load_weights(self._get_all_weights(model_config, model))
@ -436,17 +434,12 @@ class DummyModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
model = _initialize_model(vllm_config=vllm_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
@ -488,10 +481,7 @@ class TensorizerLoader(BaseModelLoader):
def _load_model_serialized_cpu(
self,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.
@ -500,26 +490,30 @@ class TensorizerLoader(BaseModelLoader):
default HuggingFace loading, but will be slower than loading a
vLLM-tensorized model.
"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)
model.load_weights(self._get_weights_iterator())
return model.eval()
def _load_model_serialized(
self,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
cache_config: CacheConfig,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer.
Expects a vLLM-tensorized model. See the
examples/tensorize_vllm_model.py example script
for serializing vLLM models."""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
cache_config = vllm_config.cache_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
@ -544,12 +538,9 @@ class TensorizerLoader(BaseModelLoader):
with self.tensorizer_config.open_stream():
pass
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)
if parallel_config.tensor_parallel_size > 1:
@ -559,10 +550,8 @@ class TensorizerLoader(BaseModelLoader):
% get_tensor_model_parallel_rank()
if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config,
lora_config, cache_config)
return self._load_model_serialized_cpu(model_config, device_config,
lora_config, cache_config)
return self._load_model_serialized(vllm_config=vllm_config)
return self._load_model_serialized_cpu(vllm_config=vllm_config)
@staticmethod
def save_model(
@ -648,12 +637,9 @@ class ShardedStateLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank
@ -663,8 +649,7 @@ class ShardedStateLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
@ -1157,16 +1142,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)
self._load_weights(model_config, model)
@ -1235,13 +1216,9 @@ class GGUFModelLoader(BaseModelLoader):
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
device_config = vllm_config.device_config
model_config = vllm_config.model_config
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
@ -1251,8 +1228,7 @@ class GGUFModelLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model = _initialize_model(vllm_config=vllm_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))
return model

View File

@ -1,8 +1,14 @@
import logging
from typing import Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Union
import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
if TYPE_CHECKING:
from vllm.compilation.config import CompilationConfig
from vllm.config import VllmConfig
else:
CompilationConfig = None
VllmConfig = None
logger = logging.getLogger(__name__)
@ -55,3 +61,15 @@ def set_compilation_config(config: Optional[CompilationConfig]):
def get_compilation_config() -> Optional[CompilationConfig]:
return _compilation_config
_vllm_config: Optional[VllmConfig] = None
def set_vllm_config(config: Optional[VllmConfig]):
global _vllm_config
_vllm_config = config
def get_vllm_config() -> Optional[VllmConfig]:
return _vllm_config

View File

@ -369,13 +369,7 @@ class GPUModelRunner:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model = get_model(vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",

View File

@ -453,13 +453,7 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
return uses_mrope(self.model_config.hf_config)
def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model = get_model(vllm_config=self.vllm_config)
def make_model_input_from_broadcasted_tensor_dict(
self,

View File

@ -1051,13 +1051,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m:
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
self.model = get_model(vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",

View File

@ -137,15 +137,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=self.device_config,
parallel_config=self.parallel_config,
cache_config=self.cache_config,
scheduler_config=self.scheduler_config,
lora_config=None,
)
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
xm.wait_device_ops()
self.model = ModelWrapper(model)

View File

@ -405,15 +405,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
def load_model(self) -> None:
with DeviceMemoryProfiler() as m:
self.model = get_model(
model_config=self.model_config,
device_config=self.device_config,
load_config=self.load_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config,
)
self.model = get_model(vllm_config=self.vllm_config)
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",