diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index 5b57df2d472c8..2b54228e0a4ee 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -37,6 +37,28 @@ class CPUWorker(Worker): self.parallel_config.disable_custom_all_reduce = True + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" + logger.info( + "Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir, + ) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ], + record_shapes=envs.VLLM_TORCH_PROFILER_RECORD_SHAPES, + profile_memory=envs.VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY, + with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, + with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, worker_name=worker_name, use_gzip=False + ), + ) + else: + self.profiler = None + def init_device(self): # Setup OpenMP threads affinity. omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND @@ -166,3 +188,17 @@ class CPUWorker(Worker): [(x.id, x.physical_core) for x in logical_cpu_list], ) return ",".join([str(x.id) for x in logical_cpu_list]) + + def profile(self, is_start: bool = True): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + self.profiler.start() + else: + self.profiler.stop() + if self.local_rank == 0: + logger.info( + self.profiler.key_averages().table( + sort_by="self_cpu_time_total", row_limit=50 + ) + )