mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:27:19 +08:00
add compile
This commit is contained in:
parent
55712941e5
commit
617fb893d5
@ -787,6 +787,16 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
"provided. Defaulting to scaling factors of 1.0. "
|
||||
"This may lead to less accurate results!")
|
||||
|
||||
count = 0
|
||||
|
||||
def backend(gm, input):
|
||||
nonlocal count
|
||||
count += 1
|
||||
print(count)
|
||||
return gm.forward
|
||||
|
||||
self.model = torch.compile(self.model, backend=backend, fullgraph=True)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user