[Misc] Improve logging for dynamic shape cache compilation (#20573)

Signed-off-by: kyolebu <kyu@redhat.com>
This commit is contained in:
Kyle Yu 2025-07-07 20:48:09 -04:00 committed by GitHub
parent 14601f5fba
commit d2e841a10a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -120,10 +120,15 @@ class CompilerManager:
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
compiled_graph = self.compiler.load(handle, graph, example_inputs,
graph_index, runtime_shape)
if runtime_shape is None:
logger.debug(
"Directly load the %s-th graph for dynamic shape from %s via "
"handle %s", graph_index, self.compiler.name, handle)
else:
logger.debug(
"Directly load the %s-th graph for shape %s from %s via "
"handle %s", graph_index, str(runtime_shape), self.compiler.name,
handle)
"handle %s", graph_index, str(runtime_shape),
self.compiler.name, handle)
return compiled_graph
def compile(self,
@ -152,9 +157,15 @@ class CompilerManager:
# there can be multiple graphs due to piecewise compilation.
now = time.time()
elapsed = now - compilation_start_time
if runtime_shape is None:
logger.info(
"Directly load the compiled graph(s) for dynamic shape "
"from the cache, took %.3f s", elapsed)
else:
logger.info(
"Directly load the compiled graph(s) for shape %s "
"from the cache, took %.3f s", str(runtime_shape), elapsed)
"from the cache, took %.3f s", str(runtime_shape),
elapsed)
return compiled_graph
# no compiler cached the graph, or the cache is disabled,
@ -178,11 +189,21 @@ class CompilerManager:
self.is_cache_updated = True
if graph_index == 0:
# adds some info logging for the first graph
if runtime_shape is None:
logger.info(
"Cache the graph for dynamic shape for later use")
else:
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
if runtime_shape is None:
logger.debug(
"store the %s-th graph for shape %s from %s via handle %s",
graph_index, str(runtime_shape), self.compiler.name, handle)
"Store the %s-th graph for dynamic shape from %s via "
"handle %s", graph_index, self.compiler.name, handle)
else:
logger.debug(
"Store the %s-th graph for shape %s from %s via handle %s",
graph_index, str(runtime_shape), self.compiler.name,
handle)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
@ -190,7 +211,7 @@ class CompilerManager:
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info("Compiling a graph for general shape takes %.2f s",
logger.info("Compiling a graph for dynamic shape takes %.2f s",
elapsed)
else:
logger.info("Compiling a graph for shape %s takes %.2f s",
@ -308,7 +329,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
global compilation_start_time
compiled_graph_for_general_shape = self.vllm_backend.\
compiled_graph_for_dynamic_shape = self.vllm_backend.\
compiler_manager.compile(
submod,
args,
@ -323,7 +344,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.module.__dict__[target] = piecewise_backend(
submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape, self.vllm_backend)
compiled_graph_for_dynamic_shape, self.vllm_backend)
compilation_counter.num_piecewise_capturable_graphs_seen += 1