Add gpu_memory_utilization and swap_space to LLM (#1090)

This commit is contained in:
Woosuk Kwon 2023-09-19 22:16:04 -07:00 committed by GitHub
parent 400b8289f7
commit bc0644574c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -37,12 +37,22 @@ class LLM:
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
seed: The seed to initialize the random number generator for sampling.
quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not
quantized and use `dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
"""
def __init__(
@ -53,8 +63,11 @@ class LLM:
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
seed: int = 0,
quantization: Optional[str] = None,
revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
@ -66,8 +79,11 @@ class LLM:
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
seed=seed,
quantization=quantization,
revision=revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(engine_args)