[Core] add an option to log every function call to for debugging hang/crash in distributed inference (#4079)

Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
youkaichao 2024-04-18 16:15:12 -07:00 committed by GitHub
parent 8f9c28fd40
commit 8a7a3e4436
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 120 additions and 8 deletions

View File

@ -40,7 +40,7 @@ steps:
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
- label: Engine Test - label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py command: pytest -v -s engine tokenization test_sequence.py test_config.py test_logger.py
- label: Entrypoints Test - label: Entrypoints Test
commands: commands:

View File

@ -57,6 +57,8 @@ body:
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
If you experienced crashes or hangs, it would be helpful to run vllm with `export VLLM_TRACE_FUNCTION=1` . All the function calls in vllm will be recorded. Inspect these log files, and tell which function crashes or hangs.
placeholder: | placeholder: |
A clear and concise description of what the bug is. A clear and concise description of what the bug is.

27
tests/test_logger.py Normal file
View File

@ -0,0 +1,27 @@
import os
import sys
import tempfile
from vllm.logger import enable_trace_function_call
def f1(x):
return f2(x)
def f2(x):
return x
def test_trace_function_call():
fd, path = tempfile.mkstemp()
cur_dir = os.path.dirname(__file__)
enable_trace_function_call(path, cur_dir)
f1(1)
with open(path, 'r') as f:
content = f.read()
assert "f1" in content
assert "f2" in content
sys.settrace(None)
os.remove(path)

View File

@ -10,7 +10,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async) get_vllm_instance_id, make_async)
if ray is not None: if ray is not None:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@ -133,12 +133,18 @@ class RayGPUExecutor(ExecutorBase):
for node_id, gpu_ids in node_gpus.items(): for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids) node_gpus[node_id] = sorted(gpu_ids)
# Set CUDA_VISIBLE_DEVICES for the driver and workers. VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = [] all_args_to_update_environment_variables = []
for (node_id, _) in worker_node_and_gpu_ids: for (node_id, _) in worker_node_and_gpu_ids:
all_args_to_update_environment_variables.append([{ all_args_to_update_environment_variables.append([{
"CUDA_VISIBLE_DEVICES": "CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])) ",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
os.getenv("VLLM_TRACE_FUNCTION", "0"),
}]) }])
self._run_workers("update_environment_variables", self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables) all_args=all_args_to_update_environment_variables)

View File

@ -1,9 +1,11 @@
# Adapted from # Adapted from
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
"""Logging configuration for vLLM.""" """Logging configuration for vLLM."""
import datetime
import logging import logging
import os import os
import sys import sys
from functools import partial
from typing import Optional from typing import Optional
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
@ -65,3 +67,53 @@ def init_logger(name: str):
logger.addHandler(_default_handler) logger.addHandler(_default_handler)
logger.propagate = False logger.propagate = False
return logger return logger
logger = init_logger(__name__)
def _trace_calls(log_path, root_dir, frame, event, arg=None):
if event in ['call', 'return']:
# Extract the filename, line number, function name, and the code object
filename = frame.f_code.co_filename
lineno = frame.f_lineno
func_name = frame.f_code.co_name
if not filename.startswith(root_dir):
# only log the functions in the vllm root_dir
return
# Log every function call or return
try:
with open(log_path, 'a') as f:
if event == 'call':
f.write(f"{datetime.datetime.now()} Call to"
f" {func_name} in {filename}:{lineno}\n")
else:
f.write(f"{datetime.datetime.now()} Return from"
f" {func_name} in {filename}:{lineno}\n")
except NameError:
# modules are deleted during shutdown
pass
return partial(_trace_calls, log_path, root_dir)
def enable_trace_function_call(log_file_path: str,
root_dir: Optional[str] = None):
"""
Enable tracing of every function call in code under `root_dir`.
This is useful for debugging hangs or crashes.
`log_file_path` is the path to the log file.
`root_dir` is the root directory of the code to trace. If None, it is the
vllm root directory.
Note that this call is thread-level, any threads calling this function
will have the trace enabled. Other threads will not be affected.
"""
logger.warning(
"VLLM_TRACE_FUNCTION is enabled. It will record every"
" function executed by Python. This will slow down the code. It "
"is suggested to be used for debugging hang or crashes only.")
logger.info(f"Trace frame log is saved to {log_file_path}")
if root_dir is None:
# by default, this is the vllm root directory
root_dir = os.path.dirname(os.path.dirname(__file__))
sys.settrace(partial(_trace_calls, log_file_path, root_dir))

View File

@ -163,6 +163,17 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
@lru_cache(maxsize=None)
def get_vllm_instance_id():
"""
If the environment variable VLLM_INSTANCE_ID is set, return it.
Otherwise, return a random UUID.
Instance id represents an instance of the VLLM. All processes in the same
instance should have the same instance id.
"""
return os.environ.get("VLLM_INSTANCE_ID", f"vllm-instance-{random_uuid()}")
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
def in_wsl() -> bool: def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071 # Reference: https://github.com/microsoft/WSL/issues/4071
@ -274,7 +285,7 @@ def get_open_port() -> int:
def update_environment_variables(envs: Dict[str, str]): def update_environment_variables(envs: Dict[str, str]):
for k, v in envs.items(): for k, v in envs.items():
if k in os.environ: if k in os.environ and os.environ[k] != v:
logger.warning(f"Overwriting environment variable {k} " logger.warning(f"Overwriting environment variable {k} "
f"from '{os.environ[k]}' to '{v}'") f"from '{os.environ[k]}' to '{v}'")
os.environ[k] = v os.environ[k] = v

View File

@ -1,12 +1,15 @@
import datetime
import importlib import importlib
import os import os
import tempfile
import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple from typing import Dict, List, Set, Tuple
from vllm.logger import init_logger from vllm.logger import enable_trace_function_call, init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import update_environment_variables from vllm.utils import get_vllm_instance_id, update_environment_variables
logger = init_logger(__name__) logger = init_logger(__name__)
@ -115,9 +118,20 @@ class WorkerWrapperBase:
def init_worker(self, *args, **kwargs): def init_worker(self, *args, **kwargs):
""" """
Actual initialization of the worker class. Actual initialization of the worker class, and set up
function tracing if required.
Arguments are passed to the worker class constructor. Arguments are passed to the worker class constructor.
""" """
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
mod = importlib.import_module(self.worker_module_name) mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name) worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs) self.worker = worker_class(*args, **kwargs)