diff --git a/vllm/compilation/base_static_graph.py b/vllm/compilation/base_static_graph.py index 161d066ce9fb8..6ee82e74963d9 100644 --- a/vllm/compilation/base_static_graph.py +++ b/vllm/compilation/base_static_graph.py @@ -12,8 +12,13 @@ class AbstractStaticGraphWrapper(Protocol): to be captured as a static graph. """ - def __init__(self, runnable: Callable, vllm_config: VllmConfig, - runtime_mode: CUDAGraphMode, **kwargs): + def __init__( + self, + runnable: Callable[..., Any], + vllm_config: VllmConfig, + runtime_mode: CUDAGraphMode, + **kwargs: Any, + ) -> None: """ Initializes the StaticGraphWrapper class with graph capturing and execution-related configurations. @@ -31,7 +36,7 @@ class AbstractStaticGraphWrapper(Protocol): """ raise NotImplementedError - def __call__(self, *args, **kwargs) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> Any: """ Executes the wrapped callable.