mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 15:11:21 +08:00
Combine loader _process_weights_after_loading
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
d56ef8b685
commit
09318caeba
@ -153,11 +153,11 @@ def _initialize_model(
|
|||||||
return model_class(**kwargs)
|
return model_class(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _process_attention_weights_after_loading(
|
def _process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
||||||
model: nn.Module, model_config: ModelConfig) -> None:
|
target_device: torch.device) -> None:
|
||||||
# Currently only used by MLA.
|
# Currently only used by MLA.
|
||||||
# NOTE: This intentionally happens before other modules so
|
# NOTE: This intentionally happens before other modules so we can easily
|
||||||
# we can easily decompress the weights for MLA.
|
# decompress the weights for MLA.
|
||||||
for _, module in model.named_modules():
|
for _, module in model.named_modules():
|
||||||
if isinstance(module, Attention) and \
|
if isinstance(module, Attention) and \
|
||||||
hasattr(module, "process_weights_after_loading"):
|
hasattr(module, "process_weights_after_loading"):
|
||||||
@ -165,6 +165,17 @@ def _process_attention_weights_after_loading(
|
|||||||
# of process_weights_after_loading
|
# of process_weights_after_loading
|
||||||
module.process_weights_after_loading(model_config.dtype)
|
module.process_weights_after_loading(model_config.dtype)
|
||||||
|
|
||||||
|
for _, module in model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if isinstance(quant_method, QuantizeMethodBase):
|
||||||
|
# When quant methods need to process weights after loading
|
||||||
|
# (for repacking, quantizing, etc), they expect parameters
|
||||||
|
# to be on the global target device. This scope is for the
|
||||||
|
# case where cpu offloading is used, where we will move the
|
||||||
|
# parameters onto device for processing and back off after.
|
||||||
|
with device_loading_context(module, target_device):
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
|
||||||
|
|
||||||
class BaseModelLoader(ABC):
|
class BaseModelLoader(ABC):
|
||||||
"""Base class for model loaders."""
|
"""Base class for model loaders."""
|
||||||
@ -389,7 +400,6 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
|
|
||||||
target_device = torch.device(device_config.device)
|
target_device = torch.device(device_config.device)
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with target_device:
|
with target_device:
|
||||||
@ -407,18 +417,8 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
"Following weights were not initialized from "
|
"Following weights were not initialized from "
|
||||||
f"checkpoint: {weights_not_loaded}")
|
f"checkpoint: {weights_not_loaded}")
|
||||||
|
|
||||||
_process_attention_weights_after_loading(model, model_config)
|
_process_weights_after_loading(model, model_config, target_device)
|
||||||
|
|
||||||
for _, module in model.named_modules():
|
|
||||||
quant_method = getattr(module, "quant_method", None)
|
|
||||||
if isinstance(quant_method, QuantizeMethodBase):
|
|
||||||
# When quant methods need to process weights after loading
|
|
||||||
# (for repacking, quantizing, etc), they expect parameters
|
|
||||||
# to be on the global target device. This scope is for the
|
|
||||||
# case where cpu offloading is used, where we will move the
|
|
||||||
# parameters onto device for processing and back off after.
|
|
||||||
with device_loading_context(module, target_device):
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
@ -437,26 +437,15 @@ class DummyModelLoader(BaseModelLoader):
|
|||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
|
target_device = torch.device(device_config.device)
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with target_device:
|
||||||
model = _initialize_model(vllm_config=vllm_config)
|
model = _initialize_model(vllm_config=vllm_config)
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
# random values to the weights.
|
# random values to the weights.
|
||||||
initialize_dummy_weights(model)
|
initialize_dummy_weights(model)
|
||||||
|
|
||||||
_process_attention_weights_after_loading(model, model_config)
|
_process_weights_after_loading(model, model_config, target_device)
|
||||||
|
|
||||||
for _, module in model.named_modules():
|
|
||||||
quant_method = getattr(module, "quant_method", None)
|
|
||||||
if quant_method is not None:
|
|
||||||
# When quant methods need to process weights after loading
|
|
||||||
# (for repacking, quantizing, etc), they expect parameters
|
|
||||||
# to be on the global target device. This scope is for the
|
|
||||||
# case where cpu offloading is used, where we will move the
|
|
||||||
# parameters onto device for processing and back off after.
|
|
||||||
with device_loading_context(
|
|
||||||
module, torch.device(device_config.device)):
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
@ -637,6 +626,7 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
def load_model(self, vllm_config: VllmConfig) -> nn.Module:
|
||||||
device_config = vllm_config.device_config
|
device_config = vllm_config.device_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
|
target_device = torch.device(device_config.device)
|
||||||
from safetensors.torch import safe_open
|
from safetensors.torch import safe_open
|
||||||
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
@ -645,13 +635,10 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
model_config.revision)
|
model_config.revision)
|
||||||
|
|
||||||
with set_default_torch_dtype(model_config.dtype):
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
with torch.device(device_config.device):
|
with target_device:
|
||||||
model = _initialize_model(vllm_config=vllm_config)
|
model = _initialize_model(vllm_config=vllm_config)
|
||||||
_process_attention_weights_after_loading(model, model_config)
|
_process_weights_after_loading(model, model_config,
|
||||||
for _, module in model.named_modules():
|
target_device)
|
||||||
quant_method = getattr(module, "quant_method", None)
|
|
||||||
if quant_method is not None:
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
rank = get_tensor_model_parallel_rank()
|
||||||
pattern = os.path.join(
|
pattern = os.path.join(
|
||||||
local_model_path,
|
local_model_path,
|
||||||
@ -1401,12 +1388,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
|
|||||||
self._get_weights_iterator(model_weights,
|
self._get_weights_iterator(model_weights,
|
||||||
model_config.revision))
|
model_config.revision))
|
||||||
|
|
||||||
_process_attention_weights_after_loading(model, model_config)
|
_process_weights_after_loading(model, model_config, target_device)
|
||||||
for _, module in model.named_modules():
|
|
||||||
quant_method = getattr(module, "quant_method", None)
|
|
||||||
if quant_method is not None:
|
|
||||||
with device_loading_context(module, target_device):
|
|
||||||
quant_method.process_weights_after_loading(module)
|
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user