[Core] Consolidate GB constant and enable float GB arguments (#7416)

This commit is contained in:
Cyrus Leung 2024-08-13 05:14:14 +08:00 committed by GitHub
parent 6aa33cb2dd
commit 4ddc4743d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 21 additions and 21 deletions

View File

@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.tracing import is_otel_installed from vllm.tracing import is_otel_installed
from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
cuda_device_count_stateless, get_cpu_memory, is_cpu, cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_tpu, is_xpu, is_hip, is_neuron, is_openvino, is_tpu, is_xpu,
print_warning_once) print_warning_once)
@ -27,7 +27,6 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
_GB = 1 << 30
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_PP_SUPPORTED_MODELS = [ _PP_SUPPORTED_MODELS = [
@ -492,7 +491,7 @@ class CacheConfig:
self, self,
block_size: int, block_size: int,
gpu_memory_utilization: float, gpu_memory_utilization: float,
swap_space: int, swap_space: float,
cache_dtype: str, cache_dtype: str,
num_gpu_blocks_override: Optional[int] = None, num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None, sliding_window: Optional[int] = None,
@ -501,7 +500,7 @@ class CacheConfig:
) -> None: ) -> None:
self.block_size = block_size self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB self.swap_space_bytes = swap_space * GiB_bytes
self.num_gpu_blocks_override = num_gpu_blocks_override self.num_gpu_blocks_override = num_gpu_blocks_override
self.cache_dtype = cache_dtype self.cache_dtype = cache_dtype
self.sliding_window = sliding_window self.sliding_window = sliding_window
@ -561,9 +560,9 @@ class CacheConfig:
num_gpus_per_node = parallel_config.tensor_parallel_size num_gpus_per_node = parallel_config.tensor_parallel_size
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of " msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the "
f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory "
"allocated for the swap space.") "is allocated for the swap space.")
if cpu_memory_usage > 0.7 * total_cpu_memory: if cpu_memory_usage > 0.7 * total_cpu_memory:
raise ValueError("Too large swap space. " + msg) raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory: elif cpu_memory_usage > 0.4 * total_cpu_memory:

View File

@ -58,8 +58,8 @@ class EngineArgs:
enable_prefix_caching: bool = False enable_prefix_caching: bool = False
disable_sliding_window: bool = False disable_sliding_window: bool = False
use_v2_block_manager: bool = False use_v2_block_manager: bool = False
swap_space: int = 4 # GiB swap_space: float = 4 # GiB
cpu_offload_gb: int = 0 # GiB cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90 gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256 max_num_seqs: int = 256
@ -321,7 +321,7 @@ class EngineArgs:
default=EngineArgs.seed, default=EngineArgs.seed,
help='Random seed for operations.') help='Random seed for operations.')
parser.add_argument('--swap-space', parser.add_argument('--swap-space',
type=int, type=float,
default=EngineArgs.swap_space, default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU.') help='CPU swap space size (GiB) per GPU.')
parser.add_argument( parser.add_argument(

View File

@ -13,7 +13,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_open_port, from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
@ -332,7 +332,6 @@ def _verify_and_get_scheduler_config(
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
_GB = 1 << 30
if config.enable_prefix_caching: if config.enable_prefix_caching:
logger.warning("Prefix caching is not supported on CPU, disable it.") logger.warning("Prefix caching is not supported on CPU, disable it.")
config.enable_prefix_caching = False config.enable_prefix_caching = False
@ -341,11 +340,11 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
if kv_cache_space >= 0: if kv_cache_space >= 0:
if kv_cache_space == 0: if kv_cache_space == 0:
config.cpu_kvcache_space_bytes = 4 * _GB # type: ignore config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
"for CPU backend is not set, using 4 by default.") "for CPU backend is not set, using 4 by default.")
else: else:
config.cpu_kvcache_space_bytes = kv_cache_space * _GB # type: ignore config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
else: else:
raise RuntimeError( raise RuntimeError(
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE" "Invalid environment variable VLLM_CPU_KVCACHE_SPACE"

View File

@ -10,8 +10,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip,
make_async) get_open_port, make_async)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -165,14 +165,13 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE
if kv_cache_space >= 0: if kv_cache_space >= 0:
_GB = 1 << 30
if kv_cache_space == 0: if kv_cache_space == 0:
config.openvino_kvcache_space_bytes = 4 * _GB # type: ignore config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
logger.warning( logger.warning(
"Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) "
"for OpenVINO backend is not set, using 4 by default.") "for OpenVINO backend is not set, using 4 by default.")
else: else:
config.openvino_kvcache_space_bytes = kv_cache_space * _GB # type: ignore config.openvino_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
else: else:
raise RuntimeError( raise RuntimeError(
"Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE"

View File

@ -115,6 +115,9 @@ STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID" STR_INVALID_VAL: str = "INVALID"
GiB_bytes = 1 << 30
"""The number of bytes in one gibibyte (GiB)."""
STR_DTYPE_TO_TORCH_DTYPE = { STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.half, "half": torch.half,
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,

View File

@ -143,7 +143,7 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
# Calculate the CPU KV cache size based on the config. # Calculate the CPU KV cache size based on the config.
num_cpu_blocks = (self.cache_config.swap_space_bytes // num_cpu_blocks = int(self.cache_config.swap_space_bytes //
block_size_bytes) block_size_bytes)
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
return num_tpu_blocks, num_cpu_blocks return num_tpu_blocks, num_cpu_blocks