[Misc] Improve logging for dynamic shape cache compilation (#20573)

Signed-off-by: kyolebu <kyu@redhat.com>
This commit is contained in:
Kyle Yu 2025-07-07 20:48:09 -04:00 committed by GitHub
parent 14601f5fba
commit d2e841a10a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -120,10 +120,15 @@ class CompilerManager:
handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] handle = self.cache[(runtime_shape, graph_index, self.compiler.name)]
compiled_graph = self.compiler.load(handle, graph, example_inputs, compiled_graph = self.compiler.load(handle, graph, example_inputs,
graph_index, runtime_shape) graph_index, runtime_shape)
logger.debug( if runtime_shape is None:
"Directly load the %s-th graph for shape %s from %s via " logger.debug(
"handle %s", graph_index, str(runtime_shape), self.compiler.name, "Directly load the %s-th graph for dynamic shape from %s via "
handle) "handle %s", graph_index, self.compiler.name, handle)
else:
logger.debug(
"Directly load the %s-th graph for shape %s from %s via "
"handle %s", graph_index, str(runtime_shape),
self.compiler.name, handle)
return compiled_graph return compiled_graph
def compile(self, def compile(self,
@ -152,9 +157,15 @@ class CompilerManager:
# there can be multiple graphs due to piecewise compilation. # there can be multiple graphs due to piecewise compilation.
now = time.time() now = time.time()
elapsed = now - compilation_start_time elapsed = now - compilation_start_time
logger.info( if runtime_shape is None:
"Directly load the compiled graph(s) for shape %s " logger.info(
"from the cache, took %.3f s", str(runtime_shape), elapsed) "Directly load the compiled graph(s) for dynamic shape "
"from the cache, took %.3f s", elapsed)
else:
logger.info(
"Directly load the compiled graph(s) for shape %s "
"from the cache, took %.3f s", str(runtime_shape),
elapsed)
return compiled_graph return compiled_graph
# no compiler cached the graph, or the cache is disabled, # no compiler cached the graph, or the cache is disabled,
@ -178,11 +189,21 @@ class CompilerManager:
self.is_cache_updated = True self.is_cache_updated = True
if graph_index == 0: if graph_index == 0:
# adds some info logging for the first graph # adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use", if runtime_shape is None:
str(runtime_shape)) logger.info(
logger.debug( "Cache the graph for dynamic shape for later use")
"store the %s-th graph for shape %s from %s via handle %s", else:
graph_index, str(runtime_shape), self.compiler.name, handle) logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
if runtime_shape is None:
logger.debug(
"Store the %s-th graph for dynamic shape from %s via "
"handle %s", graph_index, self.compiler.name, handle)
else:
logger.debug(
"Store the %s-th graph for shape %s from %s via handle %s",
graph_index, str(runtime_shape), self.compiler.name,
handle)
# after compiling the last graph, record the end time # after compiling the last graph, record the end time
if graph_index == num_graphs - 1: if graph_index == num_graphs - 1:
@ -190,7 +211,7 @@ class CompilerManager:
elapsed = now - compilation_start_time elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed compilation_config.compilation_time += elapsed
if runtime_shape is None: if runtime_shape is None:
logger.info("Compiling a graph for general shape takes %.2f s", logger.info("Compiling a graph for dynamic shape takes %.2f s",
elapsed) elapsed)
else: else:
logger.info("Compiling a graph for shape %s takes %.2f s", logger.info("Compiling a graph for shape %s takes %.2f s",
@ -308,7 +329,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
i for i, x in enumerate(args) if isinstance(x, torch.SymInt) i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
] ]
global compilation_start_time global compilation_start_time
compiled_graph_for_general_shape = self.vllm_backend.\ compiled_graph_for_dynamic_shape = self.vllm_backend.\
compiler_manager.compile( compiler_manager.compile(
submod, submod,
args, args,
@ -323,7 +344,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
self.module.__dict__[target] = piecewise_backend( self.module.__dict__[target] = piecewise_backend(
submod, self.vllm_config, self.graph_pool, index, submod, self.vllm_config, self.graph_pool, index,
len(self.compile_submod_names), sym_shape_indices, len(self.compile_submod_names), sym_shape_indices,
compiled_graph_for_general_shape, self.vllm_backend) compiled_graph_for_dynamic_shape, self.vllm_backend)
compilation_counter.num_piecewise_capturable_graphs_seen += 1 compilation_counter.num_piecewise_capturable_graphs_seen += 1