Merge branch 'main' into woosuk-tpu

This commit is contained in:
Woosuk Kwon 2024-04-16 07:56:53 +00:00
commit d4adf92beb
167 changed files with 5410 additions and 1330 deletions

View File

@ -12,7 +12,13 @@ steps:
command: pytest -v -s async_engine command: pytest -v -s async_engine
- label: Basic Correctness Test - label: Basic Correctness Test
command: pytest -v -s basic_correctness commands:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py
- label: Core Test - label: Core Test
command: pytest -v -s core command: pytest -v -s core
@ -29,6 +35,8 @@ steps:
- pytest -v -s test_pynccl.py - pytest -v -s test_pynccl.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
- label: Engine Test - label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py command: pytest -v -s engine tokenization test_sequence.py test_config.py
@ -83,6 +91,9 @@ steps:
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
parallelism: 4 parallelism: 4
- label: Tensorizer Test
command: apt-get install curl libsodium23 && pytest -v -s tensorizer
- label: Metrics Test - label: Metrics Test
command: pytest -v -s metrics command: pytest -v -s metrics

50
.github/workflows/mypy.yaml vendored Normal file
View File

@ -0,0 +1,50 @@
name: mypy
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml

View File

@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.10"] python-version: ["3.8", "3.9", "3.10", "3.11"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

View File

@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.10"] python-version: ["3.8", "3.9", "3.10", "3.11"]
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}

View File

@ -91,7 +91,7 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
COPY ./ /app/vllm COPY ./ /app/vllm
RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --upgrade pip numba
RUN python3 -m pip install xformers==0.0.23 --no-deps RUN python3 -m pip install xformers==0.0.23 --no-deps
RUN cd /app \ RUN cd /app \

View File

@ -72,7 +72,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.) - MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) - Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.) - OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)

View File

@ -27,8 +27,8 @@ class RequestFuncInput:
class RequestFuncOutput: class RequestFuncOutput:
generated_text: str = "" generated_text: str = ""
success: bool = False success: bool = False
latency: float = 0 latency: float = 0.0
ttft: float = 0 # Time to first token ttft: float = 0.0 # Time to first token
itl: List[float] = field( itl: List[float] = field(
default_factory=list) # List of inter-token latencies default_factory=list) # List of inter-token latencies
prompt_len: int = 0 prompt_len: int = 0
@ -58,23 +58,24 @@ async def async_request_tgi(
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(url=api_url, json=payload) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk) data = json.loads(chunk)
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -119,23 +120,24 @@ async def async_request_trt_llm(
output = RequestFuncOutput() output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload) as response: async with session.post(url=api_url, json=payload) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data:") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data:")
data = json.loads(chunk) data = json.loads(chunk)
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -151,7 +153,7 @@ async def async_request_trt_llm(
output.success = True output.success = True
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False
@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
output.generated_text = parsed_resp["text"][0] output.generated_text = parsed_resp["text"][0]
output.success = True output.success = True
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False
@ -234,19 +236,20 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
headers=headers) as response: headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
else: else:
@ -255,7 +258,7 @@ async def async_request_openai_completions(
if data["choices"][0]["text"]: if data["choices"][0]["text"]:
timestamp = time.perf_counter() timestamp = time.perf_counter()
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
output.prompt_len = request_func_input.prompt_len output.prompt_len = request_func_input.prompt_len
generated_text = "" generated_text = ""
ttft = 0 ttft = 0.0
st = time.perf_counter() st = time.perf_counter()
most_recent_timestamp = st most_recent_timestamp = st
try: try:
async with session.post(url=api_url, json=payload, async with session.post(url=api_url, json=payload,
headers=headers) as response: headers=headers) as response:
if response.status == 200: if response.status == 200:
async for chunk in response.content: async for chunk_bytes in response.content:
chunk = chunk.strip() chunk_bytes = chunk_bytes.strip()
if not chunk: if not chunk_bytes:
continue continue
chunk = remove_prefix(chunk.decode("utf-8"), "data: ") chunk = remove_prefix(chunk_bytes.decode("utf-8"),
"data: ")
if chunk == "[DONE]": if chunk == "[DONE]":
latency = time.perf_counter() - st latency = time.perf_counter() - st
else: else:
@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
delta = data["choices"][0]["delta"] delta = data["choices"][0]["delta"]
if delta.get("content", None): if delta.get("content", None):
# First token # First token
if ttft == 0: if ttft == 0.0:
ttft = time.perf_counter() - st ttft = time.perf_counter() - st
output.ttft = ttft output.ttft = ttft
@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
output.success = True output.success = True
output.latency = latency output.latency = latency
else: else:
output.error = response.reason output.error = response.reason or ""
output.success = False output.success = False
except Exception: except Exception:
output.success = False output.success = False

View File

@ -177,8 +177,7 @@ if __name__ == '__main__':
help='block size of key/value cache') help='block size of key/value cache')
parser.add_argument( parser.add_argument(
'--enable-chunked-prefill', '--enable-chunked-prefill',
type=bool, action='store_true',
default=False,
help='If True, the prefill requests can be chunked based on the ' help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens')
parser.add_argument( parser.add_argument(

View File

@ -74,25 +74,31 @@ def run_vllm(
quantization_param_path: Optional[str], quantization_param_path: Optional[str],
device: str, device: str,
enable_prefix_caching: bool, enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM(model=model, llm = LLM(
tokenizer=tokenizer, model=model,
quantization=quantization, tokenizer=tokenizer,
tensor_parallel_size=tensor_parallel_size, quantization=quantization,
seed=seed, tensor_parallel_size=tensor_parallel_size,
trust_remote_code=trust_remote_code, seed=seed,
dtype=dtype, trust_remote_code=trust_remote_code,
max_model_len=max_model_len, dtype=dtype,
gpu_memory_utilization=gpu_memory_utilization, max_model_len=max_model_len,
enforce_eager=enforce_eager, gpu_memory_utilization=gpu_memory_utilization,
kv_cache_dtype=kv_cache_dtype, enforce_eager=enforce_eager,
quantization_param_path=quantization_param_path, kv_cache_dtype=kv_cache_dtype,
device=device, quantization_param_path=quantization_param_path,
enable_prefix_caching=enable_prefix_caching, device=device,
download_dir=download_dir) enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
# Add the requests to the engine. # Add the requests to the engine.
for prompt, _, output_len in requests: for prompt, _, output_len in requests:
@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
args.output_len) args.output_len)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm(requests, args.model, args.tokenizer, elapsed_time = run_vllm(
args.quantization, args.tensor_parallel_size, requests, args.model, args.tokenizer, args.quantization,
args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.trust_remote_code, args.dtype, args.max_model_len,
args.max_model_len, args.enforce_eager, args.enforce_eager, args.kv_cache_dtype,
args.kv_cache_dtype, args.quantization_param_path, args.device,
args.quantization_param_path, args.device, args.enable_prefix_caching, args.enable_chunked_prefill,
args.enable_prefix_caching, args.max_num_batched_tokens, args.gpu_memory_utilization,
args.gpu_memory_utilization, args.download_dir) args.download_dir)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -335,6 +341,14 @@ if __name__ == "__main__":
"--enable-prefix-caching", "--enable-prefix-caching",
action='store_true', action='store_true',
help="enable automatic prefix caching for vLLM backend.") help="enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir', parser.add_argument('--download-dir',
type=str, type=str,
default=None, default=None,

View File

@ -5,7 +5,7 @@ from typing import Optional
import torch import torch
from vllm._C import ops from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
NUM_BLOCKS = 1024 NUM_BLOCKS = 1024

View File

@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 128) \ f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \ f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \ f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 640) \
f(in_T, out_T, W_T, narrow, 768) \ f(in_T, out_T, W_T, narrow, 768) \
f(in_T, out_T, W_T, narrow, 1024) \ f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1152) \ f(in_T, out_T, W_T, narrow, 1152) \
@ -46,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13696) \
f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \ f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 15360) \
f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 22016) \
@ -59,7 +61,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \ f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA f(in_T, out_T, W_T, narrow, 64000) \
f(in_T, out_T, W_T, narrow, 64256) \
f(in_T, out_T, W_T, narrow, 64512) \
f(in_T, out_T, W_T, narrow, 102400) \
f(in_T, out_T, W_T, narrow, 102656) \
f(in_T, out_T, W_T, narrow, 102912) \
f(in_T, out_T, W_T, narrow, 128000) \
f(in_T, out_T, W_T, narrow, 128256) \
f(in_T, out_T, W_T, narrow, 128512) \
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
// and vllm/tests/lora/test_punica.py
// Keep this in sync with vllm/config::LoRAConfig // Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \

View File

@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
} }
} }
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
return (uint32_t(a) << 16) | uint32_t(b); return (uint64_t(a) << 32) | uint64_t(b);
} }
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
template <typename in_T, typename out_T, typename W_T> template <typename in_T, typename out_T, typename W_T>
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
const int64_t *lora_indices, const int64_t *lora_indices,
uint16_t in_features, uint16_t out_features, uint32_t in_features, uint32_t out_features,
int64_t y_offset, int64_t full_y_size, int64_t y_offset, int64_t full_y_size,
int64_t batch_size, int64_t num_layers, int64_t batch_size, int64_t num_layers,
int64_t layer_idx, float scale) { int64_t layer_idx, float scale) {
switch (pack_u16(in_features, out_features)) { switch (pack_u32(in_features, out_features)) {
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u16(feat_in, feat_out): \ case pack_u32(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \ bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \ full_y_size, batch_size, num_layers, \
layer_idx, scale); \ layer_idx, scale); \
@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(y.size(0), x.size(0)); CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false; bool ok = false;
if (h_in < 65536 && h_out < 65536) { if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch // TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) { switch (x.scalar_type()) {
case at::ScalarType::Half: case at::ScalarType::Half:
@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(y.size(0), x.size(0)); CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false; bool ok = false;
if (h_in < 65536 && h_out < 65536) { if (h_in <= 128512 && h_out <= 128512) {
// TODO: See if we can get rid of this massive nested switch // TODO: See if we can get rid of this massive nested switch
switch (x.scalar_type()) { switch (x.scalar_type()) {
case at::ScalarType::Half: case at::ScalarType::Half:

View File

@ -2067,7 +2067,7 @@ void gptq_shuffle
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight( vllm::gptq::shuffle_exllama_weight(
(uint32_t*) q_weight.data_ptr(), (uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(),
q_weight.size(0) * 32 / bit, q_weight.size(0) * 32 / bit,
q_weight.size(1), q_weight.size(1),
bit bit

View File

@ -12,6 +12,7 @@
import logging import logging
import sys import sys
from typing import List
from sphinx.ext import autodoc from sphinx.ext import autodoc
@ -45,7 +46,7 @@ templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path. # This pattern also affects html_static_path and html_extra_path.
exclude_patterns = [] exclude_patterns: List[str] = []
# Exclude the prompt "$" when copying code # Exclude the prompt "$" when copying code
copybutton_prompt_text = r"\$ " copybutton_prompt_text = r"\$ "
@ -82,6 +83,7 @@ autodoc_mock_imports = [
"vllm._C", "vllm._C",
"numpy", "numpy",
"tqdm", "tqdm",
"tensorizer",
] ]
for mock_target in autodoc_mock_imports: for mock_target in autodoc_mock_imports:

View File

@ -85,13 +85,3 @@ You can also build and install vLLM from source:
$ nvcc --version # verify that nvcc is in your PATH $ nvcc --version # verify that nvcc is in your PATH
$ ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME $ ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME
.. note::
If you are developing the C++ backend of vLLM, consider building vLLM with
.. code-block:: console
$ python setup.py develop
since it will give you incremental builds. The downside is that this method
is `deprecated by setuptools <https://github.com/pypa/setuptools/issues/917>`_.

View File

@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM:
Directory to download and load the weights, default to the default cache dir of huggingface. Directory to download and load the weights, default to the default cache dir of huggingface.
.. option:: --load-format {auto,pt,safetensors,npcache,dummy} .. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer}
The format of the model weights to load. The format of the model weights to load.
@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM:
* "safetensors" will load the weights in the safetensors format. * "safetensors" will load the weights in the safetensors format.
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading. * "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
* "dummy" will initialize the weights with random values, mainly for profiling. * "dummy" will initialize the weights with random values, mainly for profiling.
* "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. <https://github.com/coreweave/tensorizer>`_ See `examples/tensorize_vllm_model.py <https://github.com/vllm-project/vllm/blob/main/examples/tensorize_vllm_model.py>`_ to serialize a vLLM model, and for more information.
.. option:: --dtype {auto,half,float16,bfloat16,float,float32} .. option:: --dtype {auto,half,float16,bfloat16,float,float32}

View File

@ -93,7 +93,7 @@ Alongside each architecture, we include some popular models that use it.
- ✅︎ - ✅︎
* - :code:`MixtralForCausalLM` * - :code:`MixtralForCausalLM`
- Mixtral-8x7B, Mixtral-8x7B-Instruct - Mixtral-8x7B, Mixtral-8x7B-Instruct
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc. - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc.
- ✅︎ - ✅︎
* - :code:`MPTForCausalLM` * - :code:`MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
@ -168,3 +168,29 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
output = llm.generate("Hello, my name is") output = llm.generate("Hello, my name is")
print(output) print(output)
Model Support Policy
---------------------
At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Heres how we manage third-party model support:
1. **Community-Driven Support**: We encourage community contributions for adding new models. When a user requests support for a new model, we welcome pull requests (PRs) from the community. These contributions are evaluated primarily on the sensibility of the output they generate, rather than strict consistency with existing implementations such as those in transformers. **Call for contribution:** PRs coming directly from model vendors are greatly appreciated!
2. **Best-Effort Consistency**: While we aim to maintain a level of consistency between the models implemented in vLLM and other frameworks like transformers, complete alignment is not always feasible. Factors like acceleration techniques and the use of low-precision computations can introduce discrepancies. Our commitment is to ensure that the implemented models are functional and produce sensible results.
3. **Issue Resolution and Model Updates**: Users are encouraged to report any bugs or issues they encounter with third-party models. Proposed fixes should be submitted via PRs, with a clear explanation of the problem and the rationale behind the proposed solution. If a fix for one model impacts another, we rely on the community to highlight and address these cross-model dependencies. Note: for bugfix PRs, it is good etiquette to inform the original author to seek their feedback.
4. **Monitoring and Updates**: Users interested in specific models should monitor the commit history for those models (e.g., by tracking changes in the main/vllm/model_executor/models directory). This proactive approach helps users stay informed about updates and changes that may affect the models they use.
5. **Selective Focus**: Our resources are primarily directed towards models with significant user interest and impact. Models that are less frequently used may receive less attention, and we rely on the community to play a more active role in their upkeep and improvement.
Through this approach, vLLM fosters a collaborative environment where both the core development team and the broader community contribute to the robustness and diversity of the third-party models supported in our ecosystem.
Note that, as an inference engine, vLLM does not introduce new models. Therefore, all models supported by vLLM are third-party models in this regard.
We have the following levels of testing for models:
1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to `test_models.py <https://github.com/vllm-project/vllm/blob/main/tests/models/test_models.py>`_ and `test_big_models.py <https://github.com/vllm-project/vllm/blob/main/tests/models/test_big_models.py>`_ for the models that have passed this test.
2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test.
3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests <https://github.com/vllm-project/vllm/tree/main/tests>`_ and `examples <https://github.com/vllm-project/vllm/tree/main/examples>`_ for the models that have passed this test.
4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category.

View File

@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat
You can start the server using Python, or using [Docker](deploying_with_docker.rst): You can start the server using Python, or using [Docker](deploying_with_docker.rst):
```bash ```bash
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-hf --dtype float32 --api-key token-abc123 python -m vllm.entrypoints.openai.api_server --model mistralai/Mistral-7B-Instruct-v0.2 --dtype auto --api-key token-abc123
``` ```
To call the server, you can use the official OpenAI Python client library, or any other HTTP client. To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
@ -16,9 +16,8 @@ client = OpenAI(
) )
completion = client.chat.completions.create( completion = client.chat.completions.create(
model="meta-llama/Llama-2-7b-hf", model="mistralai/Mistral-7B-Instruct-v0.2",
messages=[ messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"} {"role": "user", "content": "Hello!"}
] ]
) )
@ -38,9 +37,8 @@ Or directly merge them into the JSON payload if you are using HTTP call directly
```python ```python
completion = client.chat.completions.create( completion = client.chat.completions.create(
model="meta-llama/Llama-2-7b-hf", model="mistralai/Mistral-7B-Instruct-v0.2",
messages=[ messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"} {"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
], ],
extra_body={ extra_body={
@ -89,7 +87,7 @@ In order for the language model to support chat protocol, vLLM requires the mode
a chat template in its tokenizer configuration. The chat template is a Jinja2 template that a chat template in its tokenizer configuration. The chat template is a Jinja2 template that
specifies how are roles, messages, and other chat-specific tokens are encoded in the input. specifies how are roles, messages, and other chat-specific tokens are encoded in the input.
An example chat template for `meta-llama/Llama-2-7b-chat-hf` can be found [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/09bd0f49e16738cdfaa6e615203e126038736eb0/tokenizer_config.json#L12) An example chat template for `mistralai/Mistral-7B-Instruct-v0.2` can be found [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format)
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model, Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model,
you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat

View File

@ -0,0 +1,282 @@
import argparse
import dataclasses
import os
import time
import uuid
from functools import partial
from typing import Type
import torch
import torch.nn as nn
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
TensorSerializer, stream_io)
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
from transformers import AutoConfig, PretrainedConfig
from vllm.distributed import initialize_model_parallel
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.tensorizer_loader import TensorizerArgs
# yapf conflicts with isort for this docstring
# yapf: disable
"""
tensorize_vllm_model.py is a script that can be used to serialize and
deserialize vLLM models. These models can be loaded using tensorizer
to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
or locally. Tensor encryption and decryption is also supported, although
libsodium must be installed to use it. Install vllm with tensorizer support
using `pip install vllm[tensorizer]`.
To serialize a model, install vLLM from source, then run something
like this from the root level of this repository:
python -m examples.tensorize_vllm_model \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
serialize \
--serialized-directory s3://my-bucket/ \
--suffix vllm
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
and saves it to your S3 bucket. A local directory can also be used. This
assumes your S3 credentials are specified as environment variables
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
To provide S3 credentials directly, you can provide `--s3-access-key-id` and
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this
script.
You can also encrypt the model weights with a randomly-generated key by
providing a `--keyfile` argument.
To deserialize a model, you can run something like this from the root
level of this repository:
python -m examples.tensorize_vllm_model \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
deserialize \
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
Which downloads the model tensors from your S3 bucket and deserializes them.
You can also provide a `--keyfile` argument to decrypt the model weights if
they were serialized with encryption.
For more information on the available arguments for serializing, run
`python -m examples.tensorize_vllm_model serialize --help`.
Or for deserializing:
`python -m examples.tensorize_vllm_model deserialize --help`.
Once a model is serialized, it can be used to load the model when running the
OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing
the `--tensorizer-uri` CLI argument that is functionally the same as the
`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to
signify that the model to be deserialized is a vLLM model, rather than a
HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer
in the same inference server, albeit without the speed optimizations. To
deserialize an encrypted file, the `--encryption-keyfile` argument can be used
to provide the path to the keyfile used to encrypt the model weights. For
information on all the arguments that can be used to configure tensorizer's
deserialization, check out the tensorizer options argument group in the
`vllm/entrypoints/openai/api_server.py` script with `--help`.
Tensorizer can also be invoked with the `LLM` class directly to load models:
llm = LLM(model="facebook/opt-125m",
load_format="tensorizer",
tensorizer_uri=path_to_opt_tensors,
num_readers=3,
vllm_tensorized=True)
"""
def parse_args():
parser = argparse.ArgumentParser(
description="An example script that can be used to serialize and "
"deserialize vLLM models. These models "
"can be loaded using tensorizer directly to the GPU "
"extremely quickly. Tensor encryption and decryption is "
"also supported, although libsodium must be installed to "
"use it.")
parser = EngineArgs.add_cli_args(parser)
subparsers = parser.add_subparsers(dest='command')
serialize_parser = subparsers.add_parser(
'serialize', help="Serialize a model to `--serialized-directory`")
serialize_parser.add_argument(
"--suffix",
type=str,
required=False,
help=(
"The suffix to append to the serialized model directory, which is "
"used to construct the location of the serialized model tensors, "
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
"`--suffix` is `v1`, the serialized model tensors will be "
"saved to "
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
"If none is provided, a random UUID will be used."))
serialize_parser.add_argument(
"--serialized-directory",
type=str,
required=True,
help="The directory to serialize the model to. "
"This can be a local directory or S3 URI. The path to where the "
"tensors are saved is a combination of the supplied `dir` and model "
"reference ID. For instance, if `dir` is the serialized directory, "
"and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
"be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
"where `suffix` is given by `--suffix` or a random UUID if not "
"provided.")
serialize_parser.add_argument(
"--keyfile",
type=str,
required=False,
help=("Encrypt the model weights with a randomly-generated binary key,"
" and save the key at this path"))
deserialize_parser = subparsers.add_parser(
'deserialize',
help=("Deserialize a model from `--path-to-tensors`"
" to verify it can be loaded and used."))
deserialize_parser.add_argument(
"--path-to-tensors",
type=str,
required=True,
help="The local path or S3 URI to the model tensors to deserialize. ")
deserialize_parser.add_argument(
"--keyfile",
type=str,
required=False,
help=("Path to a binary key to use to decrypt the model weights,"
" if the model was serialized with encryption"))
return parser.parse_args()
def make_model_contiguous(model):
# Ensure tensors are saved in memory contiguously
for param in model.parameters():
param.data = param.data.contiguous()
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def serialize():
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
engine = LLMEngine.from_engine_args(engine_args)
model = (engine.model_executor.driver_worker.
model_runner.model)
encryption_params = EncryptionParams.random() if keyfile else None
if keyfile:
with _write_stream(keyfile) as stream:
stream.write(encryption_params.key)
with _write_stream(model_path) as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model)
serializer.close()
print("Serialization complete. Model tensors saved to", model_path)
if keyfile:
print("Key saved to", keyfile)
def deserialize():
config = AutoConfig.from_pretrained(model_ref)
with no_init_or_tensor():
model_class = _get_vllm_model_architecture(config)
model = model_class(config)
before_mem = get_mem_usage()
start = time.time()
if keyfile:
with _read_stream(keyfile) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
tensorizer_args.deserializer_params['encryption'] = \
decryption_params
with (_read_stream(model_path)) as stream, TensorDeserializer(
stream, **tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(model)
end = time.time()
# Brag about how fast we are.
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
print(
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
)
print(f"Memory usage before: {before_mem}")
print(f"Memory usage after: {after_mem}")
return model
args = parse_args()
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
or None)
s3_secret_access_key = (args.s3_secret_access_key
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
_read_stream, _write_stream = (partial(
stream_io.open_stream,
mode=mode,
s3_access_key_id=s3_access_key_id,
s3_secret_access_key=s3_secret_access_key,
s3_endpoint=s3_endpoint,
) for mode in ("rb", "wb+"))
model_ref = args.model
model_name = model_ref.split("/")[1]
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080"
torch.distributed.init_process_group(world_size=1, rank=0)
initialize_model_parallel()
keyfile = args.keyfile if args.keyfile else None
if args.command == "serialize":
input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
model_path = f"{base_path}/model.tensors"
serialize()
elif args.command == "deserialize":
tensorizer_args = TensorizerArgs.from_cli_args(args)
model_path = args.path_to_tensors
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")

View File

@ -93,9 +93,23 @@ fi
echo 'vLLM yapf: Done' echo 'vLLM yapf: Done'
# Run mypy # Run mypy
# TODO(zhuohan): Enable mypy echo 'vLLM mypy:'
# echo 'vLLM mypy:' mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
# mypy mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
CODESPELL_EXCLUDES=( CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**' '--skip' '*docs/source/_build/**'
@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then
exit 1 exit 1
fi fi

View File

@ -46,10 +46,13 @@ ignore = [
python_version = "3.8" python_version = "3.8"
ignore_missing_imports = true ignore_missing_imports = true
check_untyped_defs = true
files = "vllm" files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/" exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
]
[tool.codespell] [tool.codespell]

View File

@ -11,4 +11,7 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server. pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0 prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer tiktoken == 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.9.3
outlines == 0.0.34 # Requires torch >= 2.1.0 outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4

View File

@ -3,4 +3,4 @@
# Dependencies for x86_64 CPUs # Dependencies for x86_64 CPUs
torch == 2.2.1+cpu torch == 2.2.1+cpu
triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error. triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.

View File

@ -7,4 +7,3 @@ pynvml == 11.5.0
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.2.1 torch == 2.2.1
xformers == 0.0.25 # Requires PyTorch 2.2.1 xformers == 0.0.25 # Requires PyTorch 2.2.1
triton >= 2.1.0

View File

@ -7,13 +7,14 @@ codespell==2.2.6
isort==5.13.2 isort==5.13.2
# type checking # type checking
mypy==0.991 mypy==1.9.0
types-PyYAML types-PyYAML
types-requests types-requests
types-setuptools types-setuptools
# testing # testing
pytest pytest
tensorizer==2.9.0a0
pytest-forked pytest-forked
pytest-asyncio pytest-asyncio
pytest-rerunfailures pytest-rerunfailures

View File

@ -5,7 +5,7 @@ import re
import subprocess import subprocess
import sys import sys
from shutil import which from shutil import which
from typing import List from typing import Dict, List
import torch import torch
from packaging.version import Version, parse from packaging.version import Version, parse
@ -52,7 +52,7 @@ class CMakeExtension(Extension):
class cmake_build_ext(build_ext): class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured. # A dict of extension directories that have been configured.
did_config = {} did_config: Dict[str, bool] = {}
# #
# Determine number of compilation jobs and optionally nvcc compile threads. # Determine number of compilation jobs and optionally nvcc compile threads.
@ -269,6 +269,7 @@ def get_nvcc_cuda_version() -> Version:
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
""" """
assert CUDA_HOME is not None, "CUDA_HOME is not set"
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"], nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
universal_newlines=True) universal_newlines=True)
output = nvcc_output.split() output = nvcc_output.split()
@ -416,6 +417,9 @@ setup(
python_requires=">=3.8", python_requires=">=3.8",
install_requires=get_requirements(), install_requires=get_requirements(),
ext_modules=ext_modules, ext_modules=ext_modules,
extras_require={
"tensorizer": ["tensorizer==2.9.0a1"],
},
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {}, cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
package_data=package_data, package_data=package_data,
) )

View File

@ -25,21 +25,30 @@ def _query_server_long(prompt: str) -> dict:
@pytest.fixture @pytest.fixture
def api_server(tokenizer_pool_size: int): def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
worker_use_ray: bool):
script_path = Path(__file__).parent.joinpath( script_path = Path(__file__).parent.joinpath(
"api_server_async_engine.py").absolute() "api_server_async_engine.py").absolute()
uvicorn_process = subprocess.Popen([ commands = [
sys.executable, "-u", sys.executable, "-u",
str(script_path), "--model", "facebook/opt-125m", "--host", str(script_path), "--model", "facebook/opt-125m", "--host",
"127.0.0.1", "--tokenizer-pool-size", "127.0.0.1", "--tokenizer-pool-size",
str(tokenizer_pool_size) str(tokenizer_pool_size)
]) ]
if engine_use_ray:
commands.append("--engine-use-ray")
if worker_use_ray:
commands.append("--worker-use-ray")
uvicorn_process = subprocess.Popen(commands)
yield yield
uvicorn_process.terminate() uvicorn_process.terminate()
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2]) @pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
def test_api_server(api_server, tokenizer_pool_size: int): @pytest.mark.parametrize("worker_use_ray", [False, True])
@pytest.mark.parametrize("engine_use_ray", [False, True])
def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
engine_use_ray: bool):
""" """
Run the API server and test it. Run the API server and test it.

View File

@ -0,0 +1,66 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
It tests chunked prefill. Chunked prefill can be enabled by
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
prefill requests are chunked.
Run `pytest tests/models/test_chunked_prefill.py`.
"""
import pytest
MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False, True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
enforce_eager: bool,
tensor_parallel_size: int,
) -> None:
max_num_seqs = min(chunked_prefill_token_size, 256)
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(
model,
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
print(vllm_outputs[0])
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

View File

@ -11,8 +11,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor,
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.distributed import destroy_model_parallel
destroy_model_parallel)
from vllm.sequence import MultiModalData from vllm.sequence import MultiModalData
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
@ -402,7 +401,7 @@ class VllmRunner:
cleanup() cleanup()
@pytest.fixture @pytest.fixture(scope="session")
def vllm_runner(): def vllm_runner():
return VllmRunner return VllmRunner

View File

@ -104,10 +104,10 @@ def test_chunk():
# One chunked prefill, and one decoding. # One chunked prefill, and one decoding.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running) assert set(get_sequence_groups(out)) == set(running)
# The first one is decoding. # The first one is prefill. Scheduler guarantees ordering.
assert seq_group_meta[0].token_chunk_size == 1 assert seq_group_meta[0].token_chunk_size == 56
# The second one is a chunked prefill. # The second one is a chunked prefill.
assert seq_group_meta[1].token_chunk_size == 56 assert seq_group_meta[1].token_chunk_size == 1
assert out.num_prefill_groups == 1 assert out.num_prefill_groups == 1
assert out.num_batched_tokens == 57 assert out.num_batched_tokens == 57
@ -157,12 +157,12 @@ def test_complex():
# Decoding & chunked prefill & first chunk of 3rd request is scheduled. # Decoding & chunked prefill & first chunk of 3rd request is scheduled.
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(get_sequence_groups(out)) == 3 assert len(get_sequence_groups(out)) == 3
# The first one is decoding. # The first one is the first chunked prefill.
assert seq_group_meta[0].token_chunk_size == 1 assert seq_group_meta[0].token_chunk_size == 7
# The second one is a chunked prefill. # The second one is the second new chunked prefill.
assert seq_group_meta[1].token_chunk_size == 56 assert seq_group_meta[1].token_chunk_size == 56
# The third one is also chunked. # The last one is decode.
assert seq_group_meta[2].token_chunk_size == 7 assert seq_group_meta[2].token_chunk_size == 1
# Two of them are in chunked prefill. # Two of them are in chunked prefill.
assert out.num_prefill_groups == 2 assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 64 assert out.num_batched_tokens == 64

View File

@ -33,11 +33,16 @@ def test_models(
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
) -> None: ) -> None:
hf_model = hf_runner(model, dtype=dtype) hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model del hf_model
vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2) vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model del vllm_model

View File

@ -0,0 +1,66 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest \
test_chunked_prefill_distributed.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_chunked_prefill_distributed.py
```
"""
import os
import pytest
import torch
MODELS = [
os.environ["TEST_DIST_MODEL"],
]
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
) -> None:
# Add a chunked prefill config.
max_num_seqs = min(chunked_prefill_token_size, 256)
assert chunked_prefill_token_size != -1
enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
max_num_seqs=max_num_seqs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

View File

@ -8,9 +8,9 @@ import pytest
import ray import ray
import torch import torch
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.distributed import (broadcast_tensor_dict,
broadcast_tensor_dict, tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.test_utils import (init_test_distributed_environment, from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel) multi_process_tensor_parallel)

View File

@ -6,9 +6,8 @@ import ray
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.distributed.device_communicators import custom_all_reduce
tensor_model_parallel_all_reduce)
from vllm.test_utils import (init_test_distributed_environment, from vllm.test_utils import (init_test_distributed_environment,
multi_process_tensor_parallel) multi_process_tensor_parallel)
@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port):
init_test_distributed_environment(1, world_size, rank, init_test_distributed_environment(1, world_size, rank,
distributed_init_port) distributed_init_port)
custom_ar.init_custom_ar() custom_all_reduce.init_custom_all_reduce()
for sz in test_sizes: for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]: for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with custom_ar.capture(): with custom_all_reduce.capture():
# use integers so result matches NCCL exactly # use integers so result matches NCCL exactly
inp1 = torch.randint(1, inp1 = torch.randint(1,
16, (sz, ), 16, (sz, ),
@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port):
distributed_init_port) distributed_init_port)
sz = 1024 sz = 1024
custom_ar.init_custom_ar() custom_all_reduce.init_custom_all_reduce()
fa = custom_ar.get_handle() fa = custom_all_reduce.get_handle()
inp = torch.ones(sz, dtype=torch.float32, device=device) inp = torch.ones(sz, dtype=torch.float32, device=device)
out = fa.all_reduce_unreg(inp) out = fa.all_reduce_unreg(inp)
assert torch.allclose(out, inp * world_size) assert torch.allclose(out, inp * world_size)

View File

@ -4,8 +4,8 @@ import os
import pytest import pytest
import torch import torch
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId) ncclGetUniqueId)
def distributed_run(fn, world_size): def distributed_run(fn, world_size):

View File

@ -3,7 +3,7 @@
2. One of the provided stop tokens 2. One of the provided stop tokens
3. The EOS token 3. The EOS token
Run `pytest tests/samplers/test_stop_reason.py`. Run `pytest tests/engine/test_stop_reason.py`.
""" """
import pytest import pytest

View File

@ -0,0 +1,111 @@
from typing import Any, List, Optional
import pytest
from vllm import CompletionOutput, LLMEngine, SamplingParams
MODEL = "meta-llama/llama-2-7b-hf"
MAX_TOKENS = 200
@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
return vllm_runner(MODEL)
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["."],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=".")
_test_stopping(vllm_model.model.llm_engine,
stop=["."],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".")
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo")
_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=True,
expected_output=
"VLLM is a 100% volunteer organization. We are a group of peo",
expected_reason="group of peo")
@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer or",
expected_reason="gani")
_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani")
@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
# token id 13013 => " organization"
_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
include_in_output=False,
expected_output="VLLM is a 100% volunteer",
expected_reason=13013)
_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013)
def _test_stopping(llm_engine: LLMEngine,
expected_output: str,
expected_reason: Any,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
temperature=0.0,
max_tokens=MAX_TOKENS,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output,
), None)
output: Optional[CompletionOutput] = None
output_text = ""
stop_reason = None
while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs
# Ensure we don't backtrack
assert output.text.startswith(output_text)
output_text = output.text
stop_reason = output.stop_reason
assert output is not None
assert output_text == expected_output
assert stop_reason == expected_reason

View File

@ -1,11 +1,14 @@
# This unit test should be moved to a new # This unit test should be moved to a new
# tests/test_guided_decoding directory. # tests/test_guided_decoding directory.
import pytest
import torch import torch
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor, from vllm.entrypoints.openai.protocol import CompletionRequest
RegexLogitsProcessor) from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, RegexLogitsProcessor)
TEST_SCHEMA = { TEST_SCHEMA = {
"type": "object", "type": "object",
@ -73,3 +76,36 @@ def test_guided_logits_processors():
json_LP(token_ids, tensor) json_LP(token_ids, tensor)
assert tensor.shape == original_tensor.shape assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor) assert not torch.allclose(tensor, original_tensor)
@pytest.mark.asyncio
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
async def test_guided_logits_processor_black_box(backend: str):
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
token_ids = tokenizer.encode(
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
regex_request = CompletionRequest(model='test',
prompt=token_ids,
guided_regex=TEST_REGEX)
regex_lp = await get_guided_decoding_logits_processor(
backend, regex_request, tokenizer)
assert regex_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = regex_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)
token_ids = tokenizer.encode(
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
json_request = CompletionRequest(model='test',
prompt=token_ids,
guided_json=TEST_SCHEMA)
json_lp = await get_guided_decoding_logits_processor(
backend, json_request, tokenizer)
assert json_lp is not None
tensor = torch.rand(32000)
original_tensor = torch.clone(tensor)
tensor = json_lp(token_ids, tensor)
assert tensor.shape == original_tensor.shape
assert not torch.allclose(tensor, original_tensor)

View File

@ -141,7 +141,7 @@ def server(zephyr_lora_files):
"--max-cpu-loras", "--max-cpu-loras",
"2", "2",
"--max-num-seqs", "--max-num-seqs",
"128" "128",
]) ])
ray.get(server_runner.ready.remote()) ray.get(server_runner.ready.remote())
yield server_runner yield server_runner
@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text assert first_response != completion.choices[0].text
async def test_guided_json_completion(server, client: openai.AsyncOpenAI): @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
prompt=f"Give an example JSON for an employee profile " prompt=f"Give an example JSON for an employee profile "
@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
n=3, n=3,
temperature=1.0, temperature=1.0,
max_tokens=500, max_tokens=500,
extra_body=dict(guided_json=TEST_SCHEMA)) extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3 assert completion.choices is not None and len(completion.choices) == 3
@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
async def test_guided_json_chat(server, client: openai.AsyncOpenAI): @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=500, max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA)) extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
assert message.content is not None assert message.content is not None
json1 = json.loads(message.content) json1 = json.loads(message.content)
@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=500, max_tokens=1000,
extra_body=dict(guided_json=TEST_SCHEMA)) extra_body=dict(guided_json=TEST_SCHEMA,
guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
assert message.content is not None assert message.content is not None
json2 = json.loads(message.content) json2 = json.loads(message.content)
@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
assert json1["age"] != json2["age"] assert json1["age"] != json2["age"]
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
n=3, n=3,
temperature=1.0, temperature=1.0,
max_tokens=20, max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX)) extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 3 assert completion.choices is not None and len(completion.choices) == 3
@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=20, max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX)) extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))
ip1 = chat_completion.choices[0].message.content ip1 = chat_completion.choices[0].message.content
assert ip1 is not None assert ip1 is not None
assert re.fullmatch(TEST_REGEX, ip1) is not None assert re.fullmatch(TEST_REGEX, ip1) is not None
@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=20, max_tokens=20,
extra_body=dict(guided_regex=TEST_REGEX)) extra_body=dict(guided_regex=TEST_REGEX,
guided_decoding_backend=guided_decoding_backend))
ip2 = chat_completion.choices[0].message.content ip2 = chat_completion.choices[0].message.content
assert ip2 is not None assert ip2 is not None
assert re.fullmatch(TEST_REGEX, ip2) is not None assert re.fullmatch(TEST_REGEX, ip2) is not None
assert ip1 != ip2 assert ip1 != ip2
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
completion = await client.completions.create( completion = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
prompt="The best language for type-safe systems programming is ", prompt="The best language for type-safe systems programming is ",
n=2, n=2,
temperature=1.0, temperature=1.0,
max_tokens=10, max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE)) extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
assert completion.id is not None assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 2 assert completion.choices is not None and len(completion.choices) == 2
@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
assert completion.choices[i].text in TEST_CHOICE assert completion.choices[i].text in TEST_CHOICE
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{ messages = [{
"role": "system", "role": "system",
"content": "you are a helpful assistant" "content": "you are a helpful assistant"
@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE)) extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
choice1 = chat_completion.choices[0].message.content choice1 = chat_completion.choices[0].message.content
assert choice1 in TEST_CHOICE assert choice1 in TEST_CHOICE
@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
model=MODEL_NAME, model=MODEL_NAME,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
extra_body=dict(guided_choice=TEST_CHOICE)) extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
choice2 = chat_completion.choices[0].message.content choice2 = chat_completion.choices[0].message.content
assert choice2 in TEST_CHOICE assert choice2 in TEST_CHOICE
assert choice1 != choice2 assert choice1 != choice2
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
with pytest.raises(openai.BadRequestError): with pytest.raises(openai.BadRequestError):
_ = await client.completions.create( _ = await client.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
prompt="Give an example JSON that fits this schema: 42", prompt="Give an example JSON that fits this schema: 42",
extra_body=dict(guided_json=42)) extra_body=dict(guided_json=42,
guided_decoding_backend=guided_decoding_backend))
messages = [{ messages = [{
"role": "system", "role": "system",
@ -742,5 +773,36 @@ number: "1" | "2"
assert content.strip() == ground_truth assert content.strip() == ground_truth
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
model_name: str):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=1)
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
assert (completion.choices[0].text is not None
and re.search(r"^" + prompt_text, completion.choices[0].text))
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) > 5
assert (len(logprobs.token_logprobs) > 5
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
assert len(logprobs.tokens) > 5
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@ -7,7 +7,7 @@ from allclose_default import get_default_atol, get_default_rtol
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm._C import cache_ops, ops from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip from vllm.utils import get_max_shared_memory_bytes, is_hip
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
@ -237,14 +237,14 @@ def test_paged_attention(
dequantized_key_cache = torch.empty(size=key_cache_shape, dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype, dtype=dtype,
device=device) device=device)
cache_ops.convert_fp8(key_cache, dequantized_key_cache) ops.convert_fp8(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape, dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype, dtype=dtype,
device=device) device=device)
cache_ops.convert_fp8(value_cache, dequantized_value_cache) ops.convert_fp8(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache value_cache = dequantized_value_cache
ref_output = torch.empty_like(query) ref_output = torch.empty_like(query)

View File

@ -4,7 +4,7 @@ from typing import Tuple
import pytest import pytest
import torch import torch
from vllm._C import cache_ops from vllm import _custom_ops as ops
from vllm.utils import is_hip from vllm.utils import is_hip
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
@ -80,7 +80,7 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches] cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel. # Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping) ops.copy_blocks(key_caches, value_caches, block_mapping)
# Run the reference implementation. # Run the reference implementation.
for src, dsts in block_mapping.items(): for src, dsts in block_mapping.items():
@ -145,9 +145,9 @@ def test_reshape_and_cache(
# Clone the KV caches. # Clone the KV caches.
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
cache_ops.convert_fp8(key_cache, cloned_key_cache) ops.convert_fp8(key_cache, cloned_key_cache)
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16) cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
cache_ops.convert_fp8(value_cache, cloned_value_cache) ops.convert_fp8(value_cache, cloned_value_cache)
else: else:
cloned_key_cache = key_cache.clone() cloned_key_cache = key_cache.clone()
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache.clone()
@ -156,14 +156,14 @@ def test_reshape_and_cache(
kv_scale = 1.0 kv_scale = 1.0
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
slot_mapping, kv_cache_dtype, kv_scale) kv_cache_dtype, kv_scale)
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
cache_ops.convert_fp8(key_cache, result_key_cache) ops.convert_fp8(key_cache, result_key_cache)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
cache_ops.convert_fp8(value_cache, result_value_cache) ops.convert_fp8(value_cache, result_value_cache)
# Run the reference implementation. # Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
@ -251,9 +251,8 @@ def test_swap_blocks(
src_value_caches_clone = src_value_caches[0].clone() src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel. # Call the swap_blocks kernel.
cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
block_mapping)
for src, dst in block_mapping.items(): for src, dst in block_mapping.items():
assert torch.allclose(src_key_caches_clone[src].cpu(), assert torch.allclose(src_key_caches_clone[src].cpu(),
@ -291,9 +290,9 @@ def test_fp8_conversion(
cache.uniform_(low, high) cache.uniform_(low, high)
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8) cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
cache_ops.convert_fp8(cache, cache_fp8) ops.convert_fp8(cache, cache_fp8)
converted_cache = torch.empty_like(cache) converted_cache = torch.empty_like(cache)
cache_ops.convert_fp8(cache_fp8, converted_cache) ops.convert_fp8(cache_fp8, converted_cache)
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)

View File

@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
).cuda() ).cuda()
# Load the weights # Load the weights
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts): for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data, weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data) hf_moe.experts[i].w3.weight.data)

View File

@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
import vllm import vllm
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import destroy_model_parallel, initialize_model_parallel
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
@ -19,8 +20,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils.parallel_state import (
destroy_model_parallel, initialize_model_parallel)
def cleanup(): def cleanup():
@ -144,6 +143,11 @@ def baichuan_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider") return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
@pytest.fixture @pytest.fixture
def llama_2_7b_engine_extra_embeddings() -> nn.Module: def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup() cleanup()

View File

@ -62,7 +62,7 @@ def test_baichuan_lora(baichuan_lora_files):
@pytest.mark.skip("Requires multiple GPUs") @pytest.mark.skip("Requires multiple GPUs")
def test_llama_tensor_parallel_equality(baichuan_lora_files): def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
# Cannot use as it will initialize torch.cuda too early... # Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4: # if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}") # pytest.skip(f"Not enough GPUs for tensor parallelism {4}")

View File

@ -170,7 +170,8 @@ def create_random_inputs(
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings(dist_init, num_loras, device) -> None: @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16) lora_dtype=torch.float16)
def create_random_embedding_layer(): def create_random_embedding_layer():
embedding = VocabParallelEmbedding(512, 256) embedding = VocabParallelEmbedding(vocab_size, 256)
embedding.weight.data = torch.rand_like(embedding.weight.data) embedding.weight.data = torch.rand_like(embedding.weight.data)
embedding.weight.data[512:, :] = 0 embedding.weight.data[vocab_size:, :] = 0
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
lora_embedding.create_lora_weights(max_loras, lora_config) lora_embedding.create_lora_weights(max_loras, lora_config)
@ -203,12 +204,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids=list(lora_dict.keys()), active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3, num_inputs=num_loras * 3,
input_size=(200, ), input_size=(200, ),
input_range=(1, 512), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size) vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info) lora_embedding.set_mapping(*mapping_info)
lora_result = lora_embedding(torch.cat(inputs)) lora_result = lora_embedding(torch.cat(inputs))
@ -240,12 +242,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids=[0], active_lora_ids=[0],
num_inputs=num_loras * 3, num_inputs=num_loras * 3,
input_size=(200, ), input_size=(200, ),
input_range=(1, 512), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size) vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, ) lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(inputs)) lora_result = lora_embedding(torch.cat(inputs))
@ -263,7 +266,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
# reason="Fails when loras are in any slot other than the first.") # reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
vocab_size) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16) lora_dtype=torch.float16)
def create_random_embedding_layer(): def create_random_embedding_layer():
embedding = VocabParallelEmbedding(512, 256) embedding = VocabParallelEmbedding(vocab_size, 256)
embedding_data = torch.rand_like(embedding.weight.data) embedding_data = torch.rand_like(embedding.weight.data)
embedding.weight.data = embedding_data embedding.weight.data = embedding_data
embedding.weight.data[512:, :] = 0 embedding.weight.data[vocab_size:, :] = 0
expanded_embedding = VocabParallelEmbedding( expanded_embedding = VocabParallelEmbedding(
512 + lora_config.lora_extra_vocab_size * max_loras, vocab_size + lora_config.lora_extra_vocab_size * max_loras,
256, 256,
org_num_embeddings=512) org_num_embeddings=vocab_size)
expanded_embedding.weight.data[:512, :] = embedding_data expanded_embedding.weight.data[:vocab_size, :] = embedding_data
# We need to deepcopy the embedding as it will be modified # We need to deepcopy the embedding as it will be modified
# in place # in place
lora_embedding = VocabParallelEmbeddingWithLoRA( lora_embedding = VocabParallelEmbeddingWithLoRA(
@ -298,7 +303,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
id_to_index, id_to_index,
layer=lora_embedding, layer=lora_embedding,
layer_weights=torch.zeros( layer_weights=torch.zeros(
(256, 512 + lora_config.lora_extra_vocab_size)), (256, vocab_size + lora_config.lora_extra_vocab_size)),
generate_embeddings_tensor=256, generate_embeddings_tensor=256,
) )
@ -316,7 +321,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids=list(lora_dict.keys()), active_lora_ids=list(lora_dict.keys()),
num_inputs=num_loras * 3, num_inputs=num_loras * 3,
input_size=(200, ), input_size=(200, ),
input_range=(1, 512), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
@ -327,16 +332,18 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
for input_, original_input_, lora_id in zip(inputs, original_inputs, for input_, original_input_, lora_id in zip(inputs, original_inputs,
prompt_mapping): prompt_mapping):
embedding_id = lora_id - 1 embedding_id = lora_id - 1
input_[-1] = 512 + (embedding_id * embeddings_tensor_len) input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
original_input_[-1] = 512 original_input_[-1] = vocab_size
input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1) input_[-2] = vocab_size + (
original_input_[-2] = 512 + embeddings_tensor_len - 1 (embedding_id + 1) * embeddings_tensor_len - 1)
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size) vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, ) lora_embedding.set_mapping(*mapping_info, )
expanded_embedding.weight[512:512 + expanded_embedding.weight[vocab_size:vocab_size +
(embeddings_tensor_len * (embeddings_tensor_len *
max_loras)] = torch.cat(embeddings_tensors) max_loras)] = torch.cat(embeddings_tensors)
@ -370,14 +377,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
active_lora_ids=[0], active_lora_ids=[0],
num_inputs=num_loras * 3, num_inputs=num_loras * 3,
input_size=(200, ), input_size=(200, ),
input_range=(1, 512), input_range=(1, vocab_size),
) )
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
original_inputs = deepcopy(inputs) original_inputs = deepcopy(inputs)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
512, lora_config.lora_extra_vocab_size) vocab_size,
lora_config.lora_extra_vocab_size)
lora_embedding.set_mapping(*mapping_info, ) lora_embedding.set_mapping(*mapping_info, )
lora_result = lora_embedding(torch.cat(original_inputs)) lora_result = lora_embedding(torch.cat(original_inputs))
@ -393,7 +401,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
@torch.inference_mode() @torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: @pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
def test_lm_head_logits_processor(dist_init, num_loras, device,
vocab_size) -> None:
torch.set_default_device(device) torch.set_default_device(device)
max_loras = 8 max_loras = 8
@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_dtype=torch.float16) lora_dtype=torch.float16)
def _pretest(): def _pretest():
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
1024, 32000) 1024, vocab_size)
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
linear.weight.data[:, 32000:] = 0 linear.weight.data[:, vocab_size:] = 0
logits_processor = LogitsProcessor( logits_processor = LogitsProcessor(
32000 + lora_config.lora_extra_vocab_size, 32000) vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
lora_logits_processor = LogitsProcessorWithLoRA( lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, 1024, linear.weight.dtype, linear.weight.device) logits_processor, 1024, linear.weight.dtype, linear.weight.device)
lora_logits_processor.create_lora_weights(max_loras, lora_config) lora_logits_processor.create_lora_weights(max_loras, lora_config)
@ -444,7 +454,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_mapping, lora_mapping,
id_to_index, id_to_index,
max_loras, max_loras,
32000, vocab_size,
lora_config.lora_extra_vocab_size, lora_config.lora_extra_vocab_size,
) )
lora_logits_processor.set_mapping(*mapping_info, ) lora_logits_processor.set_mapping(*mapping_info, )
@ -460,7 +470,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
org_vocab_size:logits_processor.org_vocab_size + org_vocab_size:logits_processor.org_vocab_size +
embeddings_tensor_len] = embeddings_tensor embeddings_tensor_len] = embeddings_tensor
logits_processor.org_vocab_size = (32000 + logits_processor.org_vocab_size = (vocab_size +
lora_config.lora_extra_vocab_size) lora_config.lora_extra_vocab_size)
expected_results = [] expected_results = []
for input_, lora_id in zip(inputs, prompt_mapping): for input_, lora_id in zip(inputs, prompt_mapping):
@ -468,11 +478,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
result = logits_processor._get_logits(hidden_states=input_, result = logits_processor._get_logits(hidden_states=input_,
embedding=linear.weight, embedding=linear.weight,
embedding_bias=None) embedding_bias=None)
result[:, 32000 + embeddings_tensor_len:] = float("-inf") result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
expected_results.append(result) expected_results.append(result)
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
logits_processor.org_vocab_size = 32000 logits_processor.org_vocab_size = vocab_size
# Check that resetting the lora weights succeeds # Check that resetting the lora weights succeeds
@ -489,14 +499,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
lora_mapping = LoRAMapping(index_mapping, prompt_mapping) lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
32000, vocab_size,
lora_config.lora_extra_vocab_size) lora_config.lora_extra_vocab_size)
lora_logits_processor.set_mapping(*mapping_info, ) lora_logits_processor.set_mapping(*mapping_info, )
lora_result = lora_logits_processor._get_logits( lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs), hidden_states=torch.cat(inputs),
embedding=original_weight, embedding=original_weight,
embedding_bias=None)[:, :32000] embedding_bias=None)[:, :vocab_size]
expected_result = logits_processor._get_logits( expected_result = logits_processor._get_logits(
hidden_states=torch.cat(inputs), hidden_states=torch.cat(inputs),
embedding=original_weight, embedding=original_weight,

View File

@ -43,10 +43,52 @@ def _lora_ref_impl(
H1 = H2 = [ H1 = H2 = [
128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456, 128,
3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216, 256,
10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512, 512,
32768, 33024 1024,
1152,
1280,
1536,
2048,
2304,
2560,
2752,
3072,
3456,
3584,
4096,
4608,
5120,
5504,
5632,
6144,
6848,
6912,
7168,
8192,
9216,
10240,
11008,
13824,
14336,
15360,
22016,
24576,
27392,
32000,
32256,
32512,
32768,
33024,
36864,
49152,
64000,
64256,
102400,
102656,
128000,
128256,
] ]
SEED = [0xabcdabcd987] SEED = [0xabcdabcd987]
CUDA_DEVICES = [ CUDA_DEVICES = [

View File

@ -0,0 +1,179 @@
# Adapted from
# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
from dataclasses import dataclass
from typing import List
import pytest
import vllm
from vllm.lora.request import LoRARequest
from .conftest import cleanup
@dataclass
class ModelWithQuantization:
model_path: str
quantization: str
MODELS: List[ModelWithQuantization] = [
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
quantization="AWQ"),
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
quantization="GPTQ"),
]
def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256):
raw_prompts = [
"Give me an orange-ish brown color",
"Give me a neon pink color",
]
def format_prompt_tuples(prompt):
return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
prompts = [format_prompt_tuples(p) for p in raw_prompts]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=max_tokens,
stop=["<|im_end|>"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm = vllm.LLM(model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
max_model_len=400,
tensor_parallel_size=tp_size,
quantization=model.quantization,
trust_remote_code=True)
if model.quantization is None:
expected_no_lora_output = [
"Here are some examples of orange-brown colors",
"I'm sorry, I don't have"
]
expected_lora_output = [
"#ff8050",
"#ff8080",
]
elif model.quantization == "AWQ":
expected_no_lora_output = [
"I'm sorry, I don't understand",
"I'm sorry, I don't understand",
]
expected_lora_output = [
"#f07700: A v",
"#f00000: A v",
]
elif model.quantization == "GPTQ":
expected_no_lora_output = [
"I'm sorry, I don't have",
"I'm sorry, I don't have",
]
expected_lora_output = [
"#f08800: This is",
"#f07788 \n#",
]
def expect_match(output, expected_output):
# HACK: GPTQ lora outputs are just incredibly unstable.
# Assert that the outputs changed.
if (model.quantization == "GPTQ"
and expected_output is expected_lora_output):
assert output != expected_no_lora_output
for i, o in enumerate(output):
assert o.startswith(
'#'), f"Expected example {i} to start with # but got {o}"
return
assert output == expected_output
max_tokens = 10
print("lora adapter created")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=0,
max_tokens=max_tokens)
expect_match(output, expected_no_lora_output)
print("lora 1")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=1,
max_tokens=max_tokens)
expect_match(output, expected_lora_output)
print("no lora")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=0,
max_tokens=max_tokens)
expect_match(output, expected_no_lora_output)
print("lora 2")
output = do_sample(llm,
tinyllama_lora_files,
lora_id=2,
max_tokens=max_tokens)
expect_match(output, expected_lora_output)
print("removing lora")
del llm
cleanup()
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.skip("Requires multiple GPUs")
def test_quant_model_tp_equality(tinyllama_lora_files, model):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 2:
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
llm_tp1 = vllm.LLM(model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1,
quantization=model.quantization,
trust_remote_code=True)
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
del llm_tp1
cleanup()
llm_tp2 = vllm.LLM(model=model.model_path,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=2,
quantization=model.quantization)
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
del llm_tp2
cleanup()
assert output_tp1 == output_tp2

View File

@ -12,7 +12,7 @@ MODELS = [
"gpt2", "gpt2",
"bigcode/tiny_starcoder_py", "bigcode/tiny_starcoder_py",
"EleutherAI/pythia-70m", "EleutherAI/pythia-70m",
"bigscience/bloom-560m", "bigscience/bloom-560m", # Testing alibi slopes.
"microsoft/phi-2", "microsoft/phi-2",
"stabilityai/stablelm-3b-4e1t", "stabilityai/stablelm-3b-4e1t",
# "allenai/OLMo-1B", # Broken # "allenai/OLMo-1B", # Broken

View File

@ -1,3 +1,4 @@
import itertools
import random import random
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from unittest.mock import patch from unittest.mock import patch
@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def create_sampling_params(min_tokens, def create_sampling_params(min_tokens,
eos_token_id=0, eos_token_id=0,
stop_token_ids=None): *,
stop_token_ids: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams( sampling_params = SamplingParams(
min_tokens=min_tokens, min_tokens=min_tokens,
max_tokens=9999, # keep higher than max of min_tokens max_tokens=9999, # keep higher than max of min_tokens
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
# requesting prompt_logprobs changes the structure of `logits`
prompt_logprobs=prompt_logprobs,
) )
sampling_params.eos_token_id = eos_token_id sampling_params.eos_token_id = eos_token_id
return sampling_params return sampling_params
@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
expected_penalization = [] expected_penalization = []
sequence_metadata_list = [] sequence_metadata_list = []
# 20% chance to generate seq group metadata list with all prompts
is_prompt = random.random() < 0.2
while batch_size > 0: while batch_size > 0:
# 20% chance to generate prompt seq group with single sequence
is_prompt = random.random() < 0.2
num_seqs = 1 if is_prompt else random.randint(1, batch_size) num_seqs = 1 if is_prompt else random.randint(1, batch_size)
eos_token_id = random.randint(0, VOCAB_SIZE - 1) eos_token_id = random.randint(0, VOCAB_SIZE - 1)
@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
seq_group_penalization = [] seq_group_penalization = []
for _ in range(num_seqs): for _ in range(num_seqs):
num_input = random.randint(1, 100) num_input = random.randint(1, 100)
num_generated = random.randint(1, 100) if not is_prompt else 0 num_generated = 0 if is_prompt else random.randint(1, 100)
seq_data[next(seq_id_counter)] = create_sequence_data( seq_data[next(seq_id_counter)] = create_sequence_data(
num_input=num_input, num_generated=num_generated) num_input=num_input, num_generated=num_generated)
seq_group_penalization.append(num_generated < min_tokens) seq_group_penalization.append(num_generated < min_tokens)
@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
] ]
} }
prompt_with_penalization_and_prompt_logprobs = {
"expected_penalization": [False, False, True],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(num_input=3),
},
sampling_params=create_sampling_params(1, prompt_logprobs=3),
block_tables={},
),
]
}
stop_penalizing_after_min_tokens = { stop_penalizing_after_min_tokens = {
"expected_penalization": [False], "expected_penalization": [False],
"seq_group_metadata_list": [ "seq_group_metadata_list": [
@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
} }
stop_token_ids = [42, 99, 42, 0] # intentional duplication stop_token_ids = [42, 99, 42, 0] # intentional duplication
simple_combination = { prompt_combination = {
"expected_penalization": [True, False, False], "expected_penalization": [False, True, False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_2",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(num_input=2),
},
sampling_params=create_sampling_params(1, prompt_logprobs=3),
block_tables={},
),
SequenceGroupMetadata(
request_id="test_3",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(
0, stop_token_ids=stop_token_ids),
block_tables={},
)
]
}
stop_token_ids = [1, 999, 37, 37] # intentional duplication
decode_combination = {
"expected_penalization": [True, False, False, True, False],
"seq_group_metadata_list": [ "seq_group_metadata_list": [
SequenceGroupMetadata( SequenceGroupMetadata(
request_id="test_1", request_id="test_1",
@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
), ),
SequenceGroupMetadata( SequenceGroupMetadata(
request_id="test_2", request_id="test_2",
is_prompt=True, is_prompt=False,
seq_data={ seq_data={
next(seq_id_counter): create_sequence_data(), next(seq_id_counter):
create_sequence_data(num_generated=20),
next(seq_id_counter):
create_sequence_data(num_generated=1),
next(seq_id_counter):
create_sequence_data(num_generated=10),
}, },
sampling_params=create_sampling_params( sampling_params=create_sampling_params(
0, stop_token_ids=stop_token_ids), 10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
block_tables={}, block_tables={},
) ),
] ]
} }
@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
test_cases = [ test_cases = [
prompt_without_penalization, prompt_without_penalization,
prompt_with_penalization, prompt_with_penalization,
prompt_with_penalization_and_prompt_logprobs,
stop_penalizing_after_min_tokens, stop_penalizing_after_min_tokens,
simple_combination, prompt_combination,
decode_combination,
] ]
else: else:
test_cases = [generate_test_case()] test_cases = [generate_test_case()]
@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
def run_test_case(*, def run_test_case(*,
expected_penalization=None, expected_penalization=None,
seq_group_metadata_list=None): seq_group_metadata_list=None):
assert expected_penalization, "Invalid test case" assert expected_penalization, \
assert seq_group_metadata_list, "Invalid test case" "Invalid test case, need expected_penalization"
assert seq_group_metadata_list, \
"Invalid test case, need seq_group_metadata_list"
batch_size = 0 batch_size = 0
prompt_lens = [] prompt_lens = []
sampling_params_per_seq = [] sampling_params_per_row = []
for sgm in seq_group_metadata_list: for sgm in seq_group_metadata_list:
num_seqs = len(sgm.seq_data)
batch_size += num_seqs
sampling_params = sgm.sampling_params sampling_params = sgm.sampling_params
for seq_id in sgm.seq_data:
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len()) num_rows = len(sgm.seq_data)
sampling_params_per_seq.append(sampling_params) if sgm.is_prompt:
# a prompt seq_group has only one sequence
seq_data = next(iter(sgm.seq_data.values()))
prompt_len = seq_data.get_prompt_len()
prompt_lens.append(prompt_len)
if sgm.sampling_params.prompt_logprobs:
# with prompt_logprobs each token in the prompt has a row in
# logits
num_rows = prompt_len
batch_size += num_rows
sampling_params_per_row.extend(
itertools.repeat(sampling_params, num_rows))
assert len(
expected_penalization
) == batch_size, \
("Invalid test case, expected_penalization does not match computed"
"batch size")
_, fake_logits, sampler, model_runner = _prepare_test(batch_size) _, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = model_runner._prepare_sample( sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list, seq_group_metadata_list,
prompt_lens=prompt_lens, prompt_lens=prompt_lens if prompt_lens else None,
subquery_lens=prompt_lens) subquery_lens=prompt_lens if prompt_lens else None)
# the logits tensor is modified in-place by the sampler # the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
for logits_idx, (should_penalize, sampling_params) in enumerate( for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_seq)): zip(expected_penalization, sampling_params_per_row)):
tokens_to_check = [sampling_params.eos_token_id] tokens_to_check = [sampling_params.eos_token_id]
if sampling_params.stop_token_ids: if sampling_params.stop_token_ids:

View File

@ -0,0 +1,245 @@
import argparse
import dataclasses
import os
import time
import uuid
from functools import partial
from typing import Type
import torch
import torch.nn as nn
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
TensorSerializer, stream_io)
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
from transformers import AutoConfig, PretrainedConfig
from vllm.distributed import initialize_model_parallel
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.tensorizer_loader import TensorizerArgs
# yapf conflicts with isort for this docstring
# yapf: disable
"""
tensorize_vllm_model.py is a script that can be used to serialize and
deserialize vLLM models. These models can be loaded using tensorizer directly
to the GPU extremely quickly. Tensor encryption and decryption is also
supported, although libsodium must be installed to use it. Install
vllm with tensorizer support using `pip install vllm[tensorizer]`.
To serialize a model, you can run something like this:
python tensorize_vllm_model.py \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
serialize \
--serialized-directory s3://my-bucket/ \
--suffix vllm
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
and saves it to your S3 bucket. A local directory can also be used.
You can also encrypt the model weights with a randomly-generated key by
providing a `--keyfile` argument.
To deserialize a model, you can run something like this:
python tensorize_vllm_model.py \
--model EleutherAI/gpt-j-6B \
--dtype float16 \
deserialize \
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
Which downloads the model tensors from your S3 bucket and deserializes them.
To provide S3 credentials, you can provide `--s3-access-key-id` and
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
the OpenAI entrypoint, as arguments for LLM(), or as environment variables
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
You can also provide a `--keyfile` argument to decrypt the model weights if
they were serialized with encryption.
For more information on the available arguments, run
`python tensorize_vllm_model.py --help`.
"""
def parse_args():
parser = argparse.ArgumentParser(
description="An example script that can be used to serialize and "
"deserialize vLLM models. These models "
"can be loaded using tensorizer directly to the GPU "
"extremely quickly. Tensor encryption and decryption is "
"also supported, although libsodium must be installed to "
"use it.")
parser = EngineArgs.add_cli_args(parser)
subparsers = parser.add_subparsers(dest='command')
serialize_parser = subparsers.add_parser(
'serialize', help="Serialize a model to `--serialized-directory`")
serialize_parser.add_argument(
"--suffix",
type=str,
required=False,
help=(
"The suffix to append to the serialized model directory, which is "
"used to construct the location of the serialized model tensors, "
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
"`--suffix` is `v1`, the serialized model tensors will be "
"saved to "
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
"If none is provided, a random UUID will be used."))
serialize_parser.add_argument(
"--serialized-directory",
type=str,
required=True)
serialize_parser.add_argument(
"--keyfile",
type=str,
required=False,
help=("Encrypt the model weights with a randomly-generated binary key,"
" and save the key at this path"))
deserialize_parser = subparsers.add_parser(
'deserialize',
help=("Deserialize a model from `--path-to-tensors`"
" to verify it can be loaded and used."))
deserialize_parser.add_argument(
"--path-to-tensors",
type=str,
required=True,
help="The local path or S3 URI to the model tensors to deserialize. ")
deserialize_parser.add_argument(
"--keyfile",
type=str,
required=False,
help=("Path to a binary key to use to decrypt the model weights,"
" if the model was serialized with encryption"))
return parser.parse_args()
def make_model_contiguous(model):
# Ensure tensors are saved in memory contiguously
for param in model.parameters():
param.data = param.data.contiguous()
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def serialize():
eng_args_dict = {f.name: getattr(args, f.name) for f in
dataclasses.fields(EngineArgs)}
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
engine = LLMEngine.from_engine_args(engine_args)
model = (engine.model_executor.driver_worker.
model_runner.model)
encryption_params = EncryptionParams.random() if keyfile else None
if keyfile:
with _write_stream(keyfile) as stream:
stream.write(encryption_params.key)
with _write_stream(model_path) as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model)
serializer.close()
print("Serialization complete. Model tensors saved to", model_path)
if keyfile:
print("Key saved to", keyfile)
def deserialize():
config = AutoConfig.from_pretrained(model_ref)
with no_init_or_tensor():
model_class = _get_vllm_model_architecture(config)
model = model_class(config)
before_mem = get_mem_usage()
start = time.time()
if keyfile:
with _read_stream(keyfile) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
tensorizer_args.deserializer_params['encryption'] = \
decryption_params
with (_read_stream(model_path)) as stream, TensorDeserializer(
stream, **tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(model)
end = time.time()
# Brag about how fast we are.
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
print(
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
)
print(f"Memory usage before: {before_mem}")
print(f"Memory usage after: {after_mem}")
return model
args = parse_args()
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
or None)
s3_secret_access_key = (args.s3_secret_access_key
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
_read_stream, _write_stream = (partial(
stream_io.open_stream,
mode=mode,
s3_access_key_id=s3_access_key_id,
s3_secret_access_key=s3_secret_access_key,
s3_endpoint=s3_endpoint,
) for mode in ("rb", "wb+"))
model_ref = args.model
model_name = model_ref.split("/")[1]
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080"
torch.distributed.init_process_group(world_size=1, rank=0)
initialize_model_parallel()
keyfile = args.keyfile if args.keyfile else None
if args.command == "serialize":
input_dir = args.serialized_directory.rstrip('/')
suffix = args.suffix if args.suffix else uuid.uuid4().hex
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
model_path = f"{base_path}/model.tensors"
serialize()
elif args.command == "deserialize":
tensorizer_args = TensorizerArgs.from_cli_args(args)
model_path = args.path_to_tensors
deserialize()
else:
raise ValueError("Either serialize or deserialize must be specified.")

View File

@ -0,0 +1,302 @@
import gc
import subprocess
from unittest.mock import MagicMock, patch
import pytest
import torch
from tests.entrypoints.test_openai_server import ServerRunner
from vllm import SamplingParams
from vllm.config import TensorizerConfig
from vllm.model_executor.tensorizer_loader import (
EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer,
load_with_tensorizer, open_stream)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
model_ref = "facebook/opt-125m"
def is_curl_installed():
try:
subprocess.check_call(['curl', '--version'])
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
@pytest.fixture(autouse=True)
def tensorizer_config():
config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True)
return config
@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent')
def test_load_with_tensorizer(mock_agent, tensorizer_config):
mock_linear_method = MagicMock()
mock_agent_instance = mock_agent.return_value
mock_agent_instance.deserialize.return_value = MagicMock()
result = load_with_tensorizer(tensorizer_config,
linear_method=mock_linear_method)
mock_agent.assert_called_once_with(tensorizer_config,
linear_method=mock_linear_method)
mock_agent_instance.deserialize.assert_called_once()
assert result == mock_agent_instance.deserialize.return_value
def test_is_vllm_model_with_vllm_in_uri(tensorizer_config):
tensorizer_config.vllm_tensorized = True
result = is_vllm_serialized_tensorizer(tensorizer_config)
assert result is True
def test_is_vllm_model_without_vllm_in_uri(tensorizer_config):
tensorizer_config.vllm_tensorized = False
result = is_vllm_serialized_tensorizer(tensorizer_config)
assert result is False
def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
vllm_model = vllm_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
outputs = vllm_model.generate(prompts, sampling_params)
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream)
serializer.write_module(model)
del vllm_model, model
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(model_ref,
load_format="tensorizer",
tensorizer_uri=model_path,
num_readers=1,
vllm_tensorized=True)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
# Assumes SamplingParams being seeded ensures the outputs are deterministic
assert outputs == deserialized_outputs
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_can_deserialize_s3(vllm_runner):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
loaded_hf_model = vllm_runner(
model_ref,
tensorizer_uri=tensorized_path,
load_format="tensorizer",
num_readers=1,
vllm_tensorized=False,
s3_endpoint="object.ord1.coreweave.com",
)
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
assert deserialized_outputs
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_deserialized_encrypted_vllm_model_has_same_outputs(
vllm_runner, tmp_path):
vllm_model = vllm_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
key_path = tmp_path / (model_ref + ".key")
outputs = vllm_model.generate(prompts, sampling_params)
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
encryption_params = EncryptionParams.random()
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model)
with open_stream(key_path, "wb+") as stream:
stream.write(encryption_params.key)
del vllm_model, model
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(model_ref,
tensorizer_uri=model_path,
load_format="tensorizer",
encryption_keyfile=key_path,
num_readers=1,
vllm_tensorized=True)
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
# Assumes SamplingParams being seeded ensures the outputs are deterministic
assert outputs == deserialized_outputs
def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
tmp_path):
hf_model = hf_runner(model_ref)
model_path = tmp_path / (model_ref + ".tensors")
max_tokens = 50
outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream)
serializer.write_module(hf_model.model)
del hf_model
gc.collect()
torch.cuda.empty_cache()
loaded_hf_model = vllm_runner(model_ref,
tensorizer_uri=model_path,
load_format="tensorizer",
num_readers=1,
vllm_tensorized=False)
deserialized_outputs = loaded_hf_model.generate_greedy(
prompts, max_tokens=max_tokens)
assert outputs == deserialized_outputs
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
from huggingface_hub import snapshot_download
from examples.multilora_inference import (create_test_prompts,
process_requests)
model_ref = "meta-llama/Llama-2-7b-hf"
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
test_prompts = create_test_prompts(lora_path)
# Serialize model before deserializing and binding LoRA adapters
vllm_model = vllm_runner(model_ref, )
model_path = tmp_path / (model_ref + ".tensors")
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
with open_stream(model_path, "wb+") as stream:
serializer = TensorSerializer(stream)
serializer.write_module(model)
del vllm_model, model
gc.collect()
torch.cuda.empty_cache()
loaded_vllm_model = vllm_runner(
model_ref,
tensorizer_uri=model_path,
load_format="tensorizer",
num_readers=1,
vllm_tensorized=True,
enable_lora=True,
max_loras=1,
max_lora_rank=8,
max_cpu_loras=2,
max_num_seqs=50,
max_model_len=1000,
)
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
assert loaded_vllm_model
def test_load_without_tensorizer_load_format(vllm_runner):
with pytest.raises(ValueError):
vllm_runner(model_ref, tensorizer_uri="test")
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_tensorize_vllm_model(tmp_path):
# Test serialize command
serialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
tmp_path, "--suffix", "tests"
]
result = subprocess.run(serialize_args, capture_output=True, text=True)
print(result.stdout) # Print the output of the serialize command
assert result.returncode == 0, (f"Serialize command failed with output:"
f"\n{result.stdout}\n{result.stderr}")
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
# Test deserialize command
deserialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors",
path_to_tensors
]
result = subprocess.run(deserialize_args, capture_output=True, text=True)
assert result.returncode == 0, (f"Deserialize command failed with output:"
f"\n{result.stdout}\n{result.stderr}")
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_openai_apiserver_with_tensorizer(tmp_path):
## Serialize model
serialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
tmp_path, "--suffix", "tests"
]
result = subprocess.run(serialize_args, capture_output=True, text=True)
print(result.stdout) # Print the output of the serialize command
assert result.returncode == 0, (f"Serialize command failed with output:"
f"\n{result.stdout}\n{result.stderr}")
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
## Start OpenAI API server
openai_args = [
"--model", model_ref, "--dtype", "float16", "--load-format",
"tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized",
"--port", "8000"
]
server = ServerRunner.remote(openai_args)
print("Server ready.")
assert server.ready.remote()
def test_raise_value_error_on_invalid_load_format(vllm_runner):
with pytest.raises(ValueError):
vllm_runner(model_ref,
load_format="safetensors",
tensorizer_uri="test")
def test_tensorizer_with_tp(vllm_runner):
with pytest.raises(ValueError):
model_ref = "EleutherAI/pythia-1.4b"
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
vllm_runner(
model_ref,
tensorizer_uri=tensorized_path,
load_format="tensorizer",
num_readers=1,
vllm_tensorized=False,
s3_endpoint="object.ord1.coreweave.com",
tensor_parallel_size=2,
)
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
def test_tensorizer_warn_quant(tmp_path):
model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
serialize_args = [
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
model_ref, "--quantization", "gptq", "--tensorizer-uri", "test",
"serialize", "--serialized-directory", tmp_path, "--suffix", "tests"
]
result = subprocess.run(serialize_args, capture_output=True, text=True)
assert 'PerformanceWarning' in result.stderr

View File

@ -1,14 +1,18 @@
import pytest import pytest
import torch import torch
from vllm.config import ModelConfig from vllm.config import ModelConfig, SchedulerConfig
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
@pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size): def test_prepare_prompt(batch_size):
model_runner = ModelRunner(None, None, None, None, None) scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=False)
model_runner = ModelRunner(None, None, scheduler_config, None, None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] prompt_lens = []
@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size):
prompt_len - 1) prompt_len - 1)
selected_token_start_idx += prompt_len selected_token_start_idx += prompt_len
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _, (input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list)) _, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens assert return_prompt_lens == prompt_lens
assert len(slot_mapping) == len(input_tokens)
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size):
assert torch.allclose(attn_metadata.prompt_lens_tensor, assert torch.allclose(attn_metadata.prompt_lens_tensor,
torch.tensor(prompt_lens, device=device)) torch.tensor(prompt_lens, device=device))
assert attn_metadata.prompt_lens == prompt_lens assert attn_metadata.prompt_lens == prompt_lens
assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
assert attn_metadata.num_generation_tokens == 0
assert attn_metadata.max_prompt_len == max(prompt_lens) assert attn_metadata.max_prompt_len == max(prompt_lens)
# Test subquery start locs. # Test subquery start locs.
@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size):
assert torch.allclose(attn_metadata.block_tables, expected) assert torch.allclose(attn_metadata.block_tables, expected)
# Cuda graph should not be used for prerill. # Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is False assert attn_metadata.use_cuda_graph is False
assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (sum(prompt_lens), ) assert len(input_tokens) == sum(prompt_lens)
assert input_positions.shape == (sum(prompt_lens), ) assert len(input_positions) == sum(prompt_lens)
torch.testing.assert_close(input_tokens, input_positions) torch.testing.assert_close(input_tokens, input_positions)
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens, prompt_lens,
subquery_lens=prompt_lens) subquery_lens=prompt_lens)
assert input_tokens.shape == (sum(prompt_lens), ) assert len(input_tokens) == sum(prompt_lens)
assert input_positions.shape == (sum(prompt_lens), ) assert len(input_positions) == sum(prompt_lens)
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices, expected = torch.tensor(expected_selected_token_indices,
device=actual.device, device=actual.device,
dtype=actual.dtype) dtype=actual.dtype)
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
torch.testing.assert_close(input_tokens, input_positions) assert input_tokens == input_positions
actual = sampling_metadata.selected_token_indices actual = sampling_metadata.selected_token_indices
expected = torch.tensor(expected_selected_token_indices, expected = torch.tensor(expected_selected_token_indices,
@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size):
revision=None, revision=None,
enforce_eager=False, enforce_eager=False,
) )
model_runner = ModelRunner(model_config, None, None, None, None) scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=False)
model_runner = ModelRunner(model_config, None, scheduler_config, None,
None)
model_runner.set_block_size(16) model_runner.set_block_size(16)
prompt_lens = [] prompt_lens = []
@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size):
assert seq_group_metadata.token_chunk_size == 1 assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
input_tokens, input_positions, attn_metadata, _, _, _ = ( input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list)) model_runner._prepare_decode(seq_group_metadata_list))
assert len(slot_mapping) == len(input_tokens)
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts. # Verify input metadata is correct for prompts.
device = model_runner.device device = model_runner.device
assert attn_metadata.is_prompt is False assert attn_metadata.is_prompt is False
assert attn_metadata.prompt_lens is None assert attn_metadata.prompt_lens is None
assert attn_metadata.num_prompt_tokens == 0
assert attn_metadata.num_generation_tokens == expected_bs
assert attn_metadata.max_prompt_len is None assert attn_metadata.max_prompt_len is None
assert attn_metadata.subquery_start_loc is None assert attn_metadata.subquery_start_loc is None
assert attn_metadata.seq_start_loc is None assert attn_metadata.seq_start_loc is None
@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size):
model_runner.get_max_block_per_batch()) model_runner.get_max_block_per_batch())
# Cuda graph should not be used for prerill. # Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is True assert attn_metadata.use_cuda_graph is True
assert attn_metadata.kv_cache_dtype == "auto"
assert input_tokens.shape == (expected_bs, ) assert len(input_tokens) == expected_bs
assert input_positions.shape == (expected_bs, ) assert len(input_positions) == expected_bs
torch.testing.assert_close(input_tokens, input_positions) assert input_tokens == input_positions
# Verify Sampling # Verify Sampling
expected_selected_token_indices = [] expected_selected_token_indices = []
@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size):
device=actual.device, device=actual.device,
dtype=actual.dtype) dtype=actual.dtype)
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
def test_empty_seq_group():
"""Verify prepare prompt and decode returns empty output."""
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=False,
)
model_runner = ModelRunner(model_config, None, None, None, None)
model_runner.set_block_size(16)
seq_group_metadata_list = []
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
model_runner._prepare_decode(seq_group_metadata_list))
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
_, _,
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
assert len(input_tokens) == 0
assert len(input_positions) == 0
assert attn_metadata is None
assert len(slot_mapping) == 0
assert len(return_prompt_lens) == 0
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
def get_world_size(group=None):
return 1
def mock_get_process_group_ranks(group=None):
return [0]
monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size)
monkeypatch.setattr(torch.distributed, "get_process_group_ranks",
mock_get_process_group_ranks)
model_config = ModelConfig(
"facebook/opt-125m",
"facebook/opt-125m",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
enforce_eager=enforce_eager,
)
scheduler_config = SchedulerConfig(100000,
100000,
100000,
enable_chunked_prefill=True)
model_runner = ModelRunner(model_config,
None,
scheduler_config,
None,
None,
is_driver_worker=True)
model_runner.set_block_size(16)
# Add prefill requests.
prompt_lens = []
seq_group_metadata_list = []
prefill_metadata_list = []
decode_metadata_list = []
block_tables = {0: [1]}
prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size
for i in range(prefill_batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_lens.append(prompt_len)
seq_data = SequenceData(list(range(prompt_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
)
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
seq_group_metadata_list.append(seq_group_metadata)
prefill_metadata_list.append(seq_group_metadata)
# Add decode requests
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
prompt_len = i % (model_runner.block_size - 1) + 1
prompt_toks = list(range(prompt_len))
seq_data = SequenceData(prompt_toks)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables={0: [1]},
)
assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata)
decode_metadata_list.append(seq_group_metadata)
(input_tokens, input_positions, attn_metadata, _, _, _,
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
assert len(attn_metadata.slot_mapping) == len(input_tokens)
assert len(input_positions) == len(input_tokens)
assert attn_metadata.kv_cache_dtype == "auto"
assert attn_metadata.num_prefills == prefill_batch_size
if enforce_eager:
assert attn_metadata.num_decode_tokens == decode_batch_size
else:
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
decode_batch_size)
assert attn_metadata.num_prefill_tokens == sum(prompt_lens)
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.
prefill_meta = model_runner._prepare_prompt(
prefill_metadata_list).attn_metadata
decode_meta = model_runner._prepare_decode(
decode_metadata_list).attn_metadata
for attr_expected, attr_actual in zip(vars(prefill_meta),
vars(prefill_meta_actual)):
assert attr_expected[1] == attr_actual[1]
for attr_expected, attr_actual in zip(vars(decode_meta),
vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]

193
vllm/_custom_ops.py Normal file
View File

@ -0,0 +1,193 @@
from typing import Dict, Optional
import torch
try:
from vllm._C import cache_ops as vllm_cache_ops
from vllm._C import ops as vllm_ops
except ImportError:
pass
# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.silu_and_mul(out, x)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_and_mul(out, x)
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_tanh_and_mul(out, x)
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_fast(out, x)
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
vllm_ops.gelu_new(out, x)
# page attention ops
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
num_kv_heads, scale, block_tables,
context_lens, block_size, max_context_len,
alibi_slopes, kv_cache_dtype, kv_scale)
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
max_logits: torch.Tensor,
tmp_out: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
num_kv_heads: int,
scale: float,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_size: int,
max_context_len: int,
alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes, kv_cache_dtype,
kv_scale)
# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
is_neox)
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
vllm_ops.batched_rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
# layer norm ops
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
vllm_ops.rms_norm(out, input, weight, epsilon)
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon)
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx,
thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
vllm_ops.gptq_shuffle(q_weight, q_perm, bit)
# squeezellm
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
lookup_table: torch.Tensor) -> None:
vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table)
# marlin
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
size_n, size_k)
# moe
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
block_size: int, sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor) -> None:
vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size,
sorted_token_ids, experts_ids,
num_tokens_post_pad)
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
) -> None:
vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, kv_scale)
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
block_mapping: torch.Tensor) -> None:
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: Dict[int, int]) -> None:
vllm_cache_ops.swap_blocks(src, dst, block_mapping)
def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None:
vllm_cache_ops.convert_fp8(output, input)
#TODO: cuda_utils, custom_ar

View File

@ -1,5 +1,6 @@
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata) AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
@ -8,4 +9,5 @@ __all__ = [
"AttentionMetadata", "AttentionMetadata",
"Attention", "Attention",
"get_attn_backend", "get_attn_backend",
"AttentionMetadataPerStage",
] ]

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
import torch import torch
@ -47,7 +47,8 @@ class AttentionBackend(ABC):
@dataclass @dataclass
class AttentionMetadata: class AttentionMetadataPerStage:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
def asdict_zerocopy(self) -> Dict[str, Any]: def asdict_zerocopy(self) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying.""" """Similar to dataclasses.asdict, but avoids deepcopying."""
@ -59,6 +60,41 @@ class AttentionMetadata:
} }
T = TypeVar("T", bound=AttentionMetadataPerStage)
@dataclass
class AttentionMetadata(Generic[T]):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# The attention metadata for prefill requests in a batch.
# None if there's no prefill requests in a batch.
prefill_metadata: Optional[T]
# The attention metadata for decode requests in a batch.
# None if there's no decode requests in a batch.
decode_metadata: Optional[T]
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# The kv cache's data type.
kv_cache_dtype: str
def __post_init__(self):
if self.num_prefill_tokens > 0:
assert self.num_prefills > 0
assert self.prefill_metadata is not None
if self.num_decode_tokens > 0:
assert self.decode_metadata is not None
class AttentionImpl(ABC): class AttentionImpl(ABC):
@abstractmethod @abstractmethod
@ -80,7 +116,7 @@ class AttentionImpl(ABC):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError

View File

@ -11,7 +11,8 @@ import torch
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
@ -53,7 +54,8 @@ class FlashAttentionBackend(AttentionBackend):
@dataclass @dataclass
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): class FlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend. """Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is NOTE: Any python object stored here is not updated when it is
@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
prompt_lens: Optional[List[int]] prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor. # prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor] prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen. # NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
class FlashAttentionImpl(AttentionImpl): class FlashAttentionImpl(AttentionImpl):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens -------------->| |<--------------- num_prefill_tokens ----------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows: Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->| |<----------------- num_decode_tokens ------------------>|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used. Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding. Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens The prompts might have different lengths, while the generation tokens
always have length 1. always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
""" """
def __init__( def __init__(
@ -155,7 +162,7 @@ class FlashAttentionImpl(AttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: FlashAttentionMetadata, attn_metadata: AttentionMetadata[FlashAttentionMetadata],
kv_scale: float, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -188,52 +195,70 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
kv_scale) kv_scale)
if attn_metadata.is_prompt: num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0: if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention # normal attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
output = flash_attn_varlen_func( out = flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
cu_seqlens_q=attn_metadata.seq_start_loc, cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=attn_metadata.max_prompt_len, max_seqlen_q=prefill_meta.max_prompt_len,
max_seqlen_k=attn_metadata.max_prompt_len, max_seqlen_k=prefill_meta.max_prompt_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
window_size=self.sliding_window, window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to # TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache, # deal with different data types between KV and FP8 KV cache,
# to be addressed separately. # to be addressed separately.
output = PagedAttention.forward_prefix( output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query, query,
key, key,
value, value,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, prefill_meta.block_tables,
attn_metadata.subquery_start_loc, prefill_meta.subquery_start_loc,
attn_metadata.prompt_lens_tensor, prefill_meta.prompt_lens_tensor,
attn_metadata.context_lens, prefill_meta.context_lens,
attn_metadata.max_subquery_len, prefill_meta.max_subquery_len,
self.alibi_slopes, self.alibi_slopes,
) )
else: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
output = PagedAttention.forward_decode( output[num_prefill_tokens:] = PagedAttention.forward_decode(
query, decode_query,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, decode_meta.block_tables,
attn_metadata.context_lens, decode_meta.context_lens,
attn_metadata.max_context_len, decode_meta.max_context_len,
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,

View File

@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
import torch import torch
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -51,7 +52,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@dataclass @dataclass
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
PagedAttentionMetadata):
"""Metadata for FlashAttentionBackend. """Metadata for FlashAttentionBackend.
NOTE: Any python object stored here is not updated when it is NOTE: Any python object stored here is not updated when it is
@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
prompt_lens: Optional[List[int]] prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor. # prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor] prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen. # NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
The prompts might have different lengths, while the generation tokens The prompts might have different lengths, while the generation tokens
always have length 1. always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
""" """
def __init__( def __init__(
@ -155,7 +162,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# AMD Radeon 7900 series (gfx1100) currently does not support # AMD Radeon 7900 series (gfx1100) currently does not support
# xFormers nor FlashAttention. As a temporary workaround, we use # xFormers nor FlashAttention. As a temporary workaround, we use
# naive PyTorch implementation of attention. # naive PyTorch implementation of attention.
self.attn_fuc = _naive_attention() self.attn_fuc = _naive_attention
logger.debug("Using naive attention in ROCmBackend") logger.debug("Using naive attention in ROCmBackend")
elif self.use_triton_flash_attn: elif self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
@ -181,7 +188,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: ROCmFlashAttentionMetadata, attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
kv_scale: float = 1.0, kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention. """Forward pass with FlashAttention and PagedAttention.
@ -218,9 +225,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
kv_scale, kv_scale,
) )
if attn_metadata.is_prompt: num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0: if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention # triton attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
@ -230,63 +253,69 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key = self.repeat_kv(key, self.num_queries_per_kv) key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn: if self.use_naive_attn:
output = self.attn_fuc( out = self.attn_fuc(
query, query,
key, key,
value, value,
attn_metadata.prompt_lens, prefill_meta.prompt_lens,
self.scale, self.scale,
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
output, _ = self.attn_func( out, _ = self.attn_func(
query, query,
key, key,
value, value,
None, None,
attn_metadata.seq_start_loc, prefill_meta.seq_start_loc,
attn_metadata.seq_start_loc, prefill_meta.seq_start_loc,
attn_metadata.max_prompt_len, prefill_meta.max_prompt_len,
attn_metadata.max_prompt_len, prefill_meta.max_prompt_len,
True, True,
self.scale, self.scale,
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
output = self.attn_func( out = self.attn_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
cu_seqlens_q=attn_metadata.seq_start_loc, cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc, cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=attn_metadata.max_prompt_len, max_seqlen_q=prefill_meta.max_prompt_len,
max_seqlen_k=attn_metadata.max_prompt_len, max_seqlen_k=prefill_meta.max_prompt_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
output = PagedAttention.forward_prefix( output[:num_prefill_tokens] = PagedAttention.forward_prefix(
query, query,
key, key,
value, value,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, prefill_meta.block_tables,
attn_metadata.subquery_start_loc, prefill_meta.subquery_start_loc,
attn_metadata.prompt_lens_tensor, prefill_meta.prompt_lens_tensor,
attn_metadata.context_lens, prefill_meta.context_lens,
attn_metadata.max_subquery_len, prefill_meta.max_subquery_len,
self.alibi_slopes, self.alibi_slopes,
) )
else:
if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
output = PagedAttention.forward_decode( output[num_prefill_tokens:] = PagedAttention.forward_decode(
query, decode_query,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, decode_meta.block_tables,
attn_metadata.context_lens, decode_meta.context_lens,
attn_metadata.max_context_len, decode_meta.max_context_len,
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
@ -305,26 +334,21 @@ def _naive_attention(
prompt_lens: List[int], prompt_lens: List[int],
scale: float, scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = query.shape[0]
output = torch.empty_like(query) output = torch.empty_like(query)
start = 0 start = 0
for _, prompt_len in enumerate(prompt_lens): for _, prompt_len in enumerate(prompt_lens):
end = start + prompt_len end = start + prompt_len
out = _naive_masked_attention( out = _naive_masked_attention(
query[None, start:end], query[start:end],
key[None, start:end], key[start:end],
value[None, start:end], value[start:end],
scale, scale,
) )
# TODO(woosuk): Unnecessary copy. Optimize. # TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out) output[start:end].copy_(out)
start += prompt_len start += prompt_len
# Using view got RuntimeError: view size is not compatible return output
# with input tensor's size and stride (at least one
# dimension spans across two contiguous subspaces).
# Use reshape instead.
return output.reshape(num_tokens, -1)
def _naive_masked_attention( def _naive_masked_attention(
@ -333,14 +357,13 @@ def _naive_masked_attention(
value: torch.Tensor, value: torch.Tensor,
scale: float, scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
seq_len, _, _ = query.shape seq_len, head_size, head_dim = query.shape
attn_mask = torch.triu(torch.ones(seq_len, attn_mask = torch.triu(torch.ones(seq_len,
seq_len, seq_len,
dtype=query.dtype, dtype=query.dtype,
device=query.device), device=query.device),
diagonal=1) diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min attn_mask = attn_mask * torch.finfo(query.dtype).min
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
attn_weights = attn_weights + attn_mask.float() attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)

View File

@ -7,7 +7,8 @@ import torch
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
@ -49,7 +50,8 @@ class TorchSDPABackend(AttentionBackend):
@dataclass @dataclass
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
AttentionMetadataPerStage):
"""Metadata for TorchSDPABackend. """Metadata for TorchSDPABackend.
""" """
# Currently, input sequences can only contain all prompts # Currently, input sequences can only contain all prompts
@ -57,15 +59,6 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
is_prompt: bool is_prompt: bool
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
prompt_lens: Optional[List[int]] prompt_lens: Optional[List[int]]
prompt_lens_tensor: Optional[torch.Tensor]
num_prompt_tokens: int
num_generation_tokens: int
max_subquery_len: Optional[int] = None
max_prompt_len: Optional[int] = None
subquery_start_loc: Optional[torch.Tensor] = None
seq_start_loc: Optional[torch.Tensor] = None
use_cuda_graph: bool = False
def __post_init__(self): def __post_init__(self):
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
@ -224,7 +217,7 @@ def _make_alibi_bias(
bias = bias[None, :] - bias[:, None] bias = bias[None, :] - bias[:, None]
num_heads = alibi_slopes.shape[0] num_heads = alibi_slopes.shape[0]
bias = bias[None, :].expand(num_heads, prompt_len, prompt_len) bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]) bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty( inf_mask = torch.empty(
(1, prompt_len, prompt_len), (1, prompt_len, prompt_len),

View File

@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias) LowerTriangularMaskWithTensorBias)
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
@ -54,7 +55,7 @@ class XFormersBackend(AttentionBackend):
@dataclass @dataclass
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
"""Metadata for XFormersbackend. """Metadata for XFormersbackend.
NOTE: Any python object stored here is not updated when it is NOTE: Any python object stored here is not updated when it is
@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
# Currently, input sequences can only contain all prompts # Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts. # or all decoding. True if all sequences are prompts.
is_prompt: bool is_prompt: bool
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The prompt length per sequence. None if it is a decoding. # (batch_size,). The prompt length per sequence. None if it is a decoding.
prompt_lens: Optional[List[int]] prompt_lens: Optional[List[int]]
# prompt_lens stored as a tensor. # prompt_lens stored as a tensor.
prompt_lens_tensor: Optional[torch.Tensor] prompt_lens_tensor: Optional[torch.Tensor]
# The number of prompt tokens. Doesn't include padding.
num_prompt_tokens: int
# The number of generation tokens. Doesn't include padding.
num_generation_tokens: int
# NOTE(sang): Definition of context_len, subquery_len, and seqlen. # NOTE(sang): Definition of context_len, subquery_len, and seqlen.
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
@ -123,18 +115,27 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
class XFormersImpl(AttentionImpl): class XFormersImpl(AttentionImpl):
""" """
If the input tensors contain prompt tokens, the layout is as follows: If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prompt_tokens --------------->| |<--------------- num_prefill_tokens ----------------->|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->| |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
Otherwise, the layout is as follows: Otherwise, the layout is as follows:
|<------------------ num_generation_tokens (M) ----------------->| |<----------------- num_decode_tokens ------------------>|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
Generation tokens can contain padding when cuda-graph is used. Generation tokens can contain padding when cuda-graph is used.
Currently, prompt tokens don't contain any padding. Currently, prompt tokens don't contain any padding.
The prompts might have different lengths, while the generation tokens The prompts might have different lengths, while the generation tokens
always have length 1. always have length 1.
If chunked prefill is enabled, prefill tokens and decode tokens can be
batched together in a flattened 1D query.
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
Currently, cuda graph is disabled for chunked prefill, meaning there's no
padding between prefill and decode tokens.
""" """
def __init__( def __init__(
@ -170,7 +171,7 @@ class XFormersImpl(AttentionImpl):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: XFormersMetadata, attn_metadata: AttentionMetadata[XFormersMetadata],
kv_scale: float, kv_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention. """Forward pass with xFormers and PagedAttention.
@ -202,59 +203,61 @@ class XFormersImpl(AttentionImpl):
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
kv_scale) kv_scale)
if attn_metadata.is_prompt: num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run. # Prompt run.
if kv_cache is None or attn_metadata.block_tables.numel() == 0: if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# normal attention. # normal attention.
# block tables are empty if the prompt does not have a cached # block tables are empty if the prompt does not have a cached
# prefix. # prefix.
if self.num_kv_heads != self.num_heads: out = self._run_memory_efficient_xformers_forward(
# As of Nov 2023, xformers only supports MHA. For MQA/GQA, query, key, value, prefill_meta)
# project the key and value tensors to the desired number of assert out.shape == output[:num_prefill_tokens].shape
# heads. output[:num_prefill_tokens] = out
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
output = self._run_memory_efficient_xformers_forward(
query, key, value, attn_metadata)
else: else:
# prefix-enabled attention # prefix-enabled attention
# TODO(Hai) this triton kernel has regression issue (broke) to # TODO(Hai) this triton kernel has regression issue (broke) to
# deal with different data types between KV and FP8 KV cache, # deal with different data types between KV and FP8 KV cache,
# to be addressed separately. # to be addressed separately.
output = PagedAttention.forward_prefix( out = PagedAttention.forward_prefix(
query, query,
key, key,
value, value,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, prefill_meta.block_tables,
attn_metadata.subquery_start_loc, prefill_meta.subquery_start_loc,
attn_metadata.prompt_lens_tensor, prefill_meta.prompt_lens_tensor,
attn_metadata.context_lens, prefill_meta.context_lens,
attn_metadata.max_subquery_len, prefill_meta.max_subquery_len,
self.alibi_slopes, self.alibi_slopes,
) )
else: assert output[:num_prefill_tokens].shape == out.shape
# Decoding run. output[:num_prefill_tokens] = out
output = PagedAttention.forward_decode(
query, if decode_meta := attn_metadata.decode_metadata:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.block_tables, decode_meta.block_tables,
attn_metadata.context_lens, decode_meta.context_lens,
attn_metadata.max_context_len, decode_meta.max_context_len,
attn_metadata.kv_cache_dtype, attn_metadata.kv_cache_dtype,
self.num_kv_heads, self.num_kv_heads,
self.scale, self.scale,
@ -275,13 +278,30 @@ class XFormersImpl(AttentionImpl):
"""Attention for 1D query of multiple prompts. Multiple prompt """Attention for 1D query of multiple prompts. Multiple prompt
tokens are flattened in to `query` input. tokens are flattened in to `query` input.
See https://facebookresearch.github.io/xformers/components/ops.html
for API spec.
Args: Args:
output: shape = [num_prompt_tokens, num_heads, head_size] output: shape = [num_prefill_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size] query: shape = [num_prefill_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_kv_heads, head_size] key: shape = [num_prefill_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_kv_heads, head_size] value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention. attn_metadata: Metadata for attention.
""" """
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
# Note that the output also has the same shape (which is different
# from a spec from the doc).
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv, key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at # Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration. # the very attention layer of every iteration.
# FIXME(woosuk): This is a hack. # FIXME(woosuk): This is a hack.
@ -302,6 +322,7 @@ class XFormersImpl(AttentionImpl):
# TODO(woosuk): Too many view operations. Let's try to reduce # TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability. # them in the future for code readability.
if self.alibi_slopes is None: if self.alibi_slopes is None:
# Add the batch dimension.
query = query.unsqueeze(0) query = query.unsqueeze(0)
key = key.unsqueeze(0) key = key.unsqueeze(0)
value = value.unsqueeze(0) value = value.unsqueeze(0)
@ -312,14 +333,13 @@ class XFormersImpl(AttentionImpl):
attn_bias=attn_metadata.attn_bias[0], attn_bias=attn_metadata.attn_bias[0],
p=0.0, p=0.0,
scale=self.scale) scale=self.scale)
return out.view_as(original_query)
return out.view_as(query)
# Attention with alibi slopes. # Attention with alibi slopes.
# FIXME(woosuk): Because xformers does not support dynamic sequence # FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by # lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts. # one. This is inefficient, especially when we have many short prompts.
output = torch.empty_like(query) output = torch.empty_like(original_query)
start = 0 start = 0
for i, prompt_len in enumerate(attn_metadata.prompt_lens): for i, prompt_len in enumerate(attn_metadata.prompt_lens):
end = start + prompt_len end = start + prompt_len
@ -331,7 +351,7 @@ class XFormersImpl(AttentionImpl):
p=0.0, p=0.0,
scale=self.scale) scale=self.scale)
# TODO(woosuk): Unnecessary copy. Optimize. # TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0)) output[start:end].copy_(out.view_as(original_query[start:end]))
start += prompt_len start += prompt_len
return output return output

View File

@ -4,7 +4,8 @@ from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import (AttentionMetadata,
AttentionMetadataPerStage)
from vllm.attention.selector import get_attn_backend from vllm.attention.selector import get_attn_backend
@ -41,7 +42,7 @@ class Attention(nn.Module):
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
kv_cache: Optional[torch.Tensor], kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
kv_scale: float = 1.0, kv_scale: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata, return self.impl.forward(query, key, value, kv_cache, attn_metadata,

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
from vllm._C import cache_ops, ops from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@ -13,11 +13,6 @@ _PARTITION_SIZE = 512
@dataclass @dataclass
class PagedAttentionMetadata: class PagedAttentionMetadata:
"""Metadata for PagedAttention.""" """Metadata for PagedAttention."""
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
# (batch_size,). The length of context (tokens stored in KV cache) per # (batch_size,). The length of context (tokens stored in KV cache) per
# sequence. WARNING: When it is a prefill request, it doesn't include new # sequence. WARNING: When it is a prefill request, it doesn't include new
# tokens. When it is for decoding, it includes a new token. # tokens. When it is for decoding, it includes a new token.
@ -31,7 +26,6 @@ class PagedAttentionMetadata:
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured. # captured.
block_tables: Optional[torch.Tensor] block_tables: Optional[torch.Tensor]
kv_cache_dtype: str
class PagedAttention: class PagedAttention:
@ -75,7 +69,7 @@ class PagedAttention:
kv_cache_dtype: str, kv_cache_dtype: str,
kv_scale: float, kv_scale: float,
) -> None: ) -> None:
cache_ops.reshape_and_cache( ops.reshape_and_cache(
key, key,
value, value,
key_cache, key_cache,
@ -205,11 +199,11 @@ class PagedAttention:
) -> None: ) -> None:
src_key_cache = src_kv_cache[0] src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0] dst_key_cache = dst_kv_cache[0]
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
src_value_cache = src_kv_cache[1] src_value_cache = src_kv_cache[1]
dst_value_cache = dst_kv_cache[1] dst_value_cache = dst_kv_cache[1]
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
@ -218,4 +212,4 @@ class PagedAttention:
) -> None: ) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches] key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches]
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) ops.copy_blocks(key_caches, value_caches, src_to_dists)

View File

@ -415,7 +415,11 @@ def attn_fwd(
return return
is_mqa = hq != hk is_mqa = hq != hk
off_h_k = off_h_q % hk if is_mqa else off_h_q if is_mqa: # noqa: SIM108
off_h_k = off_h_q % hk
else:
off_h_k = off_h_q
n_extra_tokens = 0 n_extra_tokens = 0
if seqlen_k < BLOCK_N: if seqlen_k < BLOCK_N:
n_extra_tokens = BLOCK_N - seqlen_k n_extra_tokens = BLOCK_N - seqlen_k
@ -677,8 +681,7 @@ def check_args(
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
# TODO: Change assert if we support qkl f8 and v f16 # TODO: Change assert if we support qkl f8 and v f16
assert q.dtype == k.dtype and q.dtype == v.dtype assert q.dtype == k.dtype and q.dtype == v.dtype
# TODO: Fix assert to check head size <=256 once supported assert head_size <= 256
assert head_size <= 128
assert o.shape == q.shape assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0 assert (nheads_q % nheads_k) == 0
@ -729,7 +732,7 @@ class _attention(torch.autograd.Function):
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
# Get closest power of 2 over or equal to 32. # Get closest power of 2 over or equal to 32.
unpadded_head_dims = {32, 64, 128} unpadded_head_dims = {32, 64, 128, 256}
if head_size not in unpadded_head_dims: if head_size not in unpadded_head_dims:
padded_d_model = None padded_d_model = None
for i in unpadded_head_dims: for i in unpadded_head_dims:

View File

@ -1,4 +1,5 @@
import enum import enum
import os
from functools import lru_cache from functools import lru_cache
from typing import Type from typing import Type
@ -10,6 +11,8 @@ from vllm.utils import is_cpu, is_hip, is_tpu
logger = init_logger(__name__) logger = init_logger(__name__)
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
class _Backend(enum.Enum): class _Backend(enum.Enum):
FLASH_ATTN = enum.auto() FLASH_ATTN = enum.auto()
@ -83,4 +86,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
"Cannot use FlashAttention backend because the flash_attn package " "Cannot use FlashAttention backend because the flash_attn package "
"is not found. Please install it for better performance.") "is not found. Please install it for better performance.")
return _Backend.XFORMERS return _Backend.XFORMERS
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var is not None:
return _Backend[backend_by_env_var]
# Default case.
return _Backend.FLASH_ATTN return _Backend.FLASH_ATTN

View File

@ -1,8 +1,10 @@
import enum import enum
import io
import json import json
import os import os
import typing
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, Optional, Union from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
import torch import torch
from packaging.version import Version from packaging.version import Version
@ -16,6 +18,8 @@ from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
from vllm.model_executor.tensorizer_loader import TensorizerArgs
logger = init_logger(__name__) logger = init_logger(__name__)
_GB = 1 << 30 _GB = 1 << 30
@ -139,13 +143,14 @@ class ModelConfig:
def _verify_load_format(self) -> None: def _verify_load_format(self) -> None:
load_format = self.load_format.lower() load_format = self.load_format.lower()
supported_load_format = [ supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy" "auto", "pt", "safetensors", "npcache", "dummy", "tensorizer"
] ]
rocm_not_supported_load_format = [] rocm_not_supported_load_format: List[str] = []
if load_format not in supported_load_format: if load_format not in supported_load_format:
raise ValueError( raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of " f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") "'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or "
"'dummy'.")
if is_hip() and load_format in rocm_not_supported_load_format: if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [ rocm_supported_load_format = [
f for f in supported_load_format f for f in supported_load_format
@ -158,7 +163,9 @@ class ModelConfig:
# TODO: Remove this check once HF updates the pt weights of Mixtral. # TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", []) architectures = getattr(self.hf_config, "architectures", [])
if "MixtralForCausalLM" in architectures and load_format == "pt": # architectures can be None instead of []
if architectures and "MixtralForCausalLM" in architectures \
and load_format == "pt":
raise ValueError( raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. " "Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. ") "Please use the 'safetensors' format instead. ")
@ -563,9 +570,16 @@ class SchedulerConfig:
if max_num_batched_tokens is not None: if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
else: else:
# If max_model_len is too short, use 2048 as the default value for if enable_chunked_prefill:
# higher throughput. # For chunked prefill, choose the well-tuned batch size.
self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_batched_tokens = 768
else:
# If max_model_len is too short, use 2048 as the default value
# for higher throughput.
self.max_num_batched_tokens = max(max_model_len, 2048)
if enable_chunked_prefill:
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
self.max_num_seqs = max_num_seqs self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.use_v2_block_manager = use_v2_block_manager self.use_v2_block_manager = use_v2_block_manager
@ -675,6 +689,9 @@ class SpeculativeConfig:
"num_speculative_tokens to be provided, but found " "num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.") f"{speculative_model=} and {num_speculative_tokens=}.")
assert (speculative_model is not None
and num_speculative_tokens is not None)
# TODO: The user should be able to specify revision/quantization/max # TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported. # model len for the draft model. It is not currently supported.
draft_revision = None draft_revision = None
@ -818,9 +835,12 @@ class LoRAConfig:
self.lora_dtype = model_config.dtype self.lora_dtype = model_config.dtype
elif isinstance(self.lora_dtype, str): elif isinstance(self.lora_dtype, str):
self.lora_dtype = getattr(torch, self.lora_dtype) self.lora_dtype = getattr(torch, self.lora_dtype)
if model_config.quantization is not None: if model_config.quantization and model_config.quantization not in [
raise ValueError( "awq", "gptq"
"LoRA is not supported with quantized models yet.") ]:
# TODO support marlin and squeezellm
logger.warning(f"{model_config.quantization} quantization is not "
"tested with LoRA yet.")
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
if scheduler_config.max_num_batched_tokens > 65528: if scheduler_config.max_num_batched_tokens > 65528:
@ -872,6 +892,65 @@ class VisionLanguageConfig:
f"{[x.name for x in cls.ImageInputType]}.") from e f"{[x.name for x in cls.ImageInputType]}.") from e
@dataclass
class TensorizerConfig:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
str, bytes, os.PathLike, int]
vllm_tensorized: bool
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
model_class: Optional[torch.nn.Module] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Union[str, torch.dtype] = None
def _construct_tensorizer_args(self) -> "TensorizerArgs":
from vllm.model_executor.tensorizer_loader import TensorizerArgs
tensorizer_args = {
"tensorizer_uri": self.tensorizer_uri,
"vllm_tensorized": self.vllm_tensorized,
"verify_hash": self.verify_hash,
"num_readers": self.num_readers,
"encryption_keyfile": self.encryption_keyfile,
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
return TensorizerArgs(**tensorizer_args)
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
if (parallel_config.tensor_parallel_size > 1
and self.tensorizer_uri is not None):
raise ValueError(
"Loading to multiple GPUs is not currently supported with "
"vLLM-serialized models. Please set tensor_parallel_size=1."
" or use a non-vLLM-serialized model, such as a "
"serialized Hugging Face `PretrainedModel`.")
def verify_with_model_config(self, model_config) -> None:
if (model_config.quantization is not None
and self.tensorizer_uri is not None):
from vllm.model_executor.tensorizer_loader import (
tensorizer_warning)
tensorizer_warning(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors.")
if (model_config.load_format != "tensorizer"
and self.tensorizer_uri is not None):
raise ValueError(
"A tensorizer uri was passed for tensorizer loading, but the "
f"load format was set to {model_config.load_format}. "
"Please set the load format to 'tensorizer' to use "
f"tensorizer args.")
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16, "half": torch.float16,
"float16": torch.float16, "float16": torch.float16,
@ -986,7 +1065,7 @@ def _get_and_verify_max_len(
derived_max_model_len *= scaling_factor derived_max_model_len *= scaling_factor
if max_model_len is None: if max_model_len is None:
max_model_len = derived_max_model_len max_model_len = int(derived_max_model_len)
elif max_model_len > derived_max_model_len: elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length # Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input # that will be bigger than derived_max_model_len. We compare user input
@ -1005,6 +1084,21 @@ def _get_and_verify_max_len(
return int(max_model_len) return int(max_model_len)
@dataclass
class DecodingConfig:
"""Dataclass which contains the decoding strategy of the engine"""
# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
guided_decoding_backend: str = 'outlines'
def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer']
backend = self.guided_decoding_backend
if backend not in valid_guided_backends:
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
f"must be one of {valid_guided_backends}")
@dataclass(frozen=True) @dataclass(frozen=True)
class EngineConfig: class EngineConfig:
"""Dataclass which contains all engine-related configuration. This """Dataclass which contains all engine-related configuration. This
@ -1019,6 +1113,8 @@ class EngineConfig:
lora_config: Optional[LoRAConfig] lora_config: Optional[LoRAConfig]
vision_language_config: Optional[VisionLanguageConfig] vision_language_config: Optional[VisionLanguageConfig]
speculative_config: Optional[SpeculativeConfig] speculative_config: Optional[SpeculativeConfig]
decoding_config: Optional[DecodingConfig]
tensorizer_config: Optional[TensorizerConfig]
def __post_init__(self): def __post_init__(self):
"""Verify configs are valid & consistent with each other. """Verify configs are valid & consistent with each other.
@ -1026,6 +1122,11 @@ class EngineConfig:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.tensorizer_config:
self.tensorizer_config.verify_with_parallel_config(
self.parallel_config)
self.tensorizer_config.verify_with_model_config(self.model_config)
if self.lora_config: if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod, abstractproperty from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Protocol from typing import Dict, FrozenSet, List, Optional, Protocol
from vllm.utils import Device from vllm.utils import Device
@ -10,23 +10,28 @@ class Block(ABC):
def append_token_ids(self, token_ids: List[int]) -> None: def append_token_ids(self, token_ids: List[int]) -> None:
pass pass
@abstractproperty @property
@abstractmethod
def block_id(self) -> Optional[int]: def block_id(self) -> Optional[int]:
pass pass
@abstractproperty @property
@abstractmethod
def token_ids(self) -> List[int]: def token_ids(self) -> List[int]:
pass pass
@abstractproperty @property
@abstractmethod
def num_empty_slots(self) -> int: def num_empty_slots(self) -> int:
pass pass
@abstractproperty @property
@abstractmethod
def is_full(self) -> bool: def is_full(self) -> bool:
pass pass
@abstractproperty @property
@abstractmethod
def prev_block(self) -> Optional["Block"]: def prev_block(self) -> Optional["Block"]:
pass pass
@ -47,12 +52,13 @@ class Block(ABC):
class BlockAllocator(ABC): class BlockAllocator(ABC):
@abstractmethod @abstractmethod
def allocate_mutable(self, prev_block: Optional[Block]) -> Block: def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
pass pass
@abstractmethod @abstractmethod
def allocate_immutable(self, prev_block: Optional[Block], def allocate_immutable(self, prev_block: Optional[Block],
token_ids: List[int]) -> Block: token_ids: List[int], device: Device) -> Block:
pass pass
@abstractmethod @abstractmethod
@ -64,11 +70,12 @@ class BlockAllocator(ABC):
pass pass
@abstractmethod @abstractmethod
def get_num_free_blocks(self) -> int: def get_num_free_blocks(self, device: Device) -> int:
pass pass
@abstractproperty @property
def all_block_ids(self) -> frozenset[int]: @abstractmethod
def all_block_ids(self) -> FrozenSet[int]:
pass pass
@abstractmethod @abstractmethod

View File

@ -2,7 +2,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from itertools import count, takewhile from itertools import count, takewhile
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from typing import Set
from vllm.block import BlockTable, PhysicalTokenBlock from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
@ -231,10 +233,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if self.enable_caching: if self.enable_caching:
logger.info("Automatic prefix caching is enabled.") logger.info("Automatic prefix caching is enabled.")
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
num_gpu_blocks) Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
num_cpu_blocks) Device.CPU, block_size, num_cpu_blocks)
else: else:
self.gpu_allocator = UncachedBlockAllocator( self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks) Device.GPU, block_size, num_gpu_blocks)
@ -588,7 +590,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for b in takewhile(lambda b: b.computed, block_table[:-1]) for b in takewhile(lambda b: b.computed, block_table[:-1])
] ]
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Return the block ids that are common for a given sequence group. """Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks). Used in prefill (can skip prefill of some blocks).

View File

@ -1,5 +1,6 @@
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
from typing import Dict, List, Optional from typing import Dict, List, Optional
from typing import Sequence as GenericSequence
from vllm.core.block.block_table import BlockTable from vllm.core.block.block_table import BlockTable
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
@ -205,7 +206,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
# as computed. # as computed.
self.block_allocator.mark_blocks_as_computed() self.block_allocator.mark_blocks_as_computed()
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Determine which blocks for which we skip prefill. """Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks. With prefix caching we can skip prefill for previously-generated blocks.

View File

@ -1,6 +1,7 @@
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List from typing import Dict, List
from typing import Sequence as GenericSequence
from vllm.sequence import Sequence, SequenceGroup from vllm.sequence import Sequence, SequenceGroup
@ -103,7 +104,8 @@ class BlockSpaceManager(ABC):
pass pass
@abstractmethod @abstractmethod
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
pass pass
@abstractmethod @abstractmethod

View File

@ -42,8 +42,8 @@ class SchedulingBudget:
""" """
token_budget: int token_budget: int
max_num_seqs: int max_num_seqs: int
_requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set) _requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set) _requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0 _num_batched_tokens: int = 0
_num_curr_seqs: int = 0 _num_curr_seqs: int = 0
@ -133,14 +133,18 @@ class SchedulerOutputs:
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy) and not self.blocks_to_swap_out and not self.blocks_to_copy)
def _sort_by_lora_ids(self) -> bool: def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted( self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups, self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
@property @property
def lora_requests(self) -> Set[LoRARequest]: def lora_requests(self) -> Set[LoRARequest]:
return {g.seq_group.lora_request for g in self.scheduled_seq_groups} return {
g.seq_group.lora_request
for g in self.scheduled_seq_groups
if g.seq_group.lora_request is not None
}
@dataclass @dataclass
@ -333,7 +337,8 @@ class Scheduler:
self.free_seq(seq) self.free_seq(seq)
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)
@ -400,7 +405,7 @@ class Scheduler:
budget.subtract_num_seqs(seq_group.request_id, budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs) num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0: if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.pop(seq_group.lora_int_id) curr_loras.remove(seq_group.lora_int_id)
if running_queue: if running_queue:
# Preempt the lowest-priority sequence groups. # Preempt the lowest-priority sequence groups.
@ -492,7 +497,7 @@ class Scheduler:
now = time.time() now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue) swapped_queue = policy.sort_by_priority(now, swapped_queue)
leftover_swapped = deque() leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue: while swapped_queue:
seq_group = swapped_queue[0] seq_group = swapped_queue[0]
@ -503,7 +508,9 @@ class Scheduler:
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
lora_int_id = seq_group.lora_int_id lora_int_id = seq_group.lora_int_id
if (lora_int_id > 0 and lora_int_id not in curr_loras assert curr_loras is not None
assert self.lora_config is not None
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
and len(curr_loras) >= self.lora_config.max_loras): and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so # We don't have a space for another LoRA, so
# we ignore this request for now. # we ignore this request for now.
@ -589,7 +596,7 @@ class Scheduler:
# Copy the queue so that the input queue is not modified. # Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue]) waiting_queue = deque([s for s in waiting_queue])
leftover_waiting_sequences = deque() leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue: while self._passed_delay(time.time()) and waiting_queue:
seq_group = waiting_queue[0] seq_group = waiting_queue[0]
@ -631,6 +638,8 @@ class Scheduler:
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
lora_int_id = seq_group.lora_int_id lora_int_id = seq_group.lora_int_id
assert curr_loras is not None
assert self.lora_config is not None
if (self.lora_enabled and lora_int_id > 0 if (self.lora_enabled and lora_int_id > 0
and lora_int_id not in curr_loras and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras): and len(curr_loras) >= self.lora_config.max_loras):
@ -670,7 +679,7 @@ class Scheduler:
def _schedule_default(self) -> SchedulerOutputs: def _schedule_default(self) -> SchedulerOutputs:
"""Schedule queued requests. """Schedule queued requests.
The current policy is designed to opimimize the throughput. First, The current policy is designed to optimize the throughput. First,
it batches as many prefill requests as possible. And it schedules it batches as many prefill requests as possible. And it schedules
decodes. If there's a pressure on GPU memory, decode requests can decodes. If there's a pressure on GPU memory, decode requests can
be swapped or preempted. be swapped or preempted.
@ -776,7 +785,7 @@ class Scheduler:
token_budget=self.scheduler_config.max_num_batched_tokens, token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs, max_num_seqs=self.scheduler_config.max_num_seqs,
) )
curr_loras = set() curr_loras: Set[int] = set()
remaining_waiting, prefills = (self.waiting, remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty()) SchedulerPrefillOutputs.create_empty())
@ -826,13 +835,12 @@ class Scheduler:
# Update swapped requests. # Update swapped requests.
self.swapped = remaining_swapped self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out) self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs( return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups + scheduled_seq_groups=(prefills.seq_groups +
running_scheduled.decode_seq_groups +
running_scheduled.prefill_seq_groups + running_scheduled.prefill_seq_groups +
swapped_in.decode_seq_groups + swapped_in.prefill_seq_groups +
swapped_in.prefill_seq_groups), running_scheduled.decode_seq_groups +
swapped_in.decode_seq_groups),
num_prefill_groups=(len(prefills.seq_groups) + num_prefill_groups=(len(prefills.seq_groups) +
len(swapped_in.prefill_seq_groups) + len(swapped_in.prefill_seq_groups) +
len(running_scheduled.prefill_seq_groups)), len(running_scheduled.prefill_seq_groups)),
@ -907,7 +915,7 @@ class Scheduler:
# It assumes the scheduled_seq_groups is ordered by # It assumes the scheduled_seq_groups is ordered by
# prefill < decoding. # prefill < decoding.
is_prompt = i < scheduler_outputs.num_prefill_groups is_prompt = seq_group.is_prefill()
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id, request_id=seq_group.request_id,
is_prompt=is_prompt, is_prompt=is_prompt,
@ -1084,7 +1092,7 @@ class Scheduler:
def _get_num_new_tokens(self, seq_group: SequenceGroup, def _get_num_new_tokens(self, seq_group: SequenceGroup,
status: SequenceStatus, enable_chunking: bool, status: SequenceStatus, enable_chunking: bool,
budget: SchedulingBudget) -> Tuple[int, bool]: budget: SchedulingBudget) -> int:
"""Get the next new tokens to compute for a given sequence group """Get the next new tokens to compute for a given sequence group
that's in a given `status`. that's in a given `status`.

View File

@ -0,0 +1,3 @@
from .communication_op import *
from .parallel_state import *
from .utils import *

View File

@ -1,15 +1,13 @@
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from vllm.model_executor.parallel_utils import pynccl_utils from .parallel_state import (get_tensor_model_parallel_group,
from vllm.model_executor.parallel_utils.custom_all_reduce import ( get_tensor_model_parallel_rank,
custom_all_reduce) get_tensor_model_parallel_world_size,
from vllm.model_executor.parallel_utils.parallel_state import ( is_pynccl_enabled_for_all_reduce)
get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return TLDR: always assume this function modifies its input, but use the return
value as the output. value as the output.
""" """
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
custom_all_reduce)
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1: if get_tensor_model_parallel_world_size() == 1:
return input_ return input_
@ -142,7 +144,7 @@ def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0, src: int = 0,
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]: ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.""" """Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group) ranks = torch.distributed.get_process_group_ranks(group)
@ -155,10 +157,10 @@ def broadcast_tensor_dict(
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
if rank == src: if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}") dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list = []
for key, value in tensor_dict.items(): for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
assert value.is_cuda, ( assert value.is_cuda, (
@ -171,19 +173,27 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list([metadata_list], torch.distributed.broadcast_object_list([metadata_list],
src=src, src=src,
group=group) group=group)
async_handles = []
for key, value in metadata_list: for key, value in metadata_list:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = tensor_dict[key] tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src, group=group) async_handles.append(
torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True))
for async_handle in async_handles:
async_handle.wait()
else: else:
recv_metadata_list = [None] recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list, torch.distributed.broadcast_object_list(recv_metadata_list,
src=src, src=src,
group=group) group=group)
metadata_list = recv_metadata_list[0] assert recv_metadata_list[0] is not None
tensor_dict = {} tensor_dict = {}
async_handles = [] async_handles = []
for key, value in metadata_list: for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata): if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size, tensor = torch.empty(value.size,
dtype=value.dtype, dtype=value.dtype,

View File

@ -5,8 +5,6 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
try: try:
import pynvml import pynvml
@ -25,6 +23,9 @@ _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
def init_custom_ar() -> None: def init_custom_ar() -> None:
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
global _CA_HANDLE global _CA_HANDLE
if _CA_HANDLE is not None: if _CA_HANDLE is not None:
return return
@ -41,12 +42,17 @@ def init_custom_ar() -> None:
" disable_custom_all_reduce=True explicitly.", world_size, " disable_custom_all_reduce=True explicitly.", world_size,
str(_SUPPORTED_WORLD_SIZES)) str(_SUPPORTED_WORLD_SIZES))
return return
if not _can_p2p(rank, world_size): num_dev = torch.cuda.device_count()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if num_dev < world_size:
logger.warn( logger.warn(
"Custom allreduce is disabled because your platform lacks GPU P2P" "Cannot test GPU P2P because not all GPUs are visible to the "
" capability or P2P test failed. To silence this warning, specify" "current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" disable_custom_all_reduce=True explicitly.") " is set.")
return return False
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
full_nvlink = _is_full_nvlink(rank, world_size) full_nvlink = _is_full_nvlink(rank, world_size)
if world_size > 2 and not full_nvlink: if world_size > 2 and not full_nvlink:
logger.warn( logger.warn(
@ -54,6 +60,15 @@ def init_custom_ar() -> None:
" than two PCIe-only GPUs. To silence this warning, specify" " than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.") " disable_custom_all_reduce=True explicitly.")
return return
# test P2P capability
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
logger.warn(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
_CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink) _CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
@ -142,40 +157,15 @@ def _is_full_nvlink(rank, world_size):
def _can_p2p(rank: int, world_size: int) -> bool: def _can_p2p(rank: int, world_size: int) -> bool:
num_dev = torch.cuda.device_count() from vllm.distributed.utils import gpu_p2p_access_check
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if num_dev < world_size:
logger.warn(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return False
for i in range(world_size): for i in range(world_size):
if i == rank: if i == rank:
continue continue
if not torch.cuda.can_device_access_peer(rank, i): if not gpu_p2p_access_check(rank, i):
return False
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
if not _can_actually_p2p(rank, i):
return False return False
return True return True
# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"cuda:{idx_a}"
dev_j = f"cuda:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c)
class CustomAllreduce: class CustomAllreduce:
# max_size: max supported allreduce size # max_size: max supported allreduce size

View File

@ -9,8 +9,8 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
try: try:
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator, from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetVersion) ncclGetVersion)
except Exception as e: except Exception as e:
# in non-NVIDIA environments, we can't import the nccl module # in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs # e.g. when running on machines with AMD GPUs

View File

@ -8,7 +8,9 @@ from typing import Optional
import torch import torch
from vllm.model_executor.parallel_utils import pynccl_utils from vllm.logger import init_logger
logger = init_logger(__name__)
# Tensor model parallel group that the current rank belongs to. # Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
@ -39,14 +41,23 @@ _CPU_WORLD_GROUP = None
# source rank when broadcasting from the first or last pipeline stage. # source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
_LOCAL_RANK = -1
def get_local_rank():
global _LOCAL_RANK
return _LOCAL_RANK
def init_distributed_environment( def init_distributed_environment(
world_size: int, world_size: int = -1,
rank: int, rank: int = -1,
distributed_init_method: Optional[str] = None, distributed_init_method: str = "env://",
local_rank: int = -1, local_rank: int = -1,
backend: str = "nccl", backend: str = "nccl",
): ):
logger.debug(f"{world_size=} {rank=} {local_rank=} "
f"{distributed_init_method=} {backend=}")
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
assert distributed_init_method is not None, ( assert distributed_init_method is not None, (
"distributed_init_method must be provided when initializing " "distributed_init_method must be provided when initializing "
@ -62,6 +73,8 @@ def init_distributed_environment(
ranks = list(range(torch.distributed.get_world_size())) ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks, _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo") backend="gloo")
global _LOCAL_RANK
_LOCAL_RANK = local_rank
def initialize_model_parallel( def initialize_model_parallel(
@ -266,6 +279,7 @@ def destroy_model_parallel():
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
from vllm.distributed.device_communicators import pynccl_utils
# Destroy the pynccl states if any. # Destroy the pynccl states if any.
pynccl_utils.destroy_process_group() pynccl_utils.destroy_process_group()
@ -279,6 +293,7 @@ _ENABLE_PYNCCL_FOR_ALL_REDUCE = False
@contextlib.contextmanager @contextlib.contextmanager
def with_pynccl_for_all_reduce(): def with_pynccl_for_all_reduce():
from vllm.distributed.device_communicators import pynccl_utils
"""use pynccl instead of torch.distributed for all reduce""" """use pynccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1: if tp_size == 1:

133
vllm/distributed/utils.py Normal file
View File

@ -0,0 +1,133 @@
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import json
import os
from typing import Dict, Optional, Sequence
import torch
import torch.distributed as dist
from vllm.logger import init_logger
from .parallel_state import get_cpu_world_group, get_local_rank
logger = init_logger(__name__)
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator)
def divide(numerator, denominator):
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility(numerator, denominator)
return numerator // denominator
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# NOTE: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"cuda:{idx_a}"
dev_j = f"cuda:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c).cpu().item()
# why do we need this cache?
# 1. we can have runtime checks for P2P access, where every process checks
# P2P access to all other GPUs. Unfortunately, the test might cost many
# (world_size * world_size) cuda context, and reduce the memory available
# for the model. see https://github.com/vllm-project/vllm/issues/3821
# 2. alternatively, we can have a p2p map that is generated by the master
# process and broadcasted to all other processes. This still requires
# #world_size of cuda context, belonging to the master process, on each GPU.
# 3. we can have a cache file, that records the p2p access status. The first
# time the master process checks the p2p access, it will generate the cache
# file, at the cost of #world_size of cuda context. Later on, all processes
# can read the cache file to check the p2p access status without any cost of
# additional cuda context.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
def gpu_p2p_access_check(i: int, j: int) -> bool:
"""Check if GPU i can access GPU j."""
# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{i}->{j}"]
is_distributed = dist.is_initialized()
num_dev = torch.cuda.device_count()
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
path = os.path.expanduser(
f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
os.makedirs(os.path.dirname(path), exist_ok=True)
if (not is_distributed or get_local_rank() == 0) \
and (not os.path.exists(path)):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger.info(f"generating GPU P2P access cache for in {path}")
cache = {}
for _i in range(num_dev):
for _j in range(num_dev):
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
_i, _j) and _can_actually_p2p(_i, _j)
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
cpu_world_group = get_cpu_world_group()
dist.barrier(cpu_world_group)
logger.info(f"reading GPU P2P access cache from {path}")
with open(path, "r") as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{i}->{j}"]

View File

@ -1,12 +1,15 @@
import argparse import argparse
import dataclasses import dataclasses
import io
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import BinaryIO, Optional, Union
from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
ModelConfig, ParallelConfig, SchedulerConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig,
SpeculativeConfig, TokenizerPoolConfig, SchedulerConfig, SpeculativeConfig, TensorizerConfig,
VisionLanguageConfig) TokenizerPoolConfig, VisionLanguageConfig)
from vllm.model_executor.tensorizer_loader import TensorizerArgs
from vllm.utils import str_to_int_tuple from vllm.utils import str_to_int_tuple
@ -58,15 +61,26 @@ class EngineArgs:
num_gpu_blocks_override: Optional[int] = None num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0 num_lookahead_slots: int = 0
# Tensorizer configuration parameters
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
bytes, os.PathLike, int] = None
vllm_tensorized: bool = False
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
# Related to Vision-language models such as llava # Related to Vision-language models such as llava
image_input_type: Optional[str] = None image_input_type: Optional[str] = None
image_token_id: Optional[int] = None image_token_id: Optional[int] = None
image_input_shape: Optional[str] = None image_input_shape: Optional[str] = None
image_feature_size: Optional[int] = None image_feature_size: Optional[int] = None
scheduler_delay_factor: float = 0.0 scheduler_delay_factor: float = 0.0
enable_chunked_prefill: bool = False enable_chunked_prefill: bool = False
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration. # Speculative decoding configuration.
speculative_model: Optional[str] = None speculative_model: Optional[str] = None
num_speculative_tokens: Optional[int] = None num_speculative_tokens: Optional[int] = None
@ -135,7 +149,9 @@ class EngineArgs:
'--load-format', '--load-format',
type=str, type=str,
default=EngineArgs.load_format, default=EngineArgs.load_format,
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
],
help='The format of the model weights to load. ' help='The format of the model weights to load. '
'"auto" will try to load the weights in the safetensors format ' '"auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format ' 'and fall back to the pytorch bin format if safetensors format '
@ -145,7 +161,10 @@ class EngineArgs:
'"npcache" will load the weights in pytorch format and store ' '"npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. ' 'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, ' '"dummy" will initialize the weights with random values, '
'which is mainly for profiling.') 'which is mainly for profiling.'
'"tensorizer" will load the weights using tensorizer from CoreWeave'
'which assumes tensorizer_uri is set to the location of the '
'serialized weights.')
parser.add_argument( parser.add_argument(
'--dtype', '--dtype',
type=str, type=str,
@ -182,6 +201,13 @@ class EngineArgs:
default=EngineArgs.max_model_len, default=EngineArgs.max_model_len,
help='model context length. If unspecified, ' help='model context length. If unspecified, '
'will be automatically derived from the model.') 'will be automatically derived from the model.')
parser.add_argument(
'--guided-decoding-backend',
type=str,
default='outlines',
choices=['outlines', 'lm-format-enforcer'],
help='Which engine will be used for guided decoding'
' (JSON schema / regex etc)')
# Parallel arguments # Parallel arguments
parser.add_argument('--worker-use-ray', parser.add_argument('--worker-use-ray',
action='store_true', action='store_true',
@ -386,9 +412,8 @@ class EngineArgs:
'prompt latency) before scheduling next prompt.') 'prompt latency) before scheduling next prompt.')
parser.add_argument( parser.add_argument(
'--enable-chunked-prefill', '--enable-chunked-prefill',
type=bool, action='store_true',
default=False, help='If set, the prefill requests can be chunked based on the '
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens')
parser.add_argument( parser.add_argument(
@ -404,6 +429,7 @@ class EngineArgs:
default=None, default=None,
help='The number of speculative tokens to sample from ' help='The number of speculative tokens to sample from '
'the draft model in speculative decoding') 'the draft model in speculative decoding')
parser = TensorizerArgs.add_cli_args(parser)
return parser return parser
@classmethod @classmethod
@ -466,6 +492,17 @@ class EngineArgs:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None and self.max_cpu_loras > 0 else None) if self.enable_lora else None
tensorizer_config = TensorizerConfig(
tensorizer_uri=self.tensorizer_uri,
vllm_tensorized=self.vllm_tensorized,
verify_hash=self.verify_hash,
num_readers=self.num_readers,
encryption_keyfile=self.encryption_keyfile,
s3_access_key_id=self.s3_access_key_id,
s3_secret_access_key=self.s3_secret_access_key,
s3_endpoint=self.s3_endpoint,
)
if self.image_input_type: if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size): or not self.image_feature_size):
@ -482,6 +519,9 @@ class EngineArgs:
else: else:
vision_language_config = None vision_language_config = None
decoding_config = DecodingConfig(
guided_decoding_backend=self.guided_decoding_backend)
return EngineConfig(model_config=model_config, return EngineConfig(model_config=model_config,
cache_config=cache_config, cache_config=cache_config,
parallel_config=parallel_config, parallel_config=parallel_config,
@ -489,7 +529,9 @@ class EngineArgs:
device_config=device_config, device_config=device_config,
lora_config=lora_config, lora_config=lora_config,
vision_language_config=vision_language_config, vision_language_config=vision_language_config,
speculative_config=speculative_config) speculative_config=speculative_config,
decoding_config=decoding_config,
tensorizer_config=tensorizer_config)
@dataclass @dataclass

View File

@ -333,8 +333,7 @@ class AsyncLLMEngine:
if engine_config.device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
raise NotImplementedError("Neuron is not supported for " raise NotImplementedError("Neuron is not supported for "
"async engine yet.") "async engine yet.")
elif (engine_config.parallel_config.worker_use_ray elif engine_config.parallel_config.worker_use_ray:
or engine_args.engine_use_ray):
initialize_ray_cluster(engine_config.parallel_config) initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync executor_class = RayGPUExecutorAsync
@ -410,8 +409,8 @@ class AsyncLLMEngine:
else: else:
# FIXME(woosuk): This is a bit hacky. Be careful when changing the # FIXME(woosuk): This is a bit hacky. Be careful when changing the
# order of the arguments. # order of the arguments.
cache_config = args[1] cache_config = kwargs["cache_config"]
parallel_config = args[2] parallel_config = kwargs["parallel_config"]
if parallel_config.tensor_parallel_size == 1: if parallel_config.tensor_parallel_size == 1:
num_gpus = cache_config.gpu_memory_utilization num_gpus = cache_config.gpu_memory_utilization
else: else:

View File

@ -4,8 +4,9 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
import vllm import vllm
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig, TensorizerConfig,
VisionLanguageConfig) VisionLanguageConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
@ -74,6 +75,8 @@ class LLMEngine:
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
tensorizer_config: Optional[TensorizerConfig],
executor_class: Type[ExecutorBase], executor_class: Type[ExecutorBase],
log_stats: bool, log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
@ -99,6 +102,7 @@ class LLMEngine:
f"kv_cache_dtype={cache_config.cache_dtype}, " f"kv_cache_dtype={cache_config.cache_dtype}, "
f"quantization_param_path={model_config.quantization_param_path}, " f"quantization_param_path={model_config.quantization_param_path}, "
f"device_config={device_config.device}, " f"device_config={device_config.device}, "
f"decoding_config={decoding_config!r}, "
f"seed={model_config.seed})") f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
@ -110,6 +114,8 @@ class LLMEngine:
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.speculative_config = speculative_config self.speculative_config = speculative_config
self.decoding_config = decoding_config or DecodingConfig()
self.tensorizer_config = tensorizer_config
self.log_stats = log_stats self.log_stats = log_stats
self._init_tokenizer() self._init_tokenizer()
@ -125,6 +131,7 @@ class LLMEngine:
lora_config=lora_config, lora_config=lora_config,
vision_language_config=vision_language_config, vision_language_config=vision_language_config,
speculative_config=speculative_config, speculative_config=speculative_config,
tensorizer_config=tensorizer_config,
) )
self._initialize_kv_caches() self._initialize_kv_caches()
@ -267,6 +274,9 @@ class LLMEngine:
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
if self.tensorizer_config:
self.tensorizer_config.verify_with_parallel_config(
self.parallel_config)
if self.lora_config: if self.lora_config:
self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_model_config(self.model_config)
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
@ -504,9 +514,11 @@ class LLMEngine:
for seq, _ in child_seqs: for seq, _ in child_seqs:
if seq_group.sampling_params.detokenize: if seq_group.sampling_params.detokenize:
self.detokenizer.decode_sequence_inplace( new_char_count = self.detokenizer.decode_sequence_inplace(
seq, seq_group.sampling_params) seq, seq_group.sampling_params)
self._check_stop(seq, seq_group.sampling_params) else:
new_char_count = 0
self._check_stop(seq, new_char_count, seq_group.sampling_params)
# Non-beam search case # Non-beam search case
if not seq_group.sampling_params.use_beam_search: if not seq_group.sampling_params.use_beam_search:
@ -636,7 +648,10 @@ class LLMEngine:
seq_group = scheduled_seq_group.seq_group seq_group = scheduled_seq_group.seq_group
seq_group.update_num_computed_tokens( seq_group.update_num_computed_tokens(
scheduled_seq_group.token_chunk_size) scheduled_seq_group.token_chunk_size)
self._process_sequence_group_outputs(seq_group, outputs) # If uncomputed tokens > 0, it means prefill is chunked.
# We don't need to process outputs in that case.
if seq_group.get_num_uncomputed_tokens() == 0:
self._process_sequence_group_outputs(seq_group, outputs)
# Free the finished sequence groups. # Free the finished sequence groups.
self.scheduler.free_finished_seq_groups() self.scheduler.free_finished_seq_groups()
@ -798,9 +813,45 @@ class LLMEngine:
time_e2e_requests=time_e2e_requests, time_e2e_requests=time_e2e_requests,
) )
def _check_stop(self, seq: Sequence, def _check_stop(self, seq: Sequence, new_char_count: int,
sampling_params: SamplingParams) -> None: sampling_params: SamplingParams) -> None:
"""Stop the finished sequences.""" """Stop the finished sequences.
new_char_count is the number of chars added to the
sequence's output text for the newly generated token
"""
# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return
# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
seq.status = SequenceStatus.FINISHED_STOPPED
return
# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
if new_char_count and (
not sampling_params.include_stop_str_in_output):
# Remove last token
seq.output_text = seq.output_text[:-new_char_count]
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if any stop strings are matched.
stop_str = self._check_stop_strings(seq, new_char_count,
sampling_params)
if stop_str is not None:
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
# Check if the sequence has reached max_model_len. # Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len: if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
@ -811,43 +862,37 @@ class LLMEngine:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return return
# Check if the minimum number of tokens has been generated yet; @staticmethod
# skip the stop string/token checks if not def _check_stop_strings(seq: Sequence, new_char_count: int,
if seq.get_output_len() < sampling_params.min_tokens: sampling_params: SamplingParams) -> Optional[str]:
return """Check if any stop strings are matched and truncate sequence
output text accordingly.
if sampling_params.detokenize: Returns the stop string if matched or else None.
for stop_str in sampling_params.stop: """
if seq.output_text.endswith(stop_str): if not new_char_count:
self._finalize_sequence(seq, sampling_params, stop_str) return None
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = stop_str
return
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
last_token_id)
self._finalize_sequence(seq, sampling_params, stop_str)
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return
# Check if the sequence has generated the EOS token. for stop_str in sampling_params.stop:
if ((not sampling_params.ignore_eos) stop_string_len = len(stop_str)
and seq.get_last_token_id() == seq.eos_token_id): # Avoid searching already-searched text.
seq.status = SequenceStatus.FINISHED_STOPPED stop_index = seq.output_text.find(
return stop_str, -new_char_count - stop_string_len)
if stop_index == -1:
continue
def _finalize_sequence(self, seq: Sequence, if sampling_params.include_stop_str_in_output:
sampling_params: SamplingParams, # Truncate to end of stop string.
stop_string: str) -> None: stop_index += stop_string_len
if sampling_params.include_stop_str_in_output: if stop_index >= len(seq.output_text):
return # No truncation required.
return stop_str
if stop_string and seq.output_text.endswith(stop_string): # Truncate the output text to either the beginning
# Truncate the output text so that the stop string is # or end of the stop string.
# not included in the output. seq.output_text = seq.output_text[:stop_index]
seq.output_text = seq.output_text[:-len(stop_string)] return stop_str
return None
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request) return self.model_executor.add_lora(lora_request)

View File

@ -1,6 +1,6 @@
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List from typing import Dict, List, Protocol
import numpy as np import numpy as np
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
@ -119,12 +119,18 @@ class Stats:
time_e2e_requests: List[float] time_e2e_requests: List[float]
class SupportsMetricsInfo(Protocol):
def metrics_info(self) -> Dict[str, str]:
...
class StatLogger: class StatLogger:
"""StatLogger is used LLMEngine to log to Promethus and Stdout.""" """StatLogger is used LLMEngine to log to Promethus and Stdout."""
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
# Metadata for logging locally. # Metadata for logging locally.
self.last_local_log = time.monotonic() self.last_local_log = time.time()
self.local_interval = local_interval self.local_interval = local_interval
# Tracked stats over current local logging interval. # Tracked stats over current local logging interval.
@ -135,7 +141,7 @@ class StatLogger:
self.labels = labels self.labels = labels
self.metrics = Metrics(labelnames=list(labels.keys())) self.metrics = Metrics(labelnames=list(labels.keys()))
def info(self, type: str, obj: object) -> None: def info(self, type: str, obj: SupportsMetricsInfo) -> None:
if type == "cache_config": if type == "cache_config":
self.metrics.info_cache_config.info(obj.metrics_info()) self.metrics.info_cache_config.info(obj.metrics_info())

View File

@ -1,9 +1,10 @@
import pickle import pickle
from typing import List, Optional, Tuple from typing import Callable, List, Optional, Tuple
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
from vllm.worker.worker import Worker
logger = init_logger(__name__) logger = init_logger(__name__)
@ -18,15 +19,20 @@ try:
if init_cached_hf_modules: if init_cached_hf_modules:
from transformers.dynamic_module_utils import init_hf_modules from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules() init_hf_modules()
self.worker = None self._worker: Optional[Worker] = None
# Since the compiled DAG runs a main execution # Since the compiled DAG runs a main execution
# in a different thread that calls cuda.set_device. # in a different thread that calls cuda.set_device.
# The flag indicates is set_device is called on # The flag indicates is set_device is called on
# that thread. # that thread.
self.compiled_dag_cuda_device_set = False self.compiled_dag_cuda_device_set = False
def init_worker(self, worker_init_fn): def init_worker(self, worker_init_fn: Callable[[], Worker]):
self.worker = worker_init_fn() self._worker = worker_init_fn()
@property
def worker(self) -> Worker:
assert self._worker is not None
return self._worker
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.worker, name) return getattr(self.worker, name)
@ -70,8 +76,8 @@ except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. " logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with " "For distributed inference, please install Ray with "
"`pip install ray`.") "`pip install ray`.")
ray = None ray = None # type: ignore
RayWorkerVllm = None RayWorkerVllm = None # type: ignore
def initialize_ray_cluster( def initialize_ray_cluster(

View File

@ -47,6 +47,7 @@ async def generate(request: Request) -> Response:
sampling_params = SamplingParams(**request_dict) sampling_params = SamplingParams(**request_dict)
request_id = random_uuid() request_id = random_uuid()
assert engine is not None
results_generator = engine.generate(prompt, sampling_params, request_id) results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case # Streaming case

View File

@ -86,7 +86,7 @@ class LLM:
swap_space: int = 4, swap_space: int = 4,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: int = 8192, max_context_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = True, disable_custom_all_reduce: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
@ -170,8 +170,12 @@ class LLM:
multi_modal_data.data = multi_modal_data.data.to(torch.float16) multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine. # Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len( if prompts is not None:
prompt_token_ids) num_requests = len(prompts)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids)
for i in range(num_requests): for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[ token_ids = None if prompt_token_ids is None else prompt_token_ids[

View File

@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel):
description=( description=(
"If specified, the output will follow the context free grammar."), "If specified, the output will follow the context free grammar."),
) )
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
@ -265,6 +271,12 @@ class CompletionRequest(BaseModel):
description=( description=(
"If specified, the output will follow the context free grammar."), "If specified, the output will follow the context free grammar."),
) )
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
# doc: end-completion-extra-params # doc: end-completion-extra-params

View File

@ -63,13 +63,18 @@ class OpenAIServingChat(OpenAIServing):
request_id = f"cmpl-{random_uuid()}" request_id = f"cmpl-{random_uuid()}"
try: try:
token_ids = self._validate_prompt_and_tokenize(request, # Tokenize/detokenize depending on prompt format (string/token list)
prompt=prompt) prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt=prompt)
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
decoding_config = self.engine.engine.decoding_config
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = ( guided_decode_logits_processor = (
await get_guided_decoding_logits_processor( await get_guided_decoding_logits_processor(
request, await self.engine.get_tokenizer())) guided_decoding_backend, request, await
self.engine.get_tokenizer()))
if guided_decode_logits_processor: if guided_decode_logits_processor:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:
sampling_params.logits_processors = [] sampling_params.logits_processors = []
@ -78,8 +83,8 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt, sampling_params, result_generator = self.engine.generate(prompt_text, sampling_params,
request_id, token_ids, request_id, prompt_ids,
lora_request) lora_request)
# Streaming response # Streaming response
if request.stream: if request.stream:

View File

@ -1,4 +1,3 @@
import asyncio
import time import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple) Optional, Tuple)
@ -17,7 +16,7 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.utils import random_uuid from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__) logger = init_logger(__name__)
@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
return prompt_is_tokens, prompts return prompt_is_tokens, prompts
def merge_async_iterators(*iterators):
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue = asyncio.Queue()
finished = [False] * len(iterators)
async def producer(i, iterator):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
if isinstance(item, Exception):
raise item
yield item
await asyncio.gather(*_tasks)
return consumer()
class OpenAIServingCompletion(OpenAIServing): class OpenAIServingCompletion(OpenAIServing):
def __init__(self, def __init__(self,
@ -124,9 +88,13 @@ class OpenAIServingCompletion(OpenAIServing):
try: try:
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request) lora_request = self._maybe_get_lora(request)
decoding_config = self.engine.engine.decoding_config
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logit_processor = ( guided_decode_logit_processor = (
await get_guided_decoding_logits_processor( await get_guided_decoding_logits_processor(
request, await self.engine.get_tokenizer())) guided_decoding_backend, request, await
self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None: if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None: if sampling_params.logits_processors is None:
sampling_params.logits_processors = [] sampling_params.logits_processors = []
@ -136,23 +104,24 @@ class OpenAIServingCompletion(OpenAIServing):
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if prompt_is_tokens: if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize( prompt_formats = self._validate_prompt_and_tokenize(
request, request,
prompt_ids=prompt, prompt_ids=prompt,
truncate_prompt_tokens=sampling_params. truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens) truncate_prompt_tokens)
else: else:
input_ids = self._validate_prompt_and_tokenize( prompt_formats = self._validate_prompt_and_tokenize(
request, request,
prompt=prompt, prompt=prompt,
truncate_prompt_tokens=sampling_params. truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens) truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
generators.append( generators.append(
self.engine.generate(prompt, self.engine.generate(prompt_text,
sampling_params, sampling_params,
f"{request_id}-{i}", f"{request_id}-{i}",
prompt_token_ids=input_ids, prompt_token_ids=prompt_ids,
lora_request=lora_request)) lora_request=lora_request))
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
@ -326,7 +295,8 @@ class OpenAIServingCompletion(OpenAIServing):
output_text = prompt_text output_text = prompt_text
elif request.echo and request.max_tokens > 0: elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
output_text = prompt_text + output.text output_text = prompt_text + output.text
else: else:
token_ids = output.token_ids token_ids = output.token_ids
@ -334,6 +304,9 @@ class OpenAIServingCompletion(OpenAIServing):
output_text = output.text output_text = output.text
if request.logprobs is not None: if request.logprobs is not None:
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs( logprobs = self._create_logprobs(
token_ids=token_ids, token_ids=token_ids,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,

View File

@ -2,7 +2,7 @@ import asyncio
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Tuple, Union
from pydantic import conint from pydantic import conint
@ -99,27 +99,32 @@ class OpenAIServing:
last_token_len = 0 last_token_len = 0
if num_output_top_logprobs: if num_output_top_logprobs:
logprobs.top_logprobs = [] logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None: if step_top_logprobs is None:
token_logprob = step_top_logprobs[token_id].logprob token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
logprobs.top_logprobs.append(None)
else: else:
token_logprob = None token_logprob = step_top_logprobs[token_id].logprob
token = step_top_logprobs[token_id].decoded_token token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token) logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob) logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs:
logprobs.top_logprobs.append({
p.decoded_token: p.logprob
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
if len(logprobs.text_offset) == 0: if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset) logprobs.text_offset.append(initial_text_offset)
else: else:
logprobs.text_offset.append(logprobs.text_offset[-1] + logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len) last_token_len)
last_token_len = len(token) last_token_len = len(token)
if num_output_top_logprobs:
logprobs.top_logprobs.append({
p.decoded_token: p.logprob
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
return logprobs return logprobs
def create_error_response( def create_error_response(
@ -164,12 +169,12 @@ class OpenAIServing:
raise ValueError("The model `{request.model}` does not exist.") raise ValueError("The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize( def _validate_prompt_and_tokenize(
self, self,
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None, prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None, prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None truncate_prompt_tokens: Optional[conint(ge=1)] = None
) -> List[int]: ) -> Tuple[List[int], str]:
if not (prompt or prompt_ids): if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.") raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids): if (prompt and prompt_ids):
@ -187,6 +192,8 @@ class OpenAIServing:
else: else:
input_ids = prompt_ids input_ids = prompt_ids
input_text = prompt if prompt is not None else self.tokenizer.decode(
prompt_ids)
token_num = len(input_ids) token_num = len(input_ids)
if request.max_tokens is None: if request.max_tokens is None:
@ -201,4 +208,4 @@ class OpenAIServing:
f"{request.max_tokens} in the completion). " f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", ) f"Please reduce the length of the messages or completion.", )
else: else:
return input_ids return input_ids, input_text

View File

@ -1,10 +1,9 @@
import os import os
from typing import Dict, List, Optional from typing import Dict, List, Set, Tuple
import torch import torch
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
ParallelConfig, SchedulerConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -16,22 +15,13 @@ logger = init_logger(__name__)
class CPUExecutor(ExecutorBase): class CPUExecutor(ExecutorBase):
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, def _init_executor(self) -> None:
parallel_config: ParallelConfig, assert self.device_config.device_type == "cpu"
scheduler_config: SchedulerConfig, assert self.lora_config is None, "cpu backend doesn't support LoRA"
device_config: DeviceConfig, self.model_config = _verify_and_get_model_config(self.model_config)
lora_config: Optional[LoRAConfig], *args, **kwargs) -> None: self.cache_config = _verify_and_get_cache_config(self.cache_config)
assert device_config.device_type == "cpu" self.scheduler_config = _verify_and_get_scheduler_config(
assert lora_config is None, "cpu backend doesn't support LoRA" self.scheduler_config)
model_config = _verify_and_get_model_config(model_config)
cache_config = _verify_and_get_cache_config(cache_config)
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
# Instantiate the worker and load the model to CPU. # Instantiate the worker and load the model to CPU.
self._init_worker() self._init_worker()
@ -60,7 +50,7 @@ class CPUExecutor(ExecutorBase):
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
""" """
@ -73,7 +63,10 @@ class CPUExecutor(ExecutorBase):
# NOTE: We log here to avoid multiple logs when number of workers is # NOTE: We log here to avoid multiple logs when number of workers is
# greater than one. We could log in the engine, but not all executors # greater than one. We could log in the engine, but not all executors
# have GPUs. # have GPUs.
logger.info(f"# CPU blocks: {num_cpu_blocks}") # NOTE: `cpu block` for CPU backend is located on CPU memory but is
# referred as `gpu block`. Because we want to reuse the existing block
# management procedure.
logger.info(f"# CPU blocks: {num_gpu_blocks}")
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(self, def execute_model(self,
@ -95,7 +88,7 @@ class CPUExecutor(ExecutorBase):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id) return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()
def check_health(self) -> None: def check_health(self) -> None:
@ -116,6 +109,15 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
return config return config
def _verify_and_get_scheduler_config(
config: SchedulerConfig) -> SchedulerConfig:
if config.chunked_prefill_enabled:
logger.warning("Chunked prefill is not supported on CPU, disable it.")
config.chunked_prefill_enabled = False
return config
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
_GB = 1 << 30 _GB = 1 << 30
if config.enable_prefix_caching: if config.enable_prefix_caching:

View File

@ -1,9 +1,9 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Dict, List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig) TensorizerConfig, VisionLanguageConfig)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@ -16,7 +16,6 @@ class ExecutorBase(ABC):
that can execute the model on multiple devices. that can execute the model on multiple devices.
""" """
@abstractmethod
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
@ -27,11 +26,26 @@ class ExecutorBase(ABC):
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig], vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig], speculative_config: Optional[SpeculativeConfig],
tensorizer_config: Optional[TensorizerConfig],
) -> None: ) -> None:
raise NotImplementedError self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
self.speculative_config = speculative_config
self.tensorizer_config = tensorizer_config
self._init_executor()
@abstractmethod @abstractmethod
def determine_num_available_blocks(self) -> tuple[int, int]: def _init_executor(self) -> None:
pass
@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and """Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache. swappable CPU KV cache.
@ -39,7 +53,7 @@ class ExecutorBase(ABC):
ExecutorBase may require modification of the result, e.g. to ensure the ExecutorBase may require modification of the result, e.g. to ensure the
selected cache sizes are compatible with all workers. selected cache sizes are compatible with all workers.
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to. are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to. appended to.
@ -71,7 +85,7 @@ class ExecutorBase(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
@ -94,8 +108,7 @@ class ExecutorAsyncBase(ExecutorBase):
"""Executes one model step on the given sequences.""" """Executes one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an """Checks if the executor is healthy. If not, it should raise an
exception.""" exception."""
raise NotImplementedError self.check_health()

View File

@ -1,8 +1,5 @@
from typing import Dict, List, Optional from typing import Dict, List, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -15,26 +12,8 @@ logger = init_logger(__name__)
class GPUExecutor(ExecutorBase): class GPUExecutor(ExecutorBase):
def __init__( def _init_executor(self) -> None:
self, assert (not self.speculative_config
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
assert (not speculative_config
), "Speculative decoding not yet supported for GPU backend" ), "Speculative decoding not yet supported for GPU backend"
# Instantiate the worker and load the model to GPU. # Instantiate the worker and load the model to GPU.
@ -61,12 +40,13 @@ class GPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
tensorizer_config=self.tensorizer_config,
is_driver_worker=True, is_driver_worker=True,
) )
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
""" """
@ -104,7 +84,7 @@ class GPUExecutor(ExecutorBase):
assert lora_id > 0, "lora_id must be greater than 0." assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id) return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()
def check_health(self) -> None: def check_health(self) -> None:
@ -128,8 +108,3 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy) blocks_to_copy=blocks_to_copy)
return output return output
async def check_health_async(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return

View File

@ -1,8 +1,5 @@
from typing import Dict, List, Optional from typing import Dict, List, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -13,23 +10,10 @@ logger = init_logger(__name__)
class NeuronExecutor(ExecutorBase): class NeuronExecutor(ExecutorBase):
def __init__( def _init_executor(self) -> None:
self, assert (self.lora_config is
model_config: ModelConfig, None), "LoRA is not supported for Neuron backend."
cache_config: CacheConfig, assert (not self.speculative_config
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
assert lora_config is None, "LoRA is not supported for Neuron backend."
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
assert (not speculative_config
), "Speculative decoding not yet supported for Neuron backend." ), "Speculative decoding not yet supported for Neuron backend."
# Instantiate the worker and load the model to the device. # Instantiate the worker and load the model to the device.
@ -43,11 +27,12 @@ class NeuronExecutor(ExecutorBase):
self.parallel_config, self.parallel_config,
self.scheduler_config, self.scheduler_config,
self.device_config, self.device_config,
self.cache_config,
) )
self.driver_worker.init_device() self.driver_worker.init_device()
self.driver_worker.load_model() self.driver_worker.load_model()
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the """Determine the number of available KV blocks by invoking the
underlying worker. underlying worker.
""" """
@ -78,7 +63,7 @@ class NeuronExecutor(ExecutorBase):
def remove_lora(self, lora_id: int) -> bool: def remove_lora(self, lora_id: int) -> bool:
return self.driver_worker.remove_lora(lora_id) return self.driver_worker.remove_lora(lora_id)
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras() return self.driver_worker.list_loras()
def check_health(self) -> None: def check_health(self) -> None:

View File

@ -3,11 +3,8 @@ import copy
import os import os
import pickle import pickle
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig,
VisionLanguageConfig)
from vllm.engine.ray_utils import RayWorkerVllm, ray from vllm.engine.ray_utils import RayWorkerVllm, ray
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
@ -32,25 +29,8 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
class RayGPUExecutor(ExecutorBase): class RayGPUExecutor(ExecutorBase):
def __init__( def _init_executor(self) -> None:
self, assert (not self.speculative_config
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
speculative_config: Optional[SpeculativeConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.vision_language_config = vision_language_config
assert (not speculative_config
), "Speculative decoding not yet supported for RayGPU backend." ), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.worker_use_ray assert self.parallel_config.worker_use_ray
@ -68,6 +48,21 @@ class RayGPUExecutor(ExecutorBase):
if USE_RAY_COMPILED_DAG: if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag() self.forward_dag = self._compiled_ray_dag()
def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.update({
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
})
return ray_remote_kwargs
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1: if self.parallel_config.tensor_parallel_size == 1:
@ -83,6 +78,10 @@ class RayGPUExecutor(ExecutorBase):
# The remaining workers are the actual ray actors. # The remaining workers are the actual ray actors.
self.workers: List[RayWorkerVllm] = [] self.workers: List[RayWorkerVllm] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs)
# Create the workers. # Create the workers.
driver_ip = get_ip() driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs): for bundle_id, bundle in enumerate(placement_group.bundle_specs):
@ -171,6 +170,7 @@ class RayGPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=lora_config, lora_config=lora_config,
vision_language_config=vision_language_config, vision_language_config=vision_language_config,
tensorizer_config=self.tensorizer_config,
)) ))
# Initialize the driver worker with the Worker class. # Initialize the driver worker with the Worker class.
@ -187,6 +187,7 @@ class RayGPUExecutor(ExecutorBase):
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
lora_config=self.lora_config, lora_config=self.lora_config,
vision_language_config=self.vision_language_config, vision_language_config=self.vision_language_config,
tensorizer_config=self.tensorizer_config,
is_driver_worker=True, is_driver_worker=True,
) )
@ -197,7 +198,7 @@ class RayGPUExecutor(ExecutorBase):
max_parallel_loading_workers, max_parallel_loading_workers,
) )
def determine_num_available_blocks(self) -> tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks. """Determine the number of available KV blocks.
This invokes `determine_num_available_blocks` on each worker and takes This invokes `determine_num_available_blocks` on each worker and takes
@ -205,7 +206,7 @@ class RayGPUExecutor(ExecutorBase):
compatible with all workers. compatible with all workers.
Returns: Returns:
- tuple[num_gpu_blocks, num_cpu_blocks] - Tuple[num_gpu_blocks, num_cpu_blocks]
""" """
# Get the maximum number of blocks that can be allocated on GPU and CPU. # Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers("determine_num_available_blocks", ) num_blocks = self._run_workers("determine_num_available_blocks", )
@ -269,14 +270,14 @@ class RayGPUExecutor(ExecutorBase):
lora_id=lora_id, lora_id=lora_id,
) )
def list_loras(self) -> List[int]: def list_loras(self) -> Set[int]:
return self._run_workers("list_loras") return self._run_workers("list_loras")
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[List[Any]] = None, driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False, use_ray_compiled_dag: bool = False,
@ -291,6 +292,7 @@ class RayGPUExecutor(ExecutorBase):
if use_ray_compiled_dag: if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single # Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it. # input. TODO(sang): Fix it.
assert self.forward_dag is not None
output_channels = self.forward_dag.execute(1) output_channels = self.forward_dag.execute(1)
else: else:
# Start the ray workers first. # Start the ray workers first.
@ -369,7 +371,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
self, self,
method: str, method: str,
*args, *args,
driver_args: Optional[List[Any]] = None, driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None, driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
@ -411,7 +413,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
# Only the driver worker returns the sampling results. # Only the driver worker returns the sampling results.
output = all_outputs[0] output = all_outputs[0]
return output return output
async def check_health_async(self) -> None:
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()

View File

@ -4,6 +4,7 @@
import logging import logging
import os import os
import sys import sys
from typing import Optional
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
@ -26,7 +27,7 @@ class NewLineFormatter(logging.Formatter):
_root_logger = logging.getLogger("vllm") _root_logger = logging.getLogger("vllm")
_default_handler = None _default_handler: Optional[logging.Handler] = None
def _setup_logger(): def _setup_logger():
@ -55,7 +56,12 @@ def init_logger(name: str):
# Use the same settings as above for root logger # Use the same settings as above for root logger
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
if VLLM_CONFIGURE_LOGGING: if VLLM_CONFIGURE_LOGGING:
if _default_handler is None:
raise ValueError(
"_default_handler is not set up. This should never happen!"
" Please open an issue on Github.")
logger.addHandler(_default_handler) logger.addHandler(_default_handler)
logger.propagate = False logger.propagate = False
return logger return logger

View File

@ -10,6 +10,12 @@ import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.lora.punica import add_lora, add_lora_slice, bgmv
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear, MergedColumnParallelLinear,
@ -18,18 +24,27 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce,
tensor_model_parallel_gather)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import (
split_tensor_along_last_dim)
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass
def _get_lora_device(base_layer: nn.Module) -> torch.device:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
"""Returns the device for where to place the LoRA tensors."""
# unquantizedLinear
if hasattr(base_layer, "weight"):
return base_layer.weight.device
# GPTQ/AWQ/SqueezeLLM
elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device
# marlin
elif hasattr(base_layer, "B"):
return base_layer.B.device
else:
raise ValueError(f"Unsupported base layer: {base_layer}")
def _apply_lora( def _apply_lora(
x: torch.Tensor, x: torch.Tensor,
lora_a_stacked: torch.Tensor, lora_a_stacked: torch.Tensor,
@ -268,12 +283,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1 added_tokens_mask = x > self.base_layer.org_vocab_size - 1
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) embedding_len = self.indices_len[3]
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
full_lora_a_embeddings = F.embedding( full_lora_a_embeddings = F.embedding(
x + indices, x + indices,
self.lora_a_stacked_2d, self.lora_a_stacked_2d,
) )
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) indices = self.embeddings_indices[0][:embedding_len].view_as(x)
full_output = self.base_layer.forward( full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask)) x.add_(indices * added_tokens_mask))
@ -302,6 +318,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size = self.base_layer.input_size
self.output_size = self.base_layer.output_size_per_partition
self.device = _get_lora_device(self.base_layer)
def create_lora_weights( def create_lora_weights(
self, self,
@ -312,17 +331,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) )
self.lora_b_stacked = torch.zeros( self.lora_b_stacked = torch.zeros(
max_loras, max_loras,
1, 1,
self.base_layer.weight.shape[0], self.output_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) )
self.indices: Optional[torch.Tensor] = None self.indices: Optional[torch.Tensor] = None
@ -368,7 +387,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def apply_weights(self, x: torch.Tensor, def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias) self.base_layer, x, bias)
_apply_lora( _apply_lora(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
@ -402,10 +421,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
if self.base_layer.skip_bias_add else None) if self.base_layer.skip_bias_add else None)
return output, output_bias return output, output_bias
@property
def linear_weights(self):
return self.base_layer.linear_weights
@classmethod @classmethod
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
@ -446,18 +461,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) for _ in range(n_slices)) ) for _ in range(n_slices))
self.lora_b_stacked = tuple( self.lora_b_stacked = tuple(
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
self.base_layer.weight.shape[0] // 2, self.output_size // 2,
lora_config.max_lora_rank, lora_config.max_lora_rank,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) for _ in range(n_slices)) ) for _ in range(n_slices))
self.indices: Optional[torch.Tensor] = None self.indices: Optional[torch.Tensor] = None
@ -505,7 +520,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def apply_weights(self, x: torch.Tensor, def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias) self.base_layer, x, bias)
_apply_lora_packed_nslice( _apply_lora_packed_nslice(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
@ -623,25 +638,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
torch.zeros( torch.zeros(
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
) )
self.lora_b_stacked = ( self.lora_b_stacked = (
@ -651,7 +666,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.q_proj_shard_size, self.q_proj_shard_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
torch.zeros( torch.zeros(
max_loras, max_loras,
@ -659,7 +674,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.kv_proj_shard_size, self.kv_proj_shard_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
torch.zeros( torch.zeros(
max_loras, max_loras,
@ -667,7 +682,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.kv_proj_shard_size, self.kv_proj_shard_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
), ),
) )
@ -746,7 +761,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def apply_weights(self, x: torch.Tensor, def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias) self.base_layer, x, bias)
_apply_lora_packed_nslice( _apply_lora_packed_nslice(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
@ -770,6 +785,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None: def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.input_size = self.base_layer.input_size_per_partition
self.output_size = self.base_layer.output_size
self.device = _get_lora_device(self.base_layer)
def create_lora_weights( def create_lora_weights(
self, self,
@ -781,20 +799,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras, max_loras,
1, 1,
lora_config.max_lora_rank, lora_config.max_lora_rank,
self.base_layer.weight.shape[1], self.input_size,
), ),
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) )
self.lora_b_stacked = torch.zeros( self.lora_b_stacked = torch.zeros(
( (
max_loras, max_loras,
1, 1,
self.base_layer.weight.shape[0], self.output_size,
lora_config.max_lora_rank, lora_config.max_lora_rank,
), ),
dtype=lora_config.lora_dtype, dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device, device=self.device,
) )
self.indices: Optional[torch.Tensor] = None self.indices: Optional[torch.Tensor] = None
self.indices_len: Optional[List[int]] = None self.indices_len: Optional[List[int]] = None
@ -813,7 +831,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.reset_lora(index) self.reset_lora(index)
if self.base_layer.tp_size > 1: if self.base_layer.tp_size > 1:
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.base_layer.weight.shape[1] shard_size = self.input_size
start_idx = tensor_model_parallel_rank * shard_size start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size end_idx = (tensor_model_parallel_rank + 1) * shard_size
lora_a = lora_a[start_idx:end_idx, :] lora_a = lora_a[start_idx:end_idx, :]
@ -838,7 +856,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def apply_weights(self, x: torch.Tensor) -> torch.Tensor: def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x) self.base_layer, x)
_apply_lora( _apply_lora(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
@ -888,7 +906,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
@property @property
def weight(self): def weight(self):
return self.base_layer.weight
return self.base_layer.weight if hasattr(
self.base_layer, "weight") else self.base_layer.qweight
@classmethod @classmethod
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
@ -939,9 +959,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: Optional[PretrainedConfig] = None, model_config: Optional[PretrainedConfig] = None,
) -> None: ) -> None:
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
if 32000 < self.base_layer.vocab_size > 33024: if 32000 < self.base_layer.vocab_size > 128512:
raise ValueError("When using LoRA, vocab size must be " raise ValueError("When using LoRA, vocab size must be "
"32000 >= vocab_size <= 33024") "32000 >= vocab_size <= 128512")
self.lora_a_stacked = torch.zeros( self.lora_a_stacked = torch.zeros(
( (
max_loras, max_loras,

View File

@ -0,0 +1,25 @@
from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
get_lm_format_enforcer_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
async def get_guided_decoding_logits_processor(
guided_decoding_backend: str, request: Union[CompletionRequest,
ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
if guided_decoding_backend == 'outlines':
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")

View File

@ -0,0 +1,69 @@
from functools import lru_cache
from json import loads as json_loads
from typing import Optional, Union
from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
RegexParser, StringParser,
TokenEnforcerTokenizerData, UnionParser)
from lmformatenforcer.integrations.vllm import (
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
async def get_lm_format_enforcer_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Optional[LogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer)
character_level_parser: CharacterLevelParser
if request.guided_json:
schema = _normalize_json_schema_object(request.guided_json)
character_level_parser = JsonSchemaParser(schema)
elif request.guided_choice:
character_level_parser = UnionParser(
[StringParser(choice) for choice in request.guided_choice])
elif request.guided_regex:
character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
elif (request.response_format is not None
and request.response_format.type == "json_object"):
character_level_parser = JsonSchemaParser(
None) # None means any json object
else:
return None
logits_processor = build_vllm_logits_processor(tokenizer_data,
character_level_parser)
return logits_processor
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
if isinstance(schema, str):
return json_loads(schema)
if isinstance(schema, dict):
return schema
if isinstance(schema, BaseModel):
return schema.model_json_schema()
@lru_cache
def _cached_build_vllm_token_enforcer_tokenizer_data(
tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
return build_vllm_token_enforcer_tokenizer_data(tokenizer)

View File

@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest) CompletionRequest)
from vllm.model_executor.guided_logits_processors import (CFGLogitsProcessor, from vllm.model_executor.guided_decoding.outlines_logits_processors import (
JSONLogitsProcessor, CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
RegexLogitsProcessor)
class GuidedDecodingMode(Enum): class GuidedDecodingMode(Enum):
@ -54,7 +53,7 @@ pair : UNESCAPED_STRING ":" value
global_thread_pool = None # used for generating logits processor fsm global_thread_pool = None # used for generating logits processor fsm
async def get_guided_decoding_logits_processor( async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest], request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
""" """
@ -91,7 +90,7 @@ def _get_guide_and_mode(
json = request.guided_json json = request.guided_json
if isinstance(json, dict): if isinstance(json, dict):
# turn dict into hashable string # turn dict into hashable string
json = json_dumps(json, sort_keys=True) json = json_dumps(json)
elif isinstance(json, BaseModel): elif isinstance(json, BaseModel):
# use pydantic signature so that different model classes # use pydantic signature so that different model classes
# with the same fields will get hashed the same # with the same fields will get hashed the same

View File

@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import json import json
import math import math
from collections import defaultdict from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Optional, Union from typing import Callable, DefaultDict, Dict, List, Optional, Union
import torch import torch
@ -27,50 +29,6 @@ from transformers import PreTrainedTokenizerBase
class BaseLogitsProcessor: class BaseLogitsProcessor:
def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
if getattr(tokenizer, "_outlines_adapted", False):
return tokenizer
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
def change_decoder(
decoder: Callable[[List[int]], str]
) -> Callable[[List[int]], List[str]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def new_decoder(inp_tokens: List[int]) -> List[str]:
return [decoder(inp_tokens)]
return new_decoder
tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
return tokenizer
def init_state(self): def init_state(self):
"""Initialize the FSM states.""" """Initialize the FSM states."""
self.fsm_state: DefaultDict[int, int] = defaultdict(int) self.fsm_state: DefaultDict[int, int] = defaultdict(int)
@ -78,7 +36,6 @@ class BaseLogitsProcessor:
def __call__(self, input_ids: List[int], def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor: scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token.""" """Use the FSM to bias the logits before sampling the next token."""
seq_id = hash(tuple(input_ids)) seq_id = hash(tuple(input_ids))
if len(input_ids) == 0: if len(input_ids) == 0:
@ -96,7 +53,6 @@ class BaseLogitsProcessor:
device=scores.device) device=scores.device)
mask[allowed_tokens] = 0 mask[allowed_tokens] = 0
scores.add_(mask) scores.add_(mask)
return scores return scores
@ -113,7 +69,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer The model's tokenizer
""" """
tokenizer = self.adapt_tokenizer(tokenizer) tokenizer = _adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer) fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm self.fsm = fsm
@ -167,6 +123,54 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
The model's tokenizer The model's tokenizer
""" """
tokenizer = self.adapt_tokenizer(tokenizer) tokenizer = _adapt_tokenizer(tokenizer)
fsm = CFGFSM(cfg, tokenizer) fsm = CFGFSM(cfg, tokenizer)
self.fsm = fsm self.fsm = fsm
@lru_cache
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. The decoder of outlines, returns a list whereas
the decode of vLLM returns an str. To sync the vLLM decoder with
outlines internal api, the decoder should be adapted. In addition
we need to handle the missing spaces to Llama's tokenizer to be
able to compile FSMs for this model.
"""
if getattr(tokenizer, "_outlines_adapted", False):
return tokenizer
tokenizer = copy.deepcopy(tokenizer)
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
def change_decoder(
decoder: Callable[[List[int]],
str]) -> Callable[[List[int]], List[str]]:
"""Sync vLLM's decoder with the outlines by returning list."""
def new_decoder(inp_tokens: List[int]) -> List[str]:
return [decoder(inp_tokens)]
return new_decoder
tokenizer.convert_token_to_string = convert_token_to_string
tokenizer.decode = change_decoder(tokenizer.decode)
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
return tokenizer

Some files were not shown because too many files have changed in this diff Show More