Set torch_dtype in TransformersModel (#13088)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-02-11 15:51:19 +00:00 committed by GitHub
parent 75e6e14516
commit ad9776353e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -143,6 +143,7 @@ class TransformersModel(nn.Module):
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
attn_implementation="vllm",
torch_dtype=vllm_config.model_config.dtype,
trust_remote_code=vllm_config.model_config.trust_remote_code,
)
prefix = self.model.base_model_prefix