mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:45:01 +08:00
321 lines
13 KiB
Python
321 lines
13 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""A GPU worker class."""
|
|
import gc
|
|
import os
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
import torch
|
|
import torch.distributed
|
|
import torch.nn as nn
|
|
|
|
import vllm.envs as envs
|
|
from vllm.config import ParallelConfig, VllmConfig
|
|
from vllm.device_allocator.cumem import CuMemAllocator
|
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
|
init_distributed_environment,
|
|
set_custom_all_reduce)
|
|
from vllm.distributed.parallel_state import get_pp_group
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.model_executor import set_random_seed
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils import GiB_bytes
|
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
|
from vllm.v1.outputs import ModelRunnerOutput
|
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
|
from vllm.v1.worker.worker_base import WorkerBase
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
|
|
|
|
|
class Worker(WorkerBase):
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
local_rank: int,
|
|
rank: int,
|
|
distributed_init_method: str,
|
|
is_driver_worker: bool = False,
|
|
):
|
|
|
|
super().__init__(vllm_config=vllm_config,
|
|
local_rank=local_rank,
|
|
rank=rank,
|
|
distributed_init_method=distributed_init_method,
|
|
is_driver_worker=is_driver_worker)
|
|
|
|
if self.model_config.trust_remote_code:
|
|
# note: lazy import to avoid importing torch before initializing
|
|
from vllm.utils import init_cached_hf_modules
|
|
init_cached_hf_modules()
|
|
|
|
# Torch profiler. Enabled and configured through env vars:
|
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
|
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
torch_profiler_trace_dir)
|
|
self.profiler = torch.profiler.profile(
|
|
activities=[
|
|
torch.profiler.ProfilerActivity.CPU,
|
|
torch.profiler.ProfilerActivity.CUDA,
|
|
],
|
|
with_stack=True,
|
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
|
torch_profiler_trace_dir, use_gzip=True))
|
|
else:
|
|
self.profiler = None
|
|
|
|
def sleep(self, level: int = 1) -> None:
|
|
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
|
|
allocator = CuMemAllocator.get_instance()
|
|
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
|
free_bytes_after_sleep, total = torch.cuda.mem_get_info()
|
|
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
|
used_bytes = total - free_bytes_after_sleep
|
|
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
|
logger.info(
|
|
"Sleep mode freed %.2f GiB memory, "
|
|
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
|
used_bytes / GiB_bytes)
|
|
|
|
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
|
allocator = CuMemAllocator.get_instance()
|
|
allocator.wake_up(tags)
|
|
|
|
def init_device(self):
|
|
if self.device_config.device.type == "cuda":
|
|
# torch.distributed.all_reduce does not free the input tensor until
|
|
# the synchronization point. This causes the memory usage to grow
|
|
# as the number of all_reduce calls increases. This env var disables
|
|
# this behavior.
|
|
# Related issue:
|
|
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
|
|
|
# This env var set by Ray causes exceptions with graph building.
|
|
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
|
self.device = torch.device(f"cuda:{self.local_rank}")
|
|
torch.cuda.set_device(self.device)
|
|
|
|
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
|
|
else:
|
|
raise RuntimeError(
|
|
f"Not support device type: {self.device_config.device}")
|
|
# Initialize the distributed environment.
|
|
init_worker_distributed_environment(self.parallel_config, self.rank,
|
|
self.distributed_init_method,
|
|
self.local_rank)
|
|
# Set random seed.
|
|
set_random_seed(self.model_config.seed)
|
|
|
|
# Construct the model runner
|
|
self.model_runner: GPUModelRunner = GPUModelRunner(
|
|
self.vllm_config, self.device)
|
|
|
|
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
|
|
# to hijack tensor allocation.
|
|
def load_model(self) -> None:
|
|
if self.vllm_config.model_config.enable_sleep_mode:
|
|
allocator = CuMemAllocator.get_instance()
|
|
assert allocator.get_current_usage() == 0, (
|
|
"Sleep mode can only be "
|
|
"used for one instance per process.")
|
|
context = allocator.use_memory_pool(tag="weights")
|
|
else:
|
|
from contextlib import nullcontext
|
|
context = nullcontext()
|
|
with context:
|
|
self.model_runner.load_model()
|
|
|
|
@torch.inference_mode()
|
|
def determine_available_memory(self) -> int:
|
|
"""Profiles the peak memory usage of the model to determine how much
|
|
memory can be used for KV cache without OOMs.
|
|
|
|
The engine will first conduct a profiling of the existing memory usage.
|
|
Then, it calculate the free memory that can be used for KV cache in
|
|
bytes.
|
|
|
|
.. tip::
|
|
You may limit the usage of GPU memory
|
|
by adjusting the `gpu_memory_utilization` parameter.
|
|
"""
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
_, total_gpu_memory = torch.cuda.mem_get_info()
|
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
|
# of the model.
|
|
self.model_runner.profile_run()
|
|
|
|
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
|
# NOTE(woosuk): Here we assume that the other processes using the same
|
|
# GPU did not change their memory usage during the profiling.
|
|
assert self.init_gpu_memory > free_gpu_memory, (
|
|
"Error in memory profiling. "
|
|
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
|
f" {free_gpu_memory}. This happens when the GPU memory was "
|
|
"not properly cleaned up before initializing the vLLM instance.")
|
|
|
|
# Get the peak memory allocation recorded by torch
|
|
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
|
|
|
# Check for any memory left around that may have been allocated on the
|
|
# gpu outside of `torch`. NCCL operations, for example, can use a few
|
|
# GB during a forward pass
|
|
torch.cuda.empty_cache()
|
|
torch_allocated_bytes = torch.cuda.memory_stats(
|
|
)["allocated_bytes.all.current"]
|
|
total_allocated_bytes = torch.cuda.mem_get_info(
|
|
)[1] - torch.cuda.mem_get_info()[0]
|
|
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
|
if non_torch_allocations > 0:
|
|
peak_memory += non_torch_allocations
|
|
available_kv_cache_memory = (
|
|
total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
|
peak_memory)
|
|
|
|
return int(available_kv_cache_memory)
|
|
|
|
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
|
return self.model_runner.get_kv_cache_spec()
|
|
|
|
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
|
|
"""Allocate GPU KV cache with the specified kv_cache_config."""
|
|
if self.vllm_config.model_config.enable_sleep_mode:
|
|
allocator = CuMemAllocator.get_instance()
|
|
context = allocator.use_memory_pool(tag="kv_cache")
|
|
else:
|
|
from contextlib import nullcontext
|
|
context = nullcontext()
|
|
with context:
|
|
self.model_runner.initialize_kv_cache(kv_cache_config)
|
|
|
|
def compile_or_warm_up_model(self) -> None:
|
|
# warm up sizes that are not in cudagraph capture sizes,
|
|
# but users still want to compile for better performance,
|
|
# e.g. for the max-num-batched token size in chunked prefill.
|
|
warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy()
|
|
if not self.model_config.enforce_eager:
|
|
warmup_sizes = [
|
|
x for x in warmup_sizes if x not in
|
|
self.vllm_config.compilation_config.cudagraph_capture_sizes
|
|
]
|
|
for size in sorted(warmup_sizes, reverse=True):
|
|
logger.info("Compile and warming up model for size %d", size)
|
|
self.model_runner._dummy_run(size)
|
|
if not self.model_config.enforce_eager:
|
|
self.model_runner.capture_model()
|
|
|
|
# Warm up sampler and preallocate memory buffer for logits and other
|
|
# sampling related tensors of max possible shape to avoid memory
|
|
# fragmentation issue.
|
|
# NOTE: This is called after `capture_model` on purpose to prevent
|
|
# memory buffers from being cleared by `torch.cuda.empty_cache`.
|
|
if get_pp_group().is_last_rank:
|
|
max_num_reqs = min(self.scheduler_config.max_num_seqs,
|
|
self.scheduler_config.max_num_batched_tokens)
|
|
self.model_runner._dummy_sampler_run(
|
|
hidden_states=self.model_runner._dummy_run(
|
|
num_tokens=max_num_reqs))
|
|
|
|
# Reset the seed to ensure that the random state is not affected by
|
|
# the model initialization and profiling.
|
|
set_random_seed(self.model_config.seed)
|
|
|
|
def get_model(self) -> nn.Module:
|
|
return self.model_runner.get_model()
|
|
|
|
@torch.inference_mode()
|
|
def execute_model(
|
|
self,
|
|
scheduler_output: "SchedulerOutput",
|
|
) -> Optional[ModelRunnerOutput]:
|
|
output = self.model_runner.execute_model(scheduler_output)
|
|
return output if self.is_driver_worker else None
|
|
|
|
def profile(self, is_start: bool = True):
|
|
if self.profiler is None:
|
|
raise RuntimeError("Profiler is not enabled.")
|
|
if is_start:
|
|
self.profiler.start()
|
|
else:
|
|
self.profiler.stop()
|
|
|
|
def execute_dummy_batch(self) -> None:
|
|
self.model_runner._dummy_run(1)
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
return self.model_runner.add_lora(lora_request)
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
return self.model_runner.remove_lora(lora_id)
|
|
|
|
def list_loras(self) -> set[int]:
|
|
return self.model_runner.list_loras()
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
return self.model_runner.pin_lora(lora_id)
|
|
|
|
def check_health(self) -> None:
|
|
# worker will always be healthy as long as it's running.
|
|
return
|
|
|
|
def save_sharded_state(
|
|
self,
|
|
path: str,
|
|
pattern: Optional[str] = None,
|
|
max_size: Optional[int] = None,
|
|
) -> None:
|
|
from vllm.model_executor.model_loader.loader import ShardedStateLoader
|
|
ShardedStateLoader.save_model(
|
|
self.model_runner.model,
|
|
path,
|
|
pattern=pattern,
|
|
max_size=max_size,
|
|
)
|
|
|
|
|
|
def init_worker_distributed_environment(
|
|
parallel_config: ParallelConfig,
|
|
rank: int,
|
|
distributed_init_method: Optional[str] = None,
|
|
local_rank: int = -1,
|
|
) -> None:
|
|
"""Initialize the distributed environment."""
|
|
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
|
|
|
init_distributed_environment(parallel_config.world_size, rank,
|
|
distributed_init_method, local_rank)
|
|
|
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
|
parallel_config.pipeline_parallel_size)
|
|
|
|
|
|
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
|
# Check if the GPU supports the dtype.
|
|
if torch_dtype == torch.bfloat16: # noqa: SIM102
|
|
if not current_platform.has_device_capability(80):
|
|
capability = current_platform.get_device_capability()
|
|
gpu_name = current_platform.get_device_name()
|
|
|
|
if capability is None:
|
|
compute_str = "does not have a compute capability"
|
|
else:
|
|
version_str = capability.as_version_str()
|
|
compute_str = f"has compute capability {version_str}"
|
|
|
|
raise ValueError(
|
|
"Bfloat16 is only supported on GPUs with compute capability "
|
|
f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
|
|
"You can use float16 instead by explicitly setting the "
|
|
"`dtype` flag in CLI, for example: --dtype=half.")
|