diff --git a/examples/offline_inference/neuron_multimodal.py b/examples/offline_inference/neuron_multimodal.py index 6ff8faabd748b..26f7505f2fa53 100644 --- a/examples/offline_inference/neuron_multimodal.py +++ b/examples/offline_inference/neuron_multimodal.py @@ -64,7 +64,7 @@ def print_outputs(outputs): print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") -if __name__ == "__main__": +def main(): assert ( len(PROMPTS) == len(IMAGES) == len(SAMPLING_PARAMS) ), f"""Text, image prompts and sampling parameters should have the @@ -104,3 +104,7 @@ if __name__ == "__main__": # test batch-size = 4 outputs = llm.generate(batched_inputs, batched_sample_params) print_outputs(outputs) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/profiling_tpu/profiling.py b/examples/offline_inference/profiling_tpu/profiling.py index 5200be82694ab..dfcbd8c8d3605 100644 --- a/examples/offline_inference/profiling_tpu/profiling.py +++ b/examples/offline_inference/profiling_tpu/profiling.py @@ -70,7 +70,7 @@ def main(args: argparse.Namespace): return -if __name__ == "__main__": +def parse_args(): parser = FlexibleArgumentParser( description="Benchmark the latency of processing a single batch of " "requests till completion." @@ -102,5 +102,9 @@ if __name__ == "__main__": ) parser = EngineArgs.add_cli_args(parser) - args = parser.parse_args() + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() main(args)