diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 446de93cc430e..8f6022f13e0a1 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -7,6 +7,7 @@ from typing import Any, Optional import numpy as np import torch +import torch.nn as nn from vllm.config import VllmConfig from vllm.distributed import get_tp_group @@ -107,6 +108,9 @@ class GPUModelRunner: m.consumed_memory / GiB_bytes, time_after_load - time_before_load) + def get_model(self) -> nn.Module: + return self.model + def get_kv_cache_spec(self): return get_kv_cache_spec(self.vllm_config, self.kv_cache_dtype)