mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 18:45:02 +08:00
[3/N] model runner pass the whole config to model (#9958)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
74b529ceee
commit
cea808f325
@ -248,11 +248,10 @@ def llama_2_7b_engine_extra_embeddings():
|
||||
cleanup_dist_env_and_memory(shutdown_ray=True)
|
||||
get_model_old = get_model
|
||||
|
||||
def get_model_patched(*, model_config, device_config, **kwargs):
|
||||
kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8)
|
||||
return get_model_old(model_config=model_config,
|
||||
device_config=device_config,
|
||||
**kwargs)
|
||||
def get_model_patched(**kwargs):
|
||||
kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
|
||||
max_lora_rank=8)
|
||||
return get_model_old(**kwargs)
|
||||
|
||||
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
||||
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
||||
|
||||
@ -1,27 +1,15 @@
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
|
||||
get_model_loader)
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_architecture_class_name, get_model_architecture)
|
||||
|
||||
|
||||
def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
|
||||
device_config: DeviceConfig, parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
loader = get_model_loader(load_config)
|
||||
return loader.load_model(model_config=model_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
cache_config=cache_config)
|
||||
def get_model(*, vllm_config: VllmConfig) -> nn.Module:
|
||||
loader = get_model_loader(vllm_config.load_config)
|
||||
return loader.load_model(vllm_config=vllm_config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -21,9 +21,9 @@ from torch import nn
|
||||
from transformers import AutoModelForCausalLM, PretrainedConfig
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, MultiModalConfig,
|
||||
ParallelConfig, PoolerConfig, SchedulerConfig)
|
||||
from vllm.config import (CacheConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, MultiModalConfig, ParallelConfig,
|
||||
PoolerConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
@ -150,6 +150,7 @@ def _get_model_initialization_kwargs(
|
||||
|
||||
|
||||
def build_model(model_class: Type[nn.Module],
|
||||
vllm_config: VllmConfig,
|
||||
hf_config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
@ -166,23 +167,29 @@ def build_model(model_class: Type[nn.Module],
|
||||
if prefix:
|
||||
extra_kwargs["prefix"] = prefix
|
||||
|
||||
# TODO: unify all the module initialization code
|
||||
# to only take the `VllmConfig` object as input
|
||||
from vllm.plugins import set_vllm_config
|
||||
set_vllm_config(vllm_config)
|
||||
|
||||
return model_class(config=hf_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
**extra_kwargs)
|
||||
|
||||
|
||||
def _initialize_model(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
cache_config: CacheConfig,
|
||||
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
|
||||
def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_config = vllm_config.model_config
|
||||
lora_config = vllm_config.lora_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
cache_config = vllm_config.cache_config
|
||||
load_config = vllm_config.load_config
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
|
||||
return build_model(
|
||||
model_class,
|
||||
vllm_config,
|
||||
model_config.hf_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=_get_quantization_config(model_config, load_config),
|
||||
@ -205,12 +212,7 @@ class BaseModelLoader(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
def load_model(self, *, vllm_config: VllmConfig) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -396,18 +398,14 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
model_config.revision,
|
||||
fall_back_to_pt=True)
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, cache_config,
|
||||
scheduler_config)
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
|
||||
model.load_weights(self._get_all_weights(model_config, model))
|
||||
|
||||
@ -436,17 +434,12 @@ class DummyModelLoader(BaseModelLoader):
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass # Nothing to download
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, cache_config,
|
||||
scheduler_config)
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
@ -488,10 +481,7 @@ class TensorizerLoader(BaseModelLoader):
|
||||
|
||||
def _load_model_serialized_cpu(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
cache_config: CacheConfig,
|
||||
vllm_config: VllmConfig,
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer to the CPU.
|
||||
|
||||
@ -500,26 +490,30 @@ class TensorizerLoader(BaseModelLoader):
|
||||
default HuggingFace loading, but will be slower than loading a
|
||||
vLLM-tensorized model.
|
||||
"""
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, cache_config)
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
return model.eval()
|
||||
|
||||
def _load_model_serialized(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
cache_config: CacheConfig,
|
||||
vllm_config: VllmConfig,
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer.
|
||||
|
||||
Expects a vLLM-tensorized model. See the
|
||||
examples/tensorize_vllm_model.py example script
|
||||
for serializing vLLM models."""
|
||||
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
lora_config = vllm_config.lora_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
@ -544,12 +538,9 @@ class TensorizerLoader(BaseModelLoader):
|
||||
with self.tensorizer_config.open_stream():
|
||||
pass
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self._verify_config(model_config, parallel_config)
|
||||
|
||||
if parallel_config.tensor_parallel_size > 1:
|
||||
@ -559,10 +550,8 @@ class TensorizerLoader(BaseModelLoader):
|
||||
% get_tensor_model_parallel_rank()
|
||||
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
return self._load_model_serialized(model_config, device_config,
|
||||
lora_config, cache_config)
|
||||
return self._load_model_serialized_cpu(model_config, device_config,
|
||||
lora_config, cache_config)
|
||||
return self._load_model_serialized(vllm_config=vllm_config)
|
||||
return self._load_model_serialized_cpu(vllm_config=vllm_config)
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
@ -648,12 +637,9 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
@ -663,8 +649,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, cache_config)
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
@ -1157,16 +1142,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, cache_config)
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
|
||||
self._load_weights(model_config, model)
|
||||
|
||||
@ -1235,13 +1216,9 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model)
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig) -> nn.Module:
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
local_model_path = self._prepare_weights(model_config.model)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
# we can only know if tie word embeddings after mapping weights
|
||||
@ -1251,8 +1228,7 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, cache_config)
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(local_model_path, gguf_weights_map))
|
||||
return model
|
||||
|
||||
@ -1,8 +1,14 @@
|
||||
import logging
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.config import CompilationConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.compilation.config import CompilationConfig
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
CompilationConfig = None
|
||||
VllmConfig = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -55,3 +61,15 @@ def set_compilation_config(config: Optional[CompilationConfig]):
|
||||
|
||||
def get_compilation_config() -> Optional[CompilationConfig]:
|
||||
return _compilation_config
|
||||
|
||||
|
||||
_vllm_config: Optional[VllmConfig] = None
|
||||
|
||||
|
||||
def set_vllm_config(config: Optional[VllmConfig]):
|
||||
global _vllm_config
|
||||
_vllm_config = config
|
||||
|
||||
|
||||
def get_vllm_config() -> Optional[VllmConfig]:
|
||||
return _vllm_config
|
||||
|
||||
@ -369,13 +369,7 @@ class GPUModelRunner:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
device_config=self.device_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
|
||||
@ -453,13 +453,7 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
|
||||
return uses_mrope(self.model_config.hf_config)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=self.device_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
|
||||
@ -1051,13 +1051,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m:
|
||||
self.model = get_model(model_config=self.model_config,
|
||||
device_config=self.device_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
|
||||
@ -137,15 +137,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=xm_tp_rank):
|
||||
model = get_model(
|
||||
model_config=self.model_config,
|
||||
load_config=self.load_config,
|
||||
device_config=self.device_config,
|
||||
parallel_config=self.parallel_config,
|
||||
cache_config=self.cache_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
lora_config=None,
|
||||
)
|
||||
model = get_model(vllm_config=self.vllm_config)
|
||||
model = model.eval()
|
||||
xm.wait_device_ops()
|
||||
self.model = ModelWrapper(model)
|
||||
|
||||
@ -405,15 +405,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
|
||||
def load_model(self) -> None:
|
||||
with DeviceMemoryProfiler() as m:
|
||||
self.model = get_model(
|
||||
model_config=self.model_config,
|
||||
device_config=self.device_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
cache_config=self.cache_config,
|
||||
)
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user