mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 16:25:48 +08:00
[3/N][torch.compile] consolidate custom op logging (#10399)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
96d999fbe8
commit
a03ea40792
@ -4,8 +4,9 @@ import json
|
|||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, Final, List,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
|
||||||
Literal, Mapping, Optional, Set, Tuple, Type, Union)
|
Final, List, Literal, Mapping, Optional, Set, Tuple, Type,
|
||||||
|
Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@ -2169,6 +2170,10 @@ class CompilationConfig(BaseModel):
|
|||||||
compile_sizes: List[int] = PrivateAttr
|
compile_sizes: List[int] = PrivateAttr
|
||||||
capture_sizes: List[int] = PrivateAttr
|
capture_sizes: List[int] = PrivateAttr
|
||||||
|
|
||||||
|
# keep track of enabled and disabled custom ops
|
||||||
|
enabled_custom_ops: Counter[str] = PrivateAttr
|
||||||
|
disabled_custom_ops: Counter[str] = PrivateAttr
|
||||||
|
|
||||||
def model_post_init(self, __context: Any) -> None:
|
def model_post_init(self, __context: Any) -> None:
|
||||||
self.level = envs.VLLM_TORCH_COMPILE_LEVEL
|
self.level = envs.VLLM_TORCH_COMPILE_LEVEL
|
||||||
|
|
||||||
@ -2190,6 +2195,9 @@ class CompilationConfig(BaseModel):
|
|||||||
func = __import__(module).__dict__[func_name]
|
func = __import__(module).__dict__[func_name]
|
||||||
self.inductor_compile_config[k] = func
|
self.inductor_compile_config[k] = func
|
||||||
|
|
||||||
|
self.enabled_custom_ops = Counter()
|
||||||
|
self.disabled_custom_ops = Counter()
|
||||||
|
|
||||||
def init_backend(self) -> Union[str, Callable]:
|
def init_backend(self) -> Union[str, Callable]:
|
||||||
if self.level == CompilationLevel.NO_COMPILATION:
|
if self.level == CompilationLevel.NO_COMPILATION:
|
||||||
raise ValueError("No compilation level is set.")
|
raise ValueError("No compilation level is set.")
|
||||||
|
|||||||
@ -61,10 +61,13 @@ class CustomOp(nn.Module):
|
|||||||
def dispatch_forward(self):
|
def dispatch_forward(self):
|
||||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||||
# specific backend. Currently, we do not support dynamic dispatching.
|
# specific backend. Currently, we do not support dynamic dispatching.
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
enabled = self.enabled()
|
enabled = self.enabled()
|
||||||
logger.debug("custom op %s %s", self.__class__.name,
|
if enabled:
|
||||||
"enabled" if enabled else "disabled")
|
compilation_config.enabled_custom_ops.update([self.__class__.name])
|
||||||
|
else:
|
||||||
|
compilation_config.disabled_custom_ops.update(
|
||||||
|
[self.__class__.name])
|
||||||
|
|
||||||
if not enabled:
|
if not enabled:
|
||||||
return self.forward_native
|
return self.forward_native
|
||||||
|
|||||||
@ -80,6 +80,10 @@ def set_current_vllm_config(vllm_config: VllmConfig):
|
|||||||
_current_vllm_config = vllm_config
|
_current_vllm_config = vllm_config
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
logger.debug("enabled custom ops: %s",
|
||||||
|
vllm_config.compilation_config.enabled_custom_ops)
|
||||||
|
logger.debug("disabled custom ops: %s",
|
||||||
|
vllm_config.compilation_config.disabled_custom_ops)
|
||||||
_current_vllm_config = old_vllm_config
|
_current_vllm_config = old_vllm_config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user