[BugFix][TritonMLA] Process weights after model loading for GGUF (#14555)

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
This commit is contained in:
TY-AMD 2025-03-13 11:14:36 +08:00 committed by GitHub
parent a94a699c3f
commit 128bf75283
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1330,11 +1330,14 @@ class GGUFModelLoader(BaseModelLoader):
local_model_path, gguf_weights_map):
model_config.hf_config.update({"tie_word_embeddings": True})
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))
_process_weights_after_loading(model, model_config, target_device)
return model