mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 01:03:44 +08:00
Merge branch 'main' into woosuk-tpu
This commit is contained in:
commit
d4adf92beb
@ -12,7 +12,13 @@ steps:
|
||||
command: pytest -v -s async_engine
|
||||
|
||||
- label: Basic Correctness Test
|
||||
command: pytest -v -s basic_correctness
|
||||
commands:
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
|
||||
- label: Core Test
|
||||
command: pytest -v -s core
|
||||
@ -29,6 +35,8 @@ steps:
|
||||
- pytest -v -s test_pynccl.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
|
||||
|
||||
- label: Engine Test
|
||||
command: pytest -v -s engine tokenization test_sequence.py test_config.py
|
||||
@ -83,6 +91,9 @@ steps:
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
|
||||
- label: Tensorizer Test
|
||||
command: apt-get install curl libsodium23 && pytest -v -s tensorizer
|
||||
|
||||
- label: Metrics Test
|
||||
command: pytest -v -s metrics
|
||||
|
||||
|
||||
50
.github/workflows/mypy.yaml
vendored
Normal file
50
.github/workflows/mypy.yaml
vendored
Normal file
@ -0,0 +1,50 @@
|
||||
name: mypy
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
jobs:
|
||||
ruff:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install mypy==1.9.0
|
||||
pip install types-setuptools
|
||||
pip install types-PyYAML
|
||||
pip install types-requests
|
||||
pip install types-setuptools
|
||||
- name: Mypy
|
||||
run: |
|
||||
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
# TODO(sang): Follow up
|
||||
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
2
.github/workflows/ruff.yml
vendored
2
.github/workflows/ruff.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
||||
2
.github/workflows/yapf.yml
vendored
2
.github/workflows/yapf.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
||||
@ -91,7 +91,7 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
|
||||
|
||||
COPY ./ /app/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --upgrade pip numba
|
||||
RUN python3 -m pip install xformers==0.0.23 --no-deps
|
||||
|
||||
RUN cd /app \
|
||||
|
||||
@ -72,7 +72,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
|
||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
|
||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
|
||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
|
||||
@ -27,8 +27,8 @@ class RequestFuncInput:
|
||||
class RequestFuncOutput:
|
||||
generated_text: str = ""
|
||||
success: bool = False
|
||||
latency: float = 0
|
||||
ttft: float = 0 # Time to first token
|
||||
latency: float = 0.0
|
||||
ttft: float = 0.0 # Time to first token
|
||||
itl: List[float] = field(
|
||||
default_factory=list) # List of inter-token latencies
|
||||
prompt_len: int = 0
|
||||
@ -58,23 +58,24 @@ async def async_request_tgi(
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||
"data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
@ -119,23 +120,24 @@ async def async_request_trt_llm(
|
||||
output = RequestFuncOutput()
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
ttft = 0
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data:")
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||
"data:")
|
||||
|
||||
data = json.loads(chunk)
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
@ -151,7 +153,7 @@ async def async_request_trt_llm(
|
||||
output.success = True
|
||||
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
@ -195,7 +197,7 @@ async def async_request_deepspeed_mii(
|
||||
output.generated_text = parsed_resp["text"][0]
|
||||
output.success = True
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
@ -234,19 +236,20 @@ async def async_request_openai_completions(
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||
"data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
@ -255,7 +258,7 @@ async def async_request_openai_completions(
|
||||
if data["choices"][0]["text"]:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0:
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
@ -315,19 +318,20 @@ async def async_request_openai_chat_completions(
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
ttft = 0
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
try:
|
||||
async with session.post(url=api_url, json=payload,
|
||||
headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk in response.content:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk.decode("utf-8"), "data: ")
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
||||
"data: ")
|
||||
if chunk == "[DONE]":
|
||||
latency = time.perf_counter() - st
|
||||
else:
|
||||
@ -337,7 +341,7 @@ async def async_request_openai_chat_completions(
|
||||
delta = data["choices"][0]["delta"]
|
||||
if delta.get("content", None):
|
||||
# First token
|
||||
if ttft == 0:
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
@ -354,7 +358,7 @@ async def async_request_openai_chat_completions(
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
else:
|
||||
output.error = response.reason
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception:
|
||||
output.success = False
|
||||
|
||||
@ -177,8 +177,7 @@ if __name__ == '__main__':
|
||||
help='block size of key/value cache')
|
||||
parser.add_argument(
|
||||
'--enable-chunked-prefill',
|
||||
type=bool,
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
parser.add_argument(
|
||||
|
||||
@ -74,25 +74,31 @@ def run_vllm(
|
||||
quantization_param_path: Optional[str],
|
||||
device: str,
|
||||
enable_prefix_caching: bool,
|
||||
enable_chunked_prefill: bool,
|
||||
max_num_batched_tokens: int,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
download_dir: Optional[str] = None,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
quantization_param_path=quantization_param_path,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
download_dir=download_dir)
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
quantization_param_path=quantization_param_path,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
download_dir=download_dir,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
for prompt, _, output_len in requests:
|
||||
@ -213,15 +219,15 @@ def main(args: argparse.Namespace):
|
||||
args.output_len)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||
args.quantization, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype,
|
||||
args.max_model_len, args.enforce_eager,
|
||||
args.kv_cache_dtype,
|
||||
args.quantization_param_path, args.device,
|
||||
args.enable_prefix_caching,
|
||||
args.gpu_memory_utilization, args.download_dir)
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.model, args.tokenizer, args.quantization,
|
||||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||
args.enforce_eager, args.kv_cache_dtype,
|
||||
args.quantization_param_path, args.device,
|
||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
||||
args.max_num_batched_tokens, args.gpu_memory_utilization,
|
||||
args.download_dir)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -335,6 +341,14 @@ if __name__ == "__main__":
|
||||
"--enable-prefix-caching",
|
||||
action='store_true',
|
||||
help="enable automatic prefix caching for vLLM backend.")
|
||||
parser.add_argument("--enable-chunked-prefill",
|
||||
action='store_true',
|
||||
help="enable chunked prefill for vLLM backend.")
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
help='maximum number of batched tokens per '
|
||||
'iteration')
|
||||
parser.add_argument('--download-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
||||
|
||||
NUM_BLOCKS = 1024
|
||||
|
||||
@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 128) \
|
||||
f(in_T, out_T, W_T, narrow, 256) \
|
||||
f(in_T, out_T, W_T, narrow, 512) \
|
||||
f(in_T, out_T, W_T, narrow, 640) \
|
||||
f(in_T, out_T, W_T, narrow, 768) \
|
||||
f(in_T, out_T, W_T, narrow, 1024) \
|
||||
f(in_T, out_T, W_T, narrow, 1152) \
|
||||
@ -46,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 13696) \
|
||||
f(in_T, out_T, W_T, narrow, 13824) \
|
||||
f(in_T, out_T, W_T, narrow, 14336) \
|
||||
f(in_T, out_T, W_T, narrow, 15360) \
|
||||
f(in_T, out_T, W_T, narrow, 16384) \
|
||||
f(in_T, out_T, W_T, narrow, 20480) \
|
||||
f(in_T, out_T, W_T, narrow, 22016) \
|
||||
@ -59,7 +61,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
f(in_T, out_T, W_T, narrow, 33024) \
|
||||
f(in_T, out_T, W_T, narrow, 36864) \
|
||||
f(in_T, out_T, W_T, narrow, 49152) \
|
||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||
f(in_T, out_T, W_T, narrow, 64000) \
|
||||
f(in_T, out_T, W_T, narrow, 64256) \
|
||||
f(in_T, out_T, W_T, narrow, 64512) \
|
||||
f(in_T, out_T, W_T, narrow, 102400) \
|
||||
f(in_T, out_T, W_T, narrow, 102656) \
|
||||
f(in_T, out_T, W_T, narrow, 102912) \
|
||||
f(in_T, out_T, W_T, narrow, 128000) \
|
||||
f(in_T, out_T, W_T, narrow, 128256) \
|
||||
f(in_T, out_T, W_T, narrow, 128512) \
|
||||
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
|
||||
// and vllm/tests/lora/test_punica.py
|
||||
|
||||
// Keep this in sync with vllm/config::LoRAConfig
|
||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
|
||||
@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||
return (uint32_t(a) << 16) | uint32_t(b);
|
||||
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
|
||||
return (uint64_t(a) << 32) | uint64_t(b);
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
|
||||
template <typename in_T, typename out_T, typename W_T>
|
||||
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||
const int64_t *lora_indices,
|
||||
uint16_t in_features, uint16_t out_features,
|
||||
uint32_t in_features, uint32_t out_features,
|
||||
int64_t y_offset, int64_t full_y_size,
|
||||
int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
switch (pack_u16(in_features, out_features)) {
|
||||
switch (pack_u32(in_features, out_features)) {
|
||||
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||
case pack_u16(feat_in, feat_out): \
|
||||
case pack_u32(feat_in, feat_out): \
|
||||
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||
full_y_size, batch_size, num_layers, \
|
||||
layer_idx, scale); \
|
||||
@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in < 65536 && h_out < 65536) {
|
||||
if (h_in <= 128512 && h_out <= 128512) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in < 65536 && h_out < 65536) {
|
||||
if (h_in <= 128512 && h_out <= 128512) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
|
||||
@ -2067,7 +2067,7 @@ void gptq_shuffle
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
|
||||
vllm::gptq::shuffle_exllama_weight(
|
||||
(uint32_t*) q_weight.data_ptr(),
|
||||
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
|
||||
q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(),
|
||||
q_weight.size(0) * 32 / bit,
|
||||
q_weight.size(1),
|
||||
bit
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from sphinx.ext import autodoc
|
||||
|
||||
@ -45,7 +46,7 @@ templates_path = ['_templates']
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = []
|
||||
exclude_patterns: List[str] = []
|
||||
|
||||
# Exclude the prompt "$" when copying code
|
||||
copybutton_prompt_text = r"\$ "
|
||||
@ -82,6 +83,7 @@ autodoc_mock_imports = [
|
||||
"vllm._C",
|
||||
"numpy",
|
||||
"tqdm",
|
||||
"tensorizer",
|
||||
]
|
||||
|
||||
for mock_target in autodoc_mock_imports:
|
||||
|
||||
@ -85,13 +85,3 @@ You can also build and install vLLM from source:
|
||||
|
||||
$ nvcc --version # verify that nvcc is in your PATH
|
||||
$ ${CUDA_HOME}/bin/nvcc --version # verify that nvcc is in your CUDA_HOME
|
||||
|
||||
.. note::
|
||||
If you are developing the C++ backend of vLLM, consider building vLLM with
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ python setup.py develop
|
||||
|
||||
since it will give you incremental builds. The downside is that this method
|
||||
is `deprecated by setuptools <https://github.com/pypa/setuptools/issues/917>`_.
|
||||
|
||||
@ -36,7 +36,7 @@ Below, you can find an explanation of every engine argument for vLLM:
|
||||
|
||||
Directory to download and load the weights, default to the default cache dir of huggingface.
|
||||
|
||||
.. option:: --load-format {auto,pt,safetensors,npcache,dummy}
|
||||
.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer}
|
||||
|
||||
The format of the model weights to load.
|
||||
|
||||
@ -45,6 +45,7 @@ Below, you can find an explanation of every engine argument for vLLM:
|
||||
* "safetensors" will load the weights in the safetensors format.
|
||||
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
|
||||
* "dummy" will initialize the weights with random values, mainly for profiling.
|
||||
* "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. <https://github.com/coreweave/tensorizer>`_ See `examples/tensorize_vllm_model.py <https://github.com/vllm-project/vllm/blob/main/examples/tensorize_vllm_model.py>`_ to serialize a vLLM model, and for more information.
|
||||
|
||||
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
|
||||
|
||||
|
||||
@ -30,23 +30,23 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - :code:`CohereForCausalLM`
|
||||
- Command-R
|
||||
- :code:`CohereForAI/c4ai-command-r-v01`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`DbrxForCausalLM`
|
||||
- DBRX
|
||||
- :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`DeciLMForCausalLM`
|
||||
- DeciLM
|
||||
- :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`BloomForCausalLM`
|
||||
- BLOOM, BLOOMZ, BLOOMChat
|
||||
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`FalconForCausalLM`
|
||||
- Falcon
|
||||
- :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`GemmaForCausalLM`
|
||||
- Gemma
|
||||
- :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
|
||||
@ -54,19 +54,19 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - :code:`GPT2LMHeadModel`
|
||||
- GPT-2
|
||||
- :code:`gpt2`, :code:`gpt2-xl`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`GPTBigCodeForCausalLM`
|
||||
- StarCoder, SantaCoder, WizardCoder
|
||||
- :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`GPTJForCausalLM`
|
||||
- GPT-J
|
||||
- :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`GPTNeoXForCausalLM`
|
||||
- GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM
|
||||
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`InternLMForCausalLM`
|
||||
- InternLM
|
||||
- :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
|
||||
@ -93,32 +93,32 @@ Alongside each architecture, we include some popular models that use it.
|
||||
- ✅︎
|
||||
* - :code:`MixtralForCausalLM`
|
||||
- Mixtral-8x7B, Mixtral-8x7B-Instruct
|
||||
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.
|
||||
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc.
|
||||
- ✅︎
|
||||
* - :code:`MPTForCausalLM`
|
||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`OLMoForCausalLM`
|
||||
- OLMo
|
||||
- :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`OPTForCausalLM`
|
||||
- OPT, OPT-IML
|
||||
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`OrionForCausalLM`
|
||||
- Orion
|
||||
- :code:`OrionStarAI/Orion-14B-Base`, :code:`OrionStarAI/Orion-14B-Chat`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`PhiForCausalLM`
|
||||
- Phi
|
||||
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen
|
||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`Qwen2ForCausalLM`
|
||||
- Qwen2
|
||||
- :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc.
|
||||
@ -126,11 +126,11 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - :code:`Qwen2MoeForCausalLM`
|
||||
- Qwen2MoE
|
||||
- :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`StableLmForCausalLM`
|
||||
- StableLM
|
||||
- :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc.
|
||||
-
|
||||
-
|
||||
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
|
||||
@ -168,3 +168,29 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
|
||||
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
Model Support Policy
|
||||
---------------------
|
||||
|
||||
At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Here’s how we manage third-party model support:
|
||||
|
||||
1. **Community-Driven Support**: We encourage community contributions for adding new models. When a user requests support for a new model, we welcome pull requests (PRs) from the community. These contributions are evaluated primarily on the sensibility of the output they generate, rather than strict consistency with existing implementations such as those in transformers. **Call for contribution:** PRs coming directly from model vendors are greatly appreciated!
|
||||
|
||||
2. **Best-Effort Consistency**: While we aim to maintain a level of consistency between the models implemented in vLLM and other frameworks like transformers, complete alignment is not always feasible. Factors like acceleration techniques and the use of low-precision computations can introduce discrepancies. Our commitment is to ensure that the implemented models are functional and produce sensible results.
|
||||
|
||||
3. **Issue Resolution and Model Updates**: Users are encouraged to report any bugs or issues they encounter with third-party models. Proposed fixes should be submitted via PRs, with a clear explanation of the problem and the rationale behind the proposed solution. If a fix for one model impacts another, we rely on the community to highlight and address these cross-model dependencies. Note: for bugfix PRs, it is good etiquette to inform the original author to seek their feedback.
|
||||
|
||||
4. **Monitoring and Updates**: Users interested in specific models should monitor the commit history for those models (e.g., by tracking changes in the main/vllm/model_executor/models directory). This proactive approach helps users stay informed about updates and changes that may affect the models they use.
|
||||
|
||||
5. **Selective Focus**: Our resources are primarily directed towards models with significant user interest and impact. Models that are less frequently used may receive less attention, and we rely on the community to play a more active role in their upkeep and improvement.
|
||||
|
||||
Through this approach, vLLM fosters a collaborative environment where both the core development team and the broader community contribute to the robustness and diversity of the third-party models supported in our ecosystem.
|
||||
|
||||
Note that, as an inference engine, vLLM does not introduce new models. Therefore, all models supported by vLLM are third-party models in this regard.
|
||||
|
||||
We have the following levels of testing for models:
|
||||
|
||||
1. **Strict Consistency**: We compare the output of the model with the output of the model in the HuggingFace Transformers library under greedy decoding. This is the most stringent test. Please refer to `test_models.py <https://github.com/vllm-project/vllm/blob/main/tests/models/test_models.py>`_ and `test_big_models.py <https://github.com/vllm-project/vllm/blob/main/tests/models/test_big_models.py>`_ for the models that have passed this test.
|
||||
2. **Output Sensibility**: We check if the output of the model is sensible and coherent, by measuring the perplexity of the output and checking for any obvious errors. This is a less stringent test.
|
||||
3. **Runtime Functionality**: We check if the model can be loaded and run without errors. This is the least stringent test. Please refer to `functionality tests <https://github.com/vllm-project/vllm/tree/main/tests>`_ and `examples <https://github.com/vllm-project/vllm/tree/main/examples>`_ for the models that have passed this test.
|
||||
4. **Community Feedback**: We rely on the community to provide feedback on the models. If a model is broken or not working as expected, we encourage users to raise issues to report it or open pull requests to fix it. The rest of the models fall under this category.
|
||||
|
||||
@ -4,7 +4,7 @@ vLLM provides an HTTP server that implements OpenAI's [Completions](https://plat
|
||||
|
||||
You can start the server using Python, or using [Docker](deploying_with_docker.rst):
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-hf --dtype float32 --api-key token-abc123
|
||||
python -m vllm.entrypoints.openai.api_server --model mistralai/Mistral-7B-Instruct-v0.2 --dtype auto --api-key token-abc123
|
||||
```
|
||||
|
||||
To call the server, you can use the official OpenAI Python client library, or any other HTTP client.
|
||||
@ -16,9 +16,8 @@ client = OpenAI(
|
||||
)
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model="meta-llama/Llama-2-7b-hf",
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"}
|
||||
]
|
||||
)
|
||||
@ -38,9 +37,8 @@ Or directly merge them into the JSON payload if you are using HTTP call directly
|
||||
|
||||
```python
|
||||
completion = client.chat.completions.create(
|
||||
model="meta-llama/Llama-2-7b-hf",
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Classify this sentiment: vLLM is wonderful!"}
|
||||
],
|
||||
extra_body={
|
||||
@ -89,7 +87,7 @@ In order for the language model to support chat protocol, vLLM requires the mode
|
||||
a chat template in its tokenizer configuration. The chat template is a Jinja2 template that
|
||||
specifies how are roles, messages, and other chat-specific tokens are encoded in the input.
|
||||
|
||||
An example chat template for `meta-llama/Llama-2-7b-chat-hf` can be found [here](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/09bd0f49e16738cdfaa6e615203e126038736eb0/tokenizer_config.json#L12)
|
||||
An example chat template for `mistralai/Mistral-7B-Instruct-v0.2` can be found [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2#instruction-format)
|
||||
|
||||
Some models do not provide a chat template even though they are instruction/chat fine-tuned. For those model,
|
||||
you can manually specify their chat template in the `--chat-template` parameter with the file path to the chat
|
||||
|
||||
282
examples/tensorize_vllm_model.py
Normal file
282
examples/tensorize_vllm_model.py
Normal file
@ -0,0 +1,282 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
||||
TensorSerializer, stream_io)
|
||||
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from vllm.distributed import initialize_model_parallel
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
"""
|
||||
tensorize_vllm_model.py is a script that can be used to serialize and
|
||||
deserialize vLLM models. These models can be loaded using tensorizer
|
||||
to the GPU extremely quickly over an HTTP/HTTPS endpoint, an S3 endpoint,
|
||||
or locally. Tensor encryption and decryption is also supported, although
|
||||
libsodium must be installed to use it. Install vllm with tensorizer support
|
||||
using `pip install vllm[tensorizer]`.
|
||||
|
||||
To serialize a model, install vLLM from source, then run something
|
||||
like this from the root level of this repository:
|
||||
|
||||
python -m examples.tensorize_vllm_model \
|
||||
--model EleutherAI/gpt-j-6B \
|
||||
--dtype float16 \
|
||||
serialize \
|
||||
--serialized-directory s3://my-bucket/ \
|
||||
--suffix vllm
|
||||
|
||||
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
|
||||
and saves it to your S3 bucket. A local directory can also be used. This
|
||||
assumes your S3 credentials are specified as environment variables
|
||||
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
|
||||
To provide S3 credentials directly, you can provide `--s3-access-key-id` and
|
||||
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this
|
||||
script.
|
||||
|
||||
You can also encrypt the model weights with a randomly-generated key by
|
||||
providing a `--keyfile` argument.
|
||||
|
||||
To deserialize a model, you can run something like this from the root
|
||||
level of this repository:
|
||||
|
||||
python -m examples.tensorize_vllm_model \
|
||||
--model EleutherAI/gpt-j-6B \
|
||||
--dtype float16 \
|
||||
deserialize \
|
||||
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
|
||||
|
||||
Which downloads the model tensors from your S3 bucket and deserializes them.
|
||||
|
||||
You can also provide a `--keyfile` argument to decrypt the model weights if
|
||||
they were serialized with encryption.
|
||||
|
||||
For more information on the available arguments for serializing, run
|
||||
`python -m examples.tensorize_vllm_model serialize --help`.
|
||||
|
||||
Or for deserializing:
|
||||
|
||||
`python -m examples.tensorize_vllm_model deserialize --help`.
|
||||
|
||||
Once a model is serialized, it can be used to load the model when running the
|
||||
OpenAI inference client at `vllm/entrypoints/openai/api_server.py` by providing
|
||||
the `--tensorizer-uri` CLI argument that is functionally the same as the
|
||||
`--path-to-tensors` argument in this script, along with `--vllm-tensorized`, to
|
||||
signify that the model to be deserialized is a vLLM model, rather than a
|
||||
HuggingFace `PreTrainedModel`, which can also be deserialized using tensorizer
|
||||
in the same inference server, albeit without the speed optimizations. To
|
||||
deserialize an encrypted file, the `--encryption-keyfile` argument can be used
|
||||
to provide the path to the keyfile used to encrypt the model weights. For
|
||||
information on all the arguments that can be used to configure tensorizer's
|
||||
deserialization, check out the tensorizer options argument group in the
|
||||
`vllm/entrypoints/openai/api_server.py` script with `--help`.
|
||||
|
||||
Tensorizer can also be invoked with the `LLM` class directly to load models:
|
||||
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
load_format="tensorizer",
|
||||
tensorizer_uri=path_to_opt_tensors,
|
||||
num_readers=3,
|
||||
vllm_tensorized=True)
|
||||
"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="An example script that can be used to serialize and "
|
||||
"deserialize vLLM models. These models "
|
||||
"can be loaded using tensorizer directly to the GPU "
|
||||
"extremely quickly. Tensor encryption and decryption is "
|
||||
"also supported, although libsodium must be installed to "
|
||||
"use it.")
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
subparsers = parser.add_subparsers(dest='command')
|
||||
|
||||
serialize_parser = subparsers.add_parser(
|
||||
'serialize', help="Serialize a model to `--serialized-directory`")
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--suffix",
|
||||
type=str,
|
||||
required=False,
|
||||
help=(
|
||||
"The suffix to append to the serialized model directory, which is "
|
||||
"used to construct the location of the serialized model tensors, "
|
||||
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
|
||||
"`--suffix` is `v1`, the serialized model tensors will be "
|
||||
"saved to "
|
||||
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
|
||||
"If none is provided, a random UUID will be used."))
|
||||
serialize_parser.add_argument(
|
||||
"--serialized-directory",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The directory to serialize the model to. "
|
||||
"This can be a local directory or S3 URI. The path to where the "
|
||||
"tensors are saved is a combination of the supplied `dir` and model "
|
||||
"reference ID. For instance, if `dir` is the serialized directory, "
|
||||
"and the model HuggingFace ID is `EleutherAI/gpt-j-6B`, tensors will "
|
||||
"be saved to `dir/vllm/EleutherAI/gpt-j-6B/suffix/model.tensors`, "
|
||||
"where `suffix` is given by `--suffix` or a random UUID if not "
|
||||
"provided.")
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
required=False,
|
||||
help=("Encrypt the model weights with a randomly-generated binary key,"
|
||||
" and save the key at this path"))
|
||||
|
||||
deserialize_parser = subparsers.add_parser(
|
||||
'deserialize',
|
||||
help=("Deserialize a model from `--path-to-tensors`"
|
||||
" to verify it can be loaded and used."))
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--path-to-tensors",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The local path or S3 URI to the model tensors to deserialize. ")
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
required=False,
|
||||
help=("Path to a binary key to use to decrypt the model weights,"
|
||||
" if the model was serialized with encryption"))
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_model_contiguous(model):
|
||||
# Ensure tensors are saved in memory contiguously
|
||||
for param in model.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
|
||||
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry.load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return model_cls
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
|
||||
|
||||
def serialize():
|
||||
|
||||
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||
dataclasses.fields(EngineArgs)}
|
||||
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
model = (engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
|
||||
encryption_params = EncryptionParams.random() if keyfile else None
|
||||
if keyfile:
|
||||
with _write_stream(keyfile) as stream:
|
||||
stream.write(encryption_params.key)
|
||||
|
||||
with _write_stream(model_path) as stream:
|
||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||
serializer.write_module(model)
|
||||
serializer.close()
|
||||
|
||||
print("Serialization complete. Model tensors saved to", model_path)
|
||||
if keyfile:
|
||||
print("Key saved to", keyfile)
|
||||
|
||||
|
||||
def deserialize():
|
||||
config = AutoConfig.from_pretrained(model_ref)
|
||||
|
||||
with no_init_or_tensor():
|
||||
model_class = _get_vllm_model_architecture(config)
|
||||
model = model_class(config)
|
||||
|
||||
before_mem = get_mem_usage()
|
||||
start = time.time()
|
||||
|
||||
if keyfile:
|
||||
with _read_stream(keyfile) as stream:
|
||||
key = stream.read()
|
||||
decryption_params = DecryptionParams.from_key(key)
|
||||
tensorizer_args.deserializer_params['encryption'] = \
|
||||
decryption_params
|
||||
|
||||
with (_read_stream(model_path)) as stream, TensorDeserializer(
|
||||
stream, **tensorizer_args.deserializer_params) as deserializer:
|
||||
deserializer.load_into_module(model)
|
||||
end = time.time()
|
||||
|
||||
# Brag about how fast we are.
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
print(
|
||||
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
|
||||
)
|
||||
print(f"Memory usage before: {before_mem}")
|
||||
print(f"Memory usage after: {after_mem}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
|
||||
or None)
|
||||
s3_secret_access_key = (args.s3_secret_access_key
|
||||
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
|
||||
|
||||
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
|
||||
|
||||
_read_stream, _write_stream = (partial(
|
||||
stream_io.open_stream,
|
||||
mode=mode,
|
||||
s3_access_key_id=s3_access_key_id,
|
||||
s3_secret_access_key=s3_secret_access_key,
|
||||
s3_endpoint=s3_endpoint,
|
||||
) for mode in ("rb", "wb+"))
|
||||
|
||||
model_ref = args.model
|
||||
|
||||
model_name = model_ref.split("/")[1]
|
||||
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "8080"
|
||||
|
||||
torch.distributed.init_process_group(world_size=1, rank=0)
|
||||
initialize_model_parallel()
|
||||
|
||||
keyfile = args.keyfile if args.keyfile else None
|
||||
|
||||
if args.command == "serialize":
|
||||
input_dir = args.serialized_directory.rstrip('/')
|
||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
||||
model_path = f"{base_path}/model.tensors"
|
||||
serialize()
|
||||
elif args.command == "deserialize":
|
||||
tensorizer_args = TensorizerArgs.from_cli_args(args)
|
||||
model_path = args.path_to_tensors
|
||||
deserialize()
|
||||
else:
|
||||
raise ValueError("Either serialize or deserialize must be specified.")
|
||||
22
format.sh
22
format.sh
@ -93,9 +93,23 @@ fi
|
||||
echo 'vLLM yapf: Done'
|
||||
|
||||
# Run mypy
|
||||
# TODO(zhuohan): Enable mypy
|
||||
# echo 'vLLM mypy:'
|
||||
# mypy
|
||||
echo 'vLLM mypy:'
|
||||
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
# TODO(sang): Follow up
|
||||
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
|
||||
|
||||
|
||||
CODESPELL_EXCLUDES=(
|
||||
'--skip' '*docs/source/_build/**'
|
||||
@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then
|
||||
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
|
||||
@ -46,10 +46,13 @@ ignore = [
|
||||
python_version = "3.8"
|
||||
|
||||
ignore_missing_imports = true
|
||||
check_untyped_defs = true
|
||||
|
||||
files = "vllm"
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
|
||||
exclude = [
|
||||
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
|
||||
]
|
||||
|
||||
|
||||
[tool.codespell]
|
||||
|
||||
@ -11,4 +11,7 @@ uvicorn[standard]
|
||||
pydantic >= 2.0 # Required for OpenAI server.
|
||||
prometheus_client >= 0.18.0
|
||||
tiktoken == 0.6.0 # Required for DBRX tokenizer
|
||||
outlines == 0.0.34 # Requires torch >= 2.1.0
|
||||
lm-format-enforcer == 0.9.3
|
||||
outlines == 0.0.34 # Requires torch >= 2.1.0
|
||||
typing_extensions
|
||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||
|
||||
@ -3,4 +3,4 @@
|
||||
|
||||
# Dependencies for x86_64 CPUs
|
||||
torch == 2.2.1+cpu
|
||||
triton >= 2.1.0 # FIXME(woosuk): This is a hack to avoid import error.
|
||||
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
|
||||
@ -7,4 +7,3 @@ pynvml == 11.5.0
|
||||
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
|
||||
torch == 2.2.1
|
||||
xformers == 0.0.25 # Requires PyTorch 2.2.1
|
||||
triton >= 2.1.0
|
||||
|
||||
@ -7,13 +7,14 @@ codespell==2.2.6
|
||||
isort==5.13.2
|
||||
|
||||
# type checking
|
||||
mypy==0.991
|
||||
mypy==1.9.0
|
||||
types-PyYAML
|
||||
types-requests
|
||||
types-setuptools
|
||||
|
||||
# testing
|
||||
pytest
|
||||
tensorizer==2.9.0a0
|
||||
pytest-forked
|
||||
pytest-asyncio
|
||||
pytest-rerunfailures
|
||||
|
||||
8
setup.py
8
setup.py
@ -5,7 +5,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
from shutil import which
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from packaging.version import Version, parse
|
||||
@ -52,7 +52,7 @@ class CMakeExtension(Extension):
|
||||
|
||||
class cmake_build_ext(build_ext):
|
||||
# A dict of extension directories that have been configured.
|
||||
did_config = {}
|
||||
did_config: Dict[str, bool] = {}
|
||||
|
||||
#
|
||||
# Determine number of compilation jobs and optionally nvcc compile threads.
|
||||
@ -269,6 +269,7 @@ def get_nvcc_cuda_version() -> Version:
|
||||
|
||||
Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py
|
||||
"""
|
||||
assert CUDA_HOME is not None, "CUDA_HOME is not set"
|
||||
nvcc_output = subprocess.check_output([CUDA_HOME + "/bin/nvcc", "-V"],
|
||||
universal_newlines=True)
|
||||
output = nvcc_output.split()
|
||||
@ -416,6 +417,9 @@ setup(
|
||||
python_requires=">=3.8",
|
||||
install_requires=get_requirements(),
|
||||
ext_modules=ext_modules,
|
||||
extras_require={
|
||||
"tensorizer": ["tensorizer==2.9.0a1"],
|
||||
},
|
||||
cmdclass={"build_ext": cmake_build_ext} if _build_custom_ops() else {},
|
||||
package_data=package_data,
|
||||
)
|
||||
|
||||
@ -25,21 +25,30 @@ def _query_server_long(prompt: str) -> dict:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server(tokenizer_pool_size: int):
|
||||
def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
|
||||
worker_use_ray: bool):
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
uvicorn_process = subprocess.Popen([
|
||||
commands = [
|
||||
sys.executable, "-u",
|
||||
str(script_path), "--model", "facebook/opt-125m", "--host",
|
||||
"127.0.0.1", "--tokenizer-pool-size",
|
||||
str(tokenizer_pool_size)
|
||||
])
|
||||
]
|
||||
if engine_use_ray:
|
||||
commands.append("--engine-use-ray")
|
||||
if worker_use_ray:
|
||||
commands.append("--worker-use-ray")
|
||||
uvicorn_process = subprocess.Popen(commands)
|
||||
yield
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
|
||||
def test_api_server(api_server, tokenizer_pool_size: int):
|
||||
@pytest.mark.parametrize("worker_use_ray", [False, True])
|
||||
@pytest.mark.parametrize("engine_use_ray", [False, True])
|
||||
def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
|
||||
engine_use_ray: bool):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
|
||||
|
||||
66
tests/basic_correctness/test_chunked_prefill.py
Normal file
66
tests/basic_correctness/test_chunked_prefill.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
It tests chunked prefill. Chunked prefill can be enabled by
|
||||
enable_chunked_prefill=True. If prefill size exceeds max_num_batched_tokens,
|
||||
prefill requests are chunked.
|
||||
|
||||
Run `pytest tests/models/test_chunked_prefill.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"facebook/opt-125m",
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
|
||||
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
enforce_eager: bool,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
enable_chunked_prefill = False
|
||||
max_num_batched_tokens = None
|
||||
if chunked_prefill_token_size != -1:
|
||||
enable_chunked_prefill = True
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_seqs=max_num_seqs,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
print(vllm_outputs[0])
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
@ -11,8 +11,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor,
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
destroy_model_parallel)
|
||||
from vllm.distributed import destroy_model_parallel
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
@ -402,7 +401,7 @@ class VllmRunner:
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="session")
|
||||
def vllm_runner():
|
||||
return VllmRunner
|
||||
|
||||
|
||||
@ -104,10 +104,10 @@ def test_chunk():
|
||||
# One chunked prefill, and one decoding.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert set(get_sequence_groups(out)) == set(running)
|
||||
# The first one is decoding.
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
# The first one is prefill. Scheduler guarantees ordering.
|
||||
assert seq_group_meta[0].token_chunk_size == 56
|
||||
# The second one is a chunked prefill.
|
||||
assert seq_group_meta[1].token_chunk_size == 56
|
||||
assert seq_group_meta[1].token_chunk_size == 1
|
||||
assert out.num_prefill_groups == 1
|
||||
assert out.num_batched_tokens == 57
|
||||
|
||||
@ -157,12 +157,12 @@ def test_complex():
|
||||
# Decoding & chunked prefill & first chunk of 3rd request is scheduled.
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(get_sequence_groups(out)) == 3
|
||||
# The first one is decoding.
|
||||
assert seq_group_meta[0].token_chunk_size == 1
|
||||
# The second one is a chunked prefill.
|
||||
# The first one is the first chunked prefill.
|
||||
assert seq_group_meta[0].token_chunk_size == 7
|
||||
# The second one is the second new chunked prefill.
|
||||
assert seq_group_meta[1].token_chunk_size == 56
|
||||
# The third one is also chunked.
|
||||
assert seq_group_meta[2].token_chunk_size == 7
|
||||
# The last one is decode.
|
||||
assert seq_group_meta[2].token_chunk_size == 1
|
||||
# Two of them are in chunked prefill.
|
||||
assert out.num_prefill_groups == 2
|
||||
assert out.num_batched_tokens == 64
|
||||
|
||||
@ -33,11 +33,16 @@ def test_models(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype, tensor_parallel_size=2)
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
|
||||
66
tests/distributed/test_chunked_prefill_distributed.py
Normal file
66
tests/distributed/test_chunked_prefill_distributed.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
|
||||
vLLM will allocate all the available memory, so we need to run the tests one
|
||||
by one. The solution is to pass arguments (model name) by environment
|
||||
variables.
|
||||
|
||||
Run:
|
||||
```sh
|
||||
TEST_DIST_MODEL=facebook/opt-125m pytest \
|
||||
test_chunked_prefill_distributed.py
|
||||
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
|
||||
test_chunked_prefill_distributed.py
|
||||
```
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
MODELS = [
|
||||
os.environ["TEST_DIST_MODEL"],
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
chunked_prefill_token_size: int,
|
||||
) -> None:
|
||||
# Add a chunked prefill config.
|
||||
max_num_seqs = min(chunked_prefill_token_size, 256)
|
||||
assert chunked_prefill_token_size != -1
|
||||
enable_chunked_prefill = True
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=2,
|
||||
max_num_seqs=max_num_seqs,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
||||
@ -8,9 +8,9 @@ import pytest
|
||||
import ray
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict, tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import (broadcast_tensor_dict,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.test_utils import (init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
|
||||
|
||||
@ -6,9 +6,8 @@ import ray
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.device_communicators import custom_all_reduce
|
||||
from vllm.test_utils import (init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
|
||||
@ -26,10 +25,10 @@ def graph_allreduce(world_size, rank, distributed_init_port):
|
||||
init_test_distributed_environment(1, world_size, rank,
|
||||
distributed_init_port)
|
||||
|
||||
custom_ar.init_custom_ar()
|
||||
custom_all_reduce.init_custom_all_reduce()
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
with custom_ar.capture():
|
||||
with custom_all_reduce.capture():
|
||||
# use integers so result matches NCCL exactly
|
||||
inp1 = torch.randint(1,
|
||||
16, (sz, ),
|
||||
@ -62,8 +61,8 @@ def eager_allreduce(world_size, rank, distributed_init_port):
|
||||
distributed_init_port)
|
||||
|
||||
sz = 1024
|
||||
custom_ar.init_custom_ar()
|
||||
fa = custom_ar.get_handle()
|
||||
custom_all_reduce.init_custom_all_reduce()
|
||||
fa = custom_all_reduce.get_handle()
|
||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||
out = fa.all_reduce_unreg(inp)
|
||||
assert torch.allclose(out, inp * world_size)
|
||||
|
||||
@ -4,8 +4,8 @@ import os
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
|
||||
ncclGetUniqueId)
|
||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
||||
ncclGetUniqueId)
|
||||
|
||||
|
||||
def distributed_run(fn, world_size):
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
2. One of the provided stop tokens
|
||||
3. The EOS token
|
||||
|
||||
Run `pytest tests/samplers/test_stop_reason.py`.
|
||||
Run `pytest tests/engine/test_stop_reason.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
111
tests/engine/test_stop_strings.py
Normal file
111
tests/engine/test_stop_strings.py
Normal file
@ -0,0 +1,111 @@
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import CompletionOutput, LLMEngine, SamplingParams
|
||||
|
||||
MODEL = "meta-llama/llama-2-7b-hf"
|
||||
MAX_TOKENS = 200
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vllm_model(vllm_runner):
|
||||
return vllm_runner(MODEL)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_basic(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=".")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["."],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization.",
|
||||
expected_reason=".")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_multi_tokens(vllm_model):
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer organization. We are a ",
|
||||
expected_reason="group of peo")
|
||||
|
||||
_test_stopping(
|
||||
vllm_model.model.llm_engine,
|
||||
stop=["group of peo", "short"],
|
||||
include_in_output=True,
|
||||
expected_output=
|
||||
"VLLM is a 100% volunteer organization. We are a group of peo",
|
||||
expected_reason="group of peo")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_partial_token(vllm_model):
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer or",
|
||||
expected_reason="gani")
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop=["gani"],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organi",
|
||||
expected_reason="gani")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_token_id(vllm_model):
|
||||
# token id 13013 => " organization"
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=False,
|
||||
expected_output="VLLM is a 100% volunteer",
|
||||
expected_reason=13013)
|
||||
|
||||
_test_stopping(vllm_model.model.llm_engine,
|
||||
stop_token_ids=[13013],
|
||||
include_in_output=True,
|
||||
expected_output="VLLM is a 100% volunteer organization",
|
||||
expected_reason=13013)
|
||||
|
||||
|
||||
def _test_stopping(llm_engine: LLMEngine,
|
||||
expected_output: str,
|
||||
expected_reason: Any,
|
||||
stop: Optional[List[str]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
include_in_output: bool = False) -> None:
|
||||
llm_engine.add_request(
|
||||
"id", "A story about vLLM:\n",
|
||||
SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
include_stop_str_in_output=include_in_output,
|
||||
), None)
|
||||
|
||||
output: Optional[CompletionOutput] = None
|
||||
output_text = ""
|
||||
stop_reason = None
|
||||
while llm_engine.has_unfinished_requests():
|
||||
(request_output, ) = llm_engine.step()
|
||||
(output, ) = request_output.outputs
|
||||
|
||||
# Ensure we don't backtrack
|
||||
assert output.text.startswith(output_text)
|
||||
output_text = output.text
|
||||
stop_reason = output.stop_reason
|
||||
|
||||
assert output is not None
|
||||
assert output_text == expected_output
|
||||
assert stop_reason == expected_reason
|
||||
@ -1,11 +1,14 @@
|
||||
# This unit test should be moved to a new
|
||||
# tests/test_guided_decoding directory.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
|
||||
RegexLogitsProcessor)
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
|
||||
TEST_SCHEMA = {
|
||||
"type": "object",
|
||||
@ -73,3 +76,36 @@ def test_guided_logits_processors():
|
||||
json_LP(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_logits_processor_black_box(backend: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||
token_ids = tokenizer.encode(
|
||||
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
|
||||
regex_request = CompletionRequest(model='test',
|
||||
prompt=token_ids,
|
||||
guided_regex=TEST_REGEX)
|
||||
regex_lp = await get_guided_decoding_logits_processor(
|
||||
backend, regex_request, tokenizer)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = regex_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
|
||||
json_request = CompletionRequest(model='test',
|
||||
prompt=token_ids,
|
||||
guided_json=TEST_SCHEMA)
|
||||
json_lp = await get_guided_decoding_logits_processor(
|
||||
backend, json_request, tokenizer)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
@ -141,7 +141,7 @@ def server(zephyr_lora_files):
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
"--max-num-seqs",
|
||||
"128"
|
||||
"128",
|
||||
])
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
||||
assert first_response != completion.choices[0].text
|
||||
|
||||
|
||||
async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example JSON for an employee profile "
|
||||
@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 3
|
||||
@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
||||
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
|
||||
|
||||
|
||||
async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||
max_tokens=1000,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None
|
||||
json1 = json.loads(message.content)
|
||||
@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||
max_tokens=1000,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None
|
||||
json2 = json.loads(message.content)
|
||||
@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
assert json1["age"] != json2["age"]
|
||||
|
||||
|
||||
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=TEST_REGEX))
|
||||
extra_body=dict(guided_regex=TEST_REGEX,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 3
|
||||
@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
|
||||
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
|
||||
|
||||
|
||||
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=TEST_REGEX))
|
||||
extra_body=dict(guided_regex=TEST_REGEX,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
ip1 = chat_completion.choices[0].message.content
|
||||
assert ip1 is not None
|
||||
assert re.fullmatch(TEST_REGEX, ip1) is not None
|
||||
@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=TEST_REGEX))
|
||||
extra_body=dict(guided_regex=TEST_REGEX,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
ip2 = chat_completion.choices[0].message.content
|
||||
assert ip2 is not None
|
||||
assert re.fullmatch(TEST_REGEX, ip2) is not None
|
||||
assert ip1 != ip2
|
||||
|
||||
|
||||
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="The best language for type-safe systems programming is ",
|
||||
n=2,
|
||||
temperature=1.0,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 2
|
||||
@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
|
||||
assert completion.choices[i].text in TEST_CHOICE
|
||||
|
||||
|
||||
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
choice1 = chat_completion.choices[0].message.content
|
||||
assert choice1 in TEST_CHOICE
|
||||
|
||||
@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
choice2 = chat_completion.choices[0].message.content
|
||||
assert choice2 in TEST_CHOICE
|
||||
assert choice1 != choice2
|
||||
|
||||
|
||||
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Give an example JSON that fits this schema: 42",
|
||||
extra_body=dict(guided_json=42))
|
||||
extra_body=dict(guided_json=42,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
@ -742,5 +773,36 @@ number: "1" | "2"
|
||||
assert content.strip() == ground_truth
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# test using text and token IDs
|
||||
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=True,
|
||||
logprobs=1)
|
||||
|
||||
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
||||
list) else prompt
|
||||
assert (completion.choices[0].text is not None
|
||||
and re.search(r"^" + prompt_text, completion.choices[0].text))
|
||||
logprobs = completion.choices[0].logprobs
|
||||
assert logprobs is not None
|
||||
assert len(logprobs.text_offset) > 5
|
||||
assert (len(logprobs.token_logprobs) > 5
|
||||
and logprobs.token_logprobs[0] is None)
|
||||
assert (len(logprobs.top_logprobs) > 5
|
||||
and logprobs.top_logprobs[0] is None)
|
||||
assert len(logprobs.tokens) > 5
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -7,7 +7,7 @@ from allclose_default import get_default_atol, get_default_rtol
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm._C import cache_ops, ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
@ -237,14 +237,14 @@ def test_paged_attention(
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
cache_ops.convert_fp8(key_cache, dequantized_key_cache)
|
||||
ops.convert_fp8(key_cache, dequantized_key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
cache_ops.convert_fp8(value_cache, dequantized_value_cache)
|
||||
ops.convert_fp8(value_cache, dequantized_value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
ref_output = torch.empty_like(query)
|
||||
|
||||
@ -4,7 +4,7 @@ from typing import Tuple
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm._C import cache_ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import is_hip
|
||||
|
||||
COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
|
||||
@ -80,7 +80,7 @@ def test_copy_blocks(
|
||||
cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
|
||||
|
||||
# Call the copy blocks kernel.
|
||||
cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
||||
# Run the reference implementation.
|
||||
for src, dsts in block_mapping.items():
|
||||
@ -145,9 +145,9 @@ def test_reshape_and_cache(
|
||||
# Clone the KV caches.
|
||||
if kv_cache_dtype == "fp8":
|
||||
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(key_cache, cloned_key_cache)
|
||||
ops.convert_fp8(key_cache, cloned_key_cache)
|
||||
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(value_cache, cloned_value_cache)
|
||||
ops.convert_fp8(value_cache, cloned_value_cache)
|
||||
else:
|
||||
cloned_key_cache = key_cache.clone()
|
||||
cloned_value_cache = value_cache.clone()
|
||||
@ -156,14 +156,14 @@ def test_reshape_and_cache(
|
||||
kv_scale = 1.0
|
||||
|
||||
# Call the reshape_and_cache kernel.
|
||||
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, kv_scale)
|
||||
ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
|
||||
kv_cache_dtype, kv_scale)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(key_cache, result_key_cache)
|
||||
ops.convert_fp8(key_cache, result_key_cache)
|
||||
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
|
||||
cache_ops.convert_fp8(value_cache, result_value_cache)
|
||||
ops.convert_fp8(value_cache, result_value_cache)
|
||||
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
@ -251,9 +251,8 @@ def test_swap_blocks(
|
||||
src_value_caches_clone = src_value_caches[0].clone()
|
||||
|
||||
# Call the swap_blocks kernel.
|
||||
cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
|
||||
cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
|
||||
block_mapping)
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
|
||||
|
||||
for src, dst in block_mapping.items():
|
||||
assert torch.allclose(src_key_caches_clone[src].cpu(),
|
||||
@ -291,9 +290,9 @@ def test_fp8_conversion(
|
||||
cache.uniform_(low, high)
|
||||
|
||||
cache_fp8 = torch.empty_like(cache, dtype=torch.uint8)
|
||||
cache_ops.convert_fp8(cache, cache_fp8)
|
||||
ops.convert_fp8(cache, cache_fp8)
|
||||
|
||||
converted_cache = torch.empty_like(cache)
|
||||
cache_ops.convert_fp8(cache_fp8, converted_cache)
|
||||
ops.convert_fp8(cache_fp8, converted_cache)
|
||||
|
||||
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
|
||||
|
||||
@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
).cuda()
|
||||
|
||||
# Load the weights
|
||||
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
|
||||
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
|
||||
for i in range(config.num_local_experts):
|
||||
weights = (hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data)
|
||||
|
||||
@ -12,6 +12,7 @@ from huggingface_hub import snapshot_download
|
||||
|
||||
import vllm
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import destroy_model_parallel, initialize_model_parallel
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -19,8 +20,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
destroy_model_parallel, initialize_model_parallel)
|
||||
|
||||
|
||||
def cleanup():
|
||||
@ -144,6 +143,11 @@ def baichuan_lora_files():
|
||||
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tinyllama_lora_files():
|
||||
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||
cleanup()
|
||||
|
||||
@ -62,7 +62,7 @@ def test_baichuan_lora(baichuan_lora_files):
|
||||
|
||||
|
||||
@pytest.mark.skip("Requires multiple GPUs")
|
||||
def test_llama_tensor_parallel_equality(baichuan_lora_files):
|
||||
def test_baichuan_tensor_parallel_equality(baichuan_lora_files):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < 4:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
|
||||
|
||||
@ -170,7 +170,8 @@ def create_random_inputs(
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_random_embedding_layer():
|
||||
embedding = VocabParallelEmbedding(512, 256)
|
||||
embedding = VocabParallelEmbedding(vocab_size, 256)
|
||||
embedding.weight.data = torch.rand_like(embedding.weight.data)
|
||||
embedding.weight.data[512:, :] = 0
|
||||
embedding.weight.data[vocab_size:, :] = 0
|
||||
lora_embedding = VocabParallelEmbeddingWithLoRA(embedding)
|
||||
lora_embedding.create_lora_weights(max_loras, lora_config)
|
||||
|
||||
@ -203,12 +204,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=list(lora_dict.keys()),
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info)
|
||||
|
||||
lora_result = lora_embedding(torch.cat(inputs))
|
||||
@ -240,12 +242,13 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=[0],
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_embedding(torch.cat(inputs))
|
||||
@ -263,7 +266,9 @@ def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
# reason="Fails when loras are in any slot other than the first.")
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
vocab_size) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def create_random_embedding_layer():
|
||||
embedding = VocabParallelEmbedding(512, 256)
|
||||
embedding = VocabParallelEmbedding(vocab_size, 256)
|
||||
embedding_data = torch.rand_like(embedding.weight.data)
|
||||
embedding.weight.data = embedding_data
|
||||
embedding.weight.data[512:, :] = 0
|
||||
embedding.weight.data[vocab_size:, :] = 0
|
||||
expanded_embedding = VocabParallelEmbedding(
|
||||
512 + lora_config.lora_extra_vocab_size * max_loras,
|
||||
vocab_size + lora_config.lora_extra_vocab_size * max_loras,
|
||||
256,
|
||||
org_num_embeddings=512)
|
||||
expanded_embedding.weight.data[:512, :] = embedding_data
|
||||
org_num_embeddings=vocab_size)
|
||||
expanded_embedding.weight.data[:vocab_size, :] = embedding_data
|
||||
# We need to deepcopy the embedding as it will be modified
|
||||
# in place
|
||||
lora_embedding = VocabParallelEmbeddingWithLoRA(
|
||||
@ -298,7 +303,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
id_to_index,
|
||||
layer=lora_embedding,
|
||||
layer_weights=torch.zeros(
|
||||
(256, 512 + lora_config.lora_extra_vocab_size)),
|
||||
(256, vocab_size + lora_config.lora_extra_vocab_size)),
|
||||
generate_embeddings_tensor=256,
|
||||
)
|
||||
|
||||
@ -316,7 +321,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=list(lora_dict.keys()),
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
@ -327,16 +332,18 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
for input_, original_input_, lora_id in zip(inputs, original_inputs,
|
||||
prompt_mapping):
|
||||
embedding_id = lora_id - 1
|
||||
input_[-1] = 512 + (embedding_id * embeddings_tensor_len)
|
||||
original_input_[-1] = 512
|
||||
input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1)
|
||||
original_input_[-2] = 512 + embeddings_tensor_len - 1
|
||||
input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len)
|
||||
original_input_[-1] = vocab_size
|
||||
input_[-2] = vocab_size + (
|
||||
(embedding_id + 1) * embeddings_tensor_len - 1)
|
||||
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
expanded_embedding.weight[512:512 +
|
||||
expanded_embedding.weight[vocab_size:vocab_size +
|
||||
(embeddings_tensor_len *
|
||||
max_loras)] = torch.cat(embeddings_tensors)
|
||||
|
||||
@ -370,14 +377,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
active_lora_ids=[0],
|
||||
num_inputs=num_loras * 3,
|
||||
input_size=(200, ),
|
||||
input_range=(1, 512),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
original_inputs = deepcopy(inputs)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_embedding(torch.cat(original_inputs))
|
||||
@ -393,7 +401,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
vocab_size) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
lora_dtype=torch.float16)
|
||||
|
||||
def _pretest():
|
||||
linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size,
|
||||
1024, 32000)
|
||||
linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size,
|
||||
1024, vocab_size)
|
||||
linear.weight.data = torch.rand_like(linear.weight.data)
|
||||
linear.weight.data[:, 32000:] = 0
|
||||
linear.weight.data[:, vocab_size:] = 0
|
||||
logits_processor = LogitsProcessor(
|
||||
32000 + lora_config.lora_extra_vocab_size, 32000)
|
||||
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
|
||||
lora_logits_processor = LogitsProcessorWithLoRA(
|
||||
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
|
||||
lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
||||
@ -444,7 +454,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
32000,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
lora_logits_processor.set_mapping(*mapping_info, )
|
||||
@ -460,7 +470,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
org_vocab_size:logits_processor.org_vocab_size +
|
||||
embeddings_tensor_len] = embeddings_tensor
|
||||
|
||||
logits_processor.org_vocab_size = (32000 +
|
||||
logits_processor.org_vocab_size = (vocab_size +
|
||||
lora_config.lora_extra_vocab_size)
|
||||
expected_results = []
|
||||
for input_, lora_id in zip(inputs, prompt_mapping):
|
||||
@ -468,11 +478,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
result = logits_processor._get_logits(hidden_states=input_,
|
||||
embedding=linear.weight,
|
||||
embedding_bias=None)
|
||||
result[:, 32000 + embeddings_tensor_len:] = float("-inf")
|
||||
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
|
||||
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||
expected_results.append(result)
|
||||
expected_result = torch.cat(expected_results)
|
||||
logits_processor.org_vocab_size = 32000
|
||||
logits_processor.org_vocab_size = vocab_size
|
||||
|
||||
# Check that resetting the lora weights succeeds
|
||||
|
||||
@ -489,14 +499,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None:
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
32000,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_logits_processor.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
embedding_bias=None)[:, :32000]
|
||||
embedding_bias=None)[:, :vocab_size]
|
||||
expected_result = logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
embedding=original_weight,
|
||||
|
||||
@ -43,10 +43,52 @@ def _lora_ref_impl(
|
||||
|
||||
|
||||
H1 = H2 = [
|
||||
128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456,
|
||||
3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216,
|
||||
10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512,
|
||||
32768, 33024
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
1024,
|
||||
1152,
|
||||
1280,
|
||||
1536,
|
||||
2048,
|
||||
2304,
|
||||
2560,
|
||||
2752,
|
||||
3072,
|
||||
3456,
|
||||
3584,
|
||||
4096,
|
||||
4608,
|
||||
5120,
|
||||
5504,
|
||||
5632,
|
||||
6144,
|
||||
6848,
|
||||
6912,
|
||||
7168,
|
||||
8192,
|
||||
9216,
|
||||
10240,
|
||||
11008,
|
||||
13824,
|
||||
14336,
|
||||
15360,
|
||||
22016,
|
||||
24576,
|
||||
27392,
|
||||
32000,
|
||||
32256,
|
||||
32512,
|
||||
32768,
|
||||
33024,
|
||||
36864,
|
||||
49152,
|
||||
64000,
|
||||
64256,
|
||||
102400,
|
||||
102656,
|
||||
128000,
|
||||
128256,
|
||||
]
|
||||
SEED = [0xabcdabcd987]
|
||||
CUDA_DEVICES = [
|
||||
|
||||
179
tests/lora/test_quant_model.py
Normal file
179
tests/lora/test_quant_model.py
Normal file
@ -0,0 +1,179 @@
|
||||
# Adapted from
|
||||
# https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/tests/lora/test_llama.py
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from .conftest import cleanup
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelWithQuantization:
|
||||
model_path: str
|
||||
quantization: str
|
||||
|
||||
|
||||
MODELS: List[ModelWithQuantization] = [
|
||||
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
|
||||
quantization="AWQ"),
|
||||
ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
quantization="GPTQ"),
|
||||
]
|
||||
|
||||
|
||||
def do_sample(llm, lora_path: str, lora_id: int, max_tokens=256):
|
||||
raw_prompts = [
|
||||
"Give me an orange-ish brown color",
|
||||
"Give me a neon pink color",
|
||||
]
|
||||
|
||||
def format_prompt_tuples(prompt):
|
||||
return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
prompts = [format_prompt_tuples(p) for p in raw_prompts]
|
||||
|
||||
sampling_params = vllm.SamplingParams(temperature=0,
|
||||
max_tokens=max_tokens,
|
||||
stop=["<|im_end|>"])
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
if lora_id else None)
|
||||
# Print the outputs.
|
||||
generated_texts = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
generated_texts.append(generated_text)
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < tp_size:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
llm = vllm.LLM(model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
max_model_len=400,
|
||||
tensor_parallel_size=tp_size,
|
||||
quantization=model.quantization,
|
||||
trust_remote_code=True)
|
||||
|
||||
if model.quantization is None:
|
||||
expected_no_lora_output = [
|
||||
"Here are some examples of orange-brown colors",
|
||||
"I'm sorry, I don't have"
|
||||
]
|
||||
expected_lora_output = [
|
||||
"#ff8050",
|
||||
"#ff8080",
|
||||
]
|
||||
elif model.quantization == "AWQ":
|
||||
expected_no_lora_output = [
|
||||
"I'm sorry, I don't understand",
|
||||
"I'm sorry, I don't understand",
|
||||
]
|
||||
expected_lora_output = [
|
||||
"#f07700: A v",
|
||||
"#f00000: A v",
|
||||
]
|
||||
elif model.quantization == "GPTQ":
|
||||
expected_no_lora_output = [
|
||||
"I'm sorry, I don't have",
|
||||
"I'm sorry, I don't have",
|
||||
]
|
||||
expected_lora_output = [
|
||||
"#f08800: This is",
|
||||
"#f07788 \n#",
|
||||
]
|
||||
|
||||
def expect_match(output, expected_output):
|
||||
# HACK: GPTQ lora outputs are just incredibly unstable.
|
||||
# Assert that the outputs changed.
|
||||
if (model.quantization == "GPTQ"
|
||||
and expected_output is expected_lora_output):
|
||||
assert output != expected_no_lora_output
|
||||
for i, o in enumerate(output):
|
||||
assert o.startswith(
|
||||
'#'), f"Expected example {i} to start with # but got {o}"
|
||||
return
|
||||
assert output == expected_output
|
||||
|
||||
max_tokens = 10
|
||||
|
||||
print("lora adapter created")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
lora_id=0,
|
||||
max_tokens=max_tokens)
|
||||
expect_match(output, expected_no_lora_output)
|
||||
|
||||
print("lora 1")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
lora_id=1,
|
||||
max_tokens=max_tokens)
|
||||
expect_match(output, expected_lora_output)
|
||||
|
||||
print("no lora")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
lora_id=0,
|
||||
max_tokens=max_tokens)
|
||||
expect_match(output, expected_no_lora_output)
|
||||
|
||||
print("lora 2")
|
||||
output = do_sample(llm,
|
||||
tinyllama_lora_files,
|
||||
lora_id=2,
|
||||
max_tokens=max_tokens)
|
||||
expect_match(output, expected_lora_output)
|
||||
|
||||
print("removing lora")
|
||||
|
||||
del llm
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.skip("Requires multiple GPUs")
|
||||
def test_quant_model_tp_equality(tinyllama_lora_files, model):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < 2:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
|
||||
|
||||
llm_tp1 = vllm.LLM(model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=1,
|
||||
quantization=model.quantization,
|
||||
trust_remote_code=True)
|
||||
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp1
|
||||
cleanup()
|
||||
|
||||
llm_tp2 = vllm.LLM(model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=2,
|
||||
quantization=model.quantization)
|
||||
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp2
|
||||
cleanup()
|
||||
|
||||
assert output_tp1 == output_tp2
|
||||
@ -12,7 +12,7 @@ MODELS = [
|
||||
"gpt2",
|
||||
"bigcode/tiny_starcoder_py",
|
||||
"EleutherAI/pythia-70m",
|
||||
"bigscience/bloom-560m",
|
||||
"bigscience/bloom-560m", # Testing alibi slopes.
|
||||
"microsoft/phi-2",
|
||||
"stabilityai/stablelm-3b-4e1t",
|
||||
# "allenai/OLMo-1B", # Broken
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import itertools
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
@ -194,11 +195,15 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
|
||||
def create_sampling_params(min_tokens,
|
||||
eos_token_id=0,
|
||||
stop_token_ids=None):
|
||||
*,
|
||||
stop_token_ids: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
sampling_params = SamplingParams(
|
||||
min_tokens=min_tokens,
|
||||
max_tokens=9999, # keep higher than max of min_tokens
|
||||
stop_token_ids=stop_token_ids,
|
||||
# requesting prompt_logprobs changes the structure of `logits`
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
sampling_params.eos_token_id = eos_token_id
|
||||
return sampling_params
|
||||
@ -217,9 +222,9 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
|
||||
expected_penalization = []
|
||||
sequence_metadata_list = []
|
||||
# 20% chance to generate seq group metadata list with all prompts
|
||||
is_prompt = random.random() < 0.2
|
||||
while batch_size > 0:
|
||||
# 20% chance to generate prompt seq group with single sequence
|
||||
is_prompt = random.random() < 0.2
|
||||
num_seqs = 1 if is_prompt else random.randint(1, batch_size)
|
||||
|
||||
eos_token_id = random.randint(0, VOCAB_SIZE - 1)
|
||||
@ -240,7 +245,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
seq_group_penalization = []
|
||||
for _ in range(num_seqs):
|
||||
num_input = random.randint(1, 100)
|
||||
num_generated = random.randint(1, 100) if not is_prompt else 0
|
||||
num_generated = 0 if is_prompt else random.randint(1, 100)
|
||||
seq_data[next(seq_id_counter)] = create_sequence_data(
|
||||
num_input=num_input, num_generated=num_generated)
|
||||
seq_group_penalization.append(num_generated < min_tokens)
|
||||
@ -292,6 +297,21 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
]
|
||||
}
|
||||
|
||||
prompt_with_penalization_and_prompt_logprobs = {
|
||||
"expected_penalization": [False, False, True],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=3),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
stop_penalizing_after_min_tokens = {
|
||||
"expected_penalization": [False],
|
||||
"seq_group_metadata_list": [
|
||||
@ -309,8 +329,34 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
}
|
||||
|
||||
stop_token_ids = [42, 99, 42, 0] # intentional duplication
|
||||
simple_combination = {
|
||||
"expected_penalization": [True, False, False],
|
||||
prompt_combination = {
|
||||
"expected_penalization": [False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(num_input=2),
|
||||
},
|
||||
sampling_params=create_sampling_params(1, prompt_logprobs=3),
|
||||
block_tables={},
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_3",
|
||||
is_prompt=True,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
0, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
stop_token_ids = [1, 999, 37, 37] # intentional duplication
|
||||
decode_combination = {
|
||||
"expected_penalization": [True, False, False, True, False],
|
||||
"seq_group_metadata_list": [
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_1",
|
||||
@ -327,14 +373,19 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
),
|
||||
SequenceGroupMetadata(
|
||||
request_id="test_2",
|
||||
is_prompt=True,
|
||||
is_prompt=False,
|
||||
seq_data={
|
||||
next(seq_id_counter): create_sequence_data(),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=20),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=1),
|
||||
next(seq_id_counter):
|
||||
create_sequence_data(num_generated=10),
|
||||
},
|
||||
sampling_params=create_sampling_params(
|
||||
0, stop_token_ids=stop_token_ids),
|
||||
10, prompt_logprobs=5, stop_token_ids=stop_token_ids),
|
||||
block_tables={},
|
||||
)
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
@ -342,8 +393,10 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
test_cases = [
|
||||
prompt_without_penalization,
|
||||
prompt_with_penalization,
|
||||
prompt_with_penalization_and_prompt_logprobs,
|
||||
stop_penalizing_after_min_tokens,
|
||||
simple_combination,
|
||||
prompt_combination,
|
||||
decode_combination,
|
||||
]
|
||||
else:
|
||||
test_cases = [generate_test_case()]
|
||||
@ -351,30 +404,49 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
|
||||
def run_test_case(*,
|
||||
expected_penalization=None,
|
||||
seq_group_metadata_list=None):
|
||||
assert expected_penalization, "Invalid test case"
|
||||
assert seq_group_metadata_list, "Invalid test case"
|
||||
assert expected_penalization, \
|
||||
"Invalid test case, need expected_penalization"
|
||||
assert seq_group_metadata_list, \
|
||||
"Invalid test case, need seq_group_metadata_list"
|
||||
|
||||
batch_size = 0
|
||||
prompt_lens = []
|
||||
sampling_params_per_seq = []
|
||||
sampling_params_per_row = []
|
||||
for sgm in seq_group_metadata_list:
|
||||
num_seqs = len(sgm.seq_data)
|
||||
batch_size += num_seqs
|
||||
sampling_params = sgm.sampling_params
|
||||
for seq_id in sgm.seq_data:
|
||||
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len())
|
||||
sampling_params_per_seq.append(sampling_params)
|
||||
|
||||
num_rows = len(sgm.seq_data)
|
||||
if sgm.is_prompt:
|
||||
# a prompt seq_group has only one sequence
|
||||
seq_data = next(iter(sgm.seq_data.values()))
|
||||
prompt_len = seq_data.get_prompt_len()
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
if sgm.sampling_params.prompt_logprobs:
|
||||
# with prompt_logprobs each token in the prompt has a row in
|
||||
# logits
|
||||
num_rows = prompt_len
|
||||
|
||||
batch_size += num_rows
|
||||
sampling_params_per_row.extend(
|
||||
itertools.repeat(sampling_params, num_rows))
|
||||
|
||||
assert len(
|
||||
expected_penalization
|
||||
) == batch_size, \
|
||||
("Invalid test case, expected_penalization does not match computed"
|
||||
"batch size")
|
||||
|
||||
_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
|
||||
sampling_metadata = model_runner._prepare_sample(
|
||||
seq_group_metadata_list,
|
||||
prompt_lens=prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
prompt_lens=prompt_lens if prompt_lens else None,
|
||||
subquery_lens=prompt_lens if prompt_lens else None)
|
||||
# the logits tensor is modified in-place by the sampler
|
||||
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
|
||||
|
||||
for logits_idx, (should_penalize, sampling_params) in enumerate(
|
||||
zip(expected_penalization, sampling_params_per_seq)):
|
||||
zip(expected_penalization, sampling_params_per_row)):
|
||||
|
||||
tokens_to_check = [sampling_params.eos_token_id]
|
||||
if sampling_params.stop_token_ids:
|
||||
|
||||
245
tests/tensorizer/tensorize_vllm_model_for_testing.py
Normal file
245
tests/tensorizer/tensorize_vllm_model_for_testing.py
Normal file
@ -0,0 +1,245 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
||||
TensorSerializer, stream_io)
|
||||
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from vllm.distributed import initialize_model_parallel
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
|
||||
# yapf conflicts with isort for this docstring
|
||||
# yapf: disable
|
||||
"""
|
||||
tensorize_vllm_model.py is a script that can be used to serialize and
|
||||
deserialize vLLM models. These models can be loaded using tensorizer directly
|
||||
to the GPU extremely quickly. Tensor encryption and decryption is also
|
||||
supported, although libsodium must be installed to use it. Install
|
||||
vllm with tensorizer support using `pip install vllm[tensorizer]`.
|
||||
|
||||
To serialize a model, you can run something like this:
|
||||
|
||||
python tensorize_vllm_model.py \
|
||||
--model EleutherAI/gpt-j-6B \
|
||||
--dtype float16 \
|
||||
serialize \
|
||||
--serialized-directory s3://my-bucket/ \
|
||||
--suffix vllm
|
||||
|
||||
Which downloads the model from HuggingFace, loads it into vLLM, serializes it,
|
||||
and saves it to your S3 bucket. A local directory can also be used.
|
||||
|
||||
You can also encrypt the model weights with a randomly-generated key by
|
||||
providing a `--keyfile` argument.
|
||||
|
||||
To deserialize a model, you can run something like this:
|
||||
|
||||
python tensorize_vllm_model.py \
|
||||
--model EleutherAI/gpt-j-6B \
|
||||
--dtype float16 \
|
||||
deserialize \
|
||||
--path-to-tensors s3://my-bucket/vllm/EleutherAI/gpt-j-6B/vllm/model.tensors
|
||||
|
||||
Which downloads the model tensors from your S3 bucket and deserializes them.
|
||||
To provide S3 credentials, you can provide `--s3-access-key-id` and
|
||||
`--s3-secret-access-key`, as well as `--s3-endpoint` as CLI args to this script,
|
||||
the OpenAI entrypoint, as arguments for LLM(), or as environment variables
|
||||
in the form of `S3_ACCESS_KEY_ID`, `S3_SECRET_ACCESS_KEY`, and `S3_ENDPOINT`.
|
||||
|
||||
|
||||
You can also provide a `--keyfile` argument to decrypt the model weights if
|
||||
they were serialized with encryption.
|
||||
|
||||
For more information on the available arguments, run
|
||||
`python tensorize_vllm_model.py --help`.
|
||||
"""
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="An example script that can be used to serialize and "
|
||||
"deserialize vLLM models. These models "
|
||||
"can be loaded using tensorizer directly to the GPU "
|
||||
"extremely quickly. Tensor encryption and decryption is "
|
||||
"also supported, although libsodium must be installed to "
|
||||
"use it.")
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
subparsers = parser.add_subparsers(dest='command')
|
||||
|
||||
serialize_parser = subparsers.add_parser(
|
||||
'serialize', help="Serialize a model to `--serialized-directory`")
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--suffix",
|
||||
type=str,
|
||||
required=False,
|
||||
help=(
|
||||
"The suffix to append to the serialized model directory, which is "
|
||||
"used to construct the location of the serialized model tensors, "
|
||||
"e.g. if `--serialized-directory` is `s3://my-bucket/` and "
|
||||
"`--suffix` is `v1`, the serialized model tensors will be "
|
||||
"saved to "
|
||||
"`s3://my-bucket/vllm/EleutherAI/gpt-j-6B/v1/model.tensors`. "
|
||||
"If none is provided, a random UUID will be used."))
|
||||
serialize_parser.add_argument(
|
||||
"--serialized-directory",
|
||||
type=str,
|
||||
required=True)
|
||||
|
||||
serialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
required=False,
|
||||
help=("Encrypt the model weights with a randomly-generated binary key,"
|
||||
" and save the key at this path"))
|
||||
|
||||
deserialize_parser = subparsers.add_parser(
|
||||
'deserialize',
|
||||
help=("Deserialize a model from `--path-to-tensors`"
|
||||
" to verify it can be loaded and used."))
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--path-to-tensors",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The local path or S3 URI to the model tensors to deserialize. ")
|
||||
|
||||
deserialize_parser.add_argument(
|
||||
"--keyfile",
|
||||
type=str,
|
||||
required=False,
|
||||
help=("Path to a binary key to use to decrypt the model weights,"
|
||||
" if the model was serialized with encryption"))
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_model_contiguous(model):
|
||||
# Ensure tensors are saved in memory contiguously
|
||||
for param in model.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
|
||||
def _get_vllm_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry.load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return model_cls
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
|
||||
|
||||
def serialize():
|
||||
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||
dataclasses.fields(EngineArgs)}
|
||||
engine_args = EngineArgs.from_cli_args(argparse.Namespace(**eng_args_dict))
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
|
||||
model = (engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
|
||||
encryption_params = EncryptionParams.random() if keyfile else None
|
||||
if keyfile:
|
||||
with _write_stream(keyfile) as stream:
|
||||
stream.write(encryption_params.key)
|
||||
|
||||
with _write_stream(model_path) as stream:
|
||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||
serializer.write_module(model)
|
||||
serializer.close()
|
||||
|
||||
print("Serialization complete. Model tensors saved to", model_path)
|
||||
if keyfile:
|
||||
print("Key saved to", keyfile)
|
||||
|
||||
|
||||
def deserialize():
|
||||
config = AutoConfig.from_pretrained(model_ref)
|
||||
|
||||
with no_init_or_tensor():
|
||||
model_class = _get_vllm_model_architecture(config)
|
||||
model = model_class(config)
|
||||
|
||||
before_mem = get_mem_usage()
|
||||
start = time.time()
|
||||
|
||||
if keyfile:
|
||||
with _read_stream(keyfile) as stream:
|
||||
key = stream.read()
|
||||
decryption_params = DecryptionParams.from_key(key)
|
||||
tensorizer_args.deserializer_params['encryption'] = \
|
||||
decryption_params
|
||||
|
||||
with (_read_stream(model_path)) as stream, TensorDeserializer(
|
||||
stream, **tensorizer_args.deserializer_params) as deserializer:
|
||||
deserializer.load_into_module(model)
|
||||
end = time.time()
|
||||
|
||||
# Brag about how fast we are.
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
print(
|
||||
f"Deserialized {total_bytes_str} in {end - start:0.2f}s, {per_second}/s"
|
||||
)
|
||||
print(f"Memory usage before: {before_mem}")
|
||||
print(f"Memory usage after: {after_mem}")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
s3_access_key_id = (args.s3_access_key_id or os.environ.get("S3_ACCESS_KEY_ID")
|
||||
or None)
|
||||
s3_secret_access_key = (args.s3_secret_access_key
|
||||
or os.environ.get("S3_SECRET_ACCESS_KEY") or None)
|
||||
|
||||
s3_endpoint = (args.s3_endpoint or os.environ.get("S3_ENDPOINT_URL") or None)
|
||||
|
||||
_read_stream, _write_stream = (partial(
|
||||
stream_io.open_stream,
|
||||
mode=mode,
|
||||
s3_access_key_id=s3_access_key_id,
|
||||
s3_secret_access_key=s3_secret_access_key,
|
||||
s3_endpoint=s3_endpoint,
|
||||
) for mode in ("rb", "wb+"))
|
||||
|
||||
model_ref = args.model
|
||||
|
||||
model_name = model_ref.split("/")[1]
|
||||
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "8080"
|
||||
|
||||
torch.distributed.init_process_group(world_size=1, rank=0)
|
||||
initialize_model_parallel()
|
||||
|
||||
keyfile = args.keyfile if args.keyfile else None
|
||||
|
||||
if args.command == "serialize":
|
||||
input_dir = args.serialized_directory.rstrip('/')
|
||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||
base_path = f"{input_dir}/vllm/{model_ref}/{suffix}"
|
||||
model_path = f"{base_path}/model.tensors"
|
||||
serialize()
|
||||
elif args.command == "deserialize":
|
||||
tensorizer_args = TensorizerArgs.from_cli_args(args)
|
||||
model_path = args.path_to_tensors
|
||||
deserialize()
|
||||
else:
|
||||
raise ValueError("Either serialize or deserialize must be specified.")
|
||||
302
tests/tensorizer/test_tensorizer.py
Normal file
302
tests/tensorizer/test_tensorizer.py
Normal file
@ -0,0 +1,302 @@
|
||||
import gc
|
||||
import subprocess
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.entrypoints.test_openai_server import ServerRunner
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import TensorizerConfig
|
||||
from vllm.model_executor.tensorizer_loader import (
|
||||
EncryptionParams, TensorSerializer, is_vllm_serialized_tensorizer,
|
||||
load_with_tensorizer, open_stream)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
||||
|
||||
model_ref = "facebook/opt-125m"
|
||||
|
||||
|
||||
def is_curl_installed():
|
||||
try:
|
||||
subprocess.check_call(['curl', '--version'])
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def tensorizer_config():
|
||||
config = TensorizerConfig(tensorizer_uri="vllm", vllm_tensorized=True)
|
||||
return config
|
||||
|
||||
|
||||
@patch('vllm.model_executor.tensorizer_loader.TensorizerAgent')
|
||||
def test_load_with_tensorizer(mock_agent, tensorizer_config):
|
||||
mock_linear_method = MagicMock()
|
||||
mock_agent_instance = mock_agent.return_value
|
||||
mock_agent_instance.deserialize.return_value = MagicMock()
|
||||
|
||||
result = load_with_tensorizer(tensorizer_config,
|
||||
linear_method=mock_linear_method)
|
||||
|
||||
mock_agent.assert_called_once_with(tensorizer_config,
|
||||
linear_method=mock_linear_method)
|
||||
mock_agent_instance.deserialize.assert_called_once()
|
||||
assert result == mock_agent_instance.deserialize.return_value
|
||||
|
||||
|
||||
def test_is_vllm_model_with_vllm_in_uri(tensorizer_config):
|
||||
tensorizer_config.vllm_tensorized = True
|
||||
|
||||
result = is_vllm_serialized_tensorizer(tensorizer_config)
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_is_vllm_model_without_vllm_in_uri(tensorizer_config):
|
||||
tensorizer_config.vllm_tensorized = False
|
||||
|
||||
result = is_vllm_serialized_tensorizer(tensorizer_config)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_deserialized_vllm_model_has_same_outputs(vllm_runner, tmp_path):
|
||||
vllm_model = vllm_runner(model_ref)
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
outputs = vllm_model.generate(prompts, sampling_params)
|
||||
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
with open_stream(model_path, "wb+") as stream:
|
||||
serializer = TensorSerializer(stream)
|
||||
serializer.write_module(model)
|
||||
del vllm_model, model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loaded_vllm_model = vllm_runner(model_ref,
|
||||
load_format="tensorizer",
|
||||
tensorizer_uri=model_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=True)
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
# Assumes SamplingParams being seeded ensures the outputs are deterministic
|
||||
assert outputs == deserialized_outputs
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_can_deserialize_s3(vllm_runner):
|
||||
model_ref = "EleutherAI/pythia-1.4b"
|
||||
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
|
||||
|
||||
loaded_hf_model = vllm_runner(
|
||||
model_ref,
|
||||
tensorizer_uri=tensorized_path,
|
||||
load_format="tensorizer",
|
||||
num_readers=1,
|
||||
vllm_tensorized=False,
|
||||
s3_endpoint="object.ord1.coreweave.com",
|
||||
)
|
||||
|
||||
deserialized_outputs = loaded_hf_model.generate(prompts, sampling_params)
|
||||
|
||||
assert deserialized_outputs
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
||||
vllm_runner, tmp_path):
|
||||
vllm_model = vllm_runner(model_ref)
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
key_path = tmp_path / (model_ref + ".key")
|
||||
outputs = vllm_model.generate(prompts, sampling_params)
|
||||
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
|
||||
encryption_params = EncryptionParams.random()
|
||||
with open_stream(model_path, "wb+") as stream:
|
||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||
serializer.write_module(model)
|
||||
with open_stream(key_path, "wb+") as stream:
|
||||
stream.write(encryption_params.key)
|
||||
del vllm_model, model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loaded_vllm_model = vllm_runner(model_ref,
|
||||
tensorizer_uri=model_path,
|
||||
load_format="tensorizer",
|
||||
encryption_keyfile=key_path,
|
||||
num_readers=1,
|
||||
vllm_tensorized=True)
|
||||
|
||||
deserialized_outputs = loaded_vllm_model.generate(prompts, sampling_params)
|
||||
|
||||
# Assumes SamplingParams being seeded ensures the outputs are deterministic
|
||||
assert outputs == deserialized_outputs
|
||||
|
||||
|
||||
def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner,
|
||||
tmp_path):
|
||||
hf_model = hf_runner(model_ref)
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
max_tokens = 50
|
||||
outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens)
|
||||
with open_stream(model_path, "wb+") as stream:
|
||||
serializer = TensorSerializer(stream)
|
||||
serializer.write_module(hf_model.model)
|
||||
del hf_model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loaded_hf_model = vllm_runner(model_ref,
|
||||
tensorizer_uri=model_path,
|
||||
load_format="tensorizer",
|
||||
num_readers=1,
|
||||
vllm_tensorized=False)
|
||||
|
||||
deserialized_outputs = loaded_hf_model.generate_greedy(
|
||||
prompts, max_tokens=max_tokens)
|
||||
|
||||
assert outputs == deserialized_outputs
|
||||
|
||||
|
||||
def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from examples.multilora_inference import (create_test_prompts,
|
||||
process_requests)
|
||||
|
||||
model_ref = "meta-llama/Llama-2-7b-hf"
|
||||
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
||||
test_prompts = create_test_prompts(lora_path)
|
||||
|
||||
# Serialize model before deserializing and binding LoRA adapters
|
||||
vllm_model = vllm_runner(model_ref, )
|
||||
model_path = tmp_path / (model_ref + ".tensors")
|
||||
model = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||
model_runner.model)
|
||||
with open_stream(model_path, "wb+") as stream:
|
||||
serializer = TensorSerializer(stream)
|
||||
serializer.write_module(model)
|
||||
del vllm_model, model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
loaded_vllm_model = vllm_runner(
|
||||
model_ref,
|
||||
tensorizer_uri=model_path,
|
||||
load_format="tensorizer",
|
||||
num_readers=1,
|
||||
vllm_tensorized=True,
|
||||
enable_lora=True,
|
||||
max_loras=1,
|
||||
max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_num_seqs=50,
|
||||
max_model_len=1000,
|
||||
)
|
||||
process_requests(loaded_vllm_model.model.llm_engine, test_prompts)
|
||||
|
||||
assert loaded_vllm_model
|
||||
|
||||
|
||||
def test_load_without_tensorizer_load_format(vllm_runner):
|
||||
with pytest.raises(ValueError):
|
||||
vllm_runner(model_ref, tensorizer_uri="test")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_tensorize_vllm_model(tmp_path):
|
||||
# Test serialize command
|
||||
serialize_args = [
|
||||
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||
model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
|
||||
tmp_path, "--suffix", "tests"
|
||||
]
|
||||
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
||||
print(result.stdout) # Print the output of the serialize command
|
||||
|
||||
assert result.returncode == 0, (f"Serialize command failed with output:"
|
||||
f"\n{result.stdout}\n{result.stderr}")
|
||||
|
||||
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
|
||||
|
||||
# Test deserialize command
|
||||
deserialize_args = [
|
||||
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||
model_ref, "--dtype", "float16", "deserialize", "--path-to-tensors",
|
||||
path_to_tensors
|
||||
]
|
||||
result = subprocess.run(deserialize_args, capture_output=True, text=True)
|
||||
assert result.returncode == 0, (f"Deserialize command failed with output:"
|
||||
f"\n{result.stdout}\n{result.stderr}")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_openai_apiserver_with_tensorizer(tmp_path):
|
||||
## Serialize model
|
||||
serialize_args = [
|
||||
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||
model_ref, "--dtype", "float16", "serialize", "--serialized-directory",
|
||||
tmp_path, "--suffix", "tests"
|
||||
]
|
||||
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
||||
print(result.stdout) # Print the output of the serialize command
|
||||
|
||||
assert result.returncode == 0, (f"Serialize command failed with output:"
|
||||
f"\n{result.stdout}\n{result.stderr}")
|
||||
|
||||
path_to_tensors = f"{tmp_path}/vllm/{model_ref}/tests/model.tensors"
|
||||
|
||||
## Start OpenAI API server
|
||||
openai_args = [
|
||||
"--model", model_ref, "--dtype", "float16", "--load-format",
|
||||
"tensorizer", "--tensorizer-uri", path_to_tensors, "--vllm-tensorized",
|
||||
"--port", "8000"
|
||||
]
|
||||
|
||||
server = ServerRunner.remote(openai_args)
|
||||
|
||||
print("Server ready.")
|
||||
assert server.ready.remote()
|
||||
|
||||
|
||||
def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
||||
with pytest.raises(ValueError):
|
||||
vllm_runner(model_ref,
|
||||
load_format="safetensors",
|
||||
tensorizer_uri="test")
|
||||
|
||||
|
||||
def test_tensorizer_with_tp(vllm_runner):
|
||||
with pytest.raises(ValueError):
|
||||
model_ref = "EleutherAI/pythia-1.4b"
|
||||
tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors"
|
||||
|
||||
vllm_runner(
|
||||
model_ref,
|
||||
tensorizer_uri=tensorized_path,
|
||||
load_format="tensorizer",
|
||||
num_readers=1,
|
||||
vllm_tensorized=False,
|
||||
s3_endpoint="object.ord1.coreweave.com",
|
||||
tensor_parallel_size=2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed")
|
||||
def test_tensorizer_warn_quant(tmp_path):
|
||||
model_ref = "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"
|
||||
serialize_args = [
|
||||
"python3", "tensorizer/tensorize_vllm_model_for_testing.py", "--model",
|
||||
model_ref, "--quantization", "gptq", "--tensorizer-uri", "test",
|
||||
"serialize", "--serialized-directory", tmp_path, "--suffix", "tests"
|
||||
]
|
||||
result = subprocess.run(serialize_args, capture_output=True, text=True)
|
||||
assert 'PerformanceWarning' in result.stderr
|
||||
@ -1,14 +1,18 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_prompt(batch_size):
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
scheduler_config = SchedulerConfig(100000,
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=False)
|
||||
model_runner = ModelRunner(None, None, scheduler_config, None, None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
@ -36,8 +40,10 @@ def test_prepare_prompt(batch_size):
|
||||
prompt_len - 1)
|
||||
selected_token_start_idx += prompt_len
|
||||
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
|
||||
_, _) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
_, _,
|
||||
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert return_prompt_lens == prompt_lens
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
@ -45,8 +51,6 @@ def test_prepare_prompt(batch_size):
|
||||
assert torch.allclose(attn_metadata.prompt_lens_tensor,
|
||||
torch.tensor(prompt_lens, device=device))
|
||||
assert attn_metadata.prompt_lens == prompt_lens
|
||||
assert attn_metadata.num_prompt_tokens == sum(prompt_lens)
|
||||
assert attn_metadata.num_generation_tokens == 0
|
||||
assert attn_metadata.max_prompt_len == max(prompt_lens)
|
||||
|
||||
# Test subquery start locs.
|
||||
@ -83,23 +87,22 @@ def test_prepare_prompt(batch_size):
|
||||
assert torch.allclose(attn_metadata.block_tables, expected)
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert attn_metadata.use_cuda_graph is False
|
||||
assert attn_metadata.kv_cache_dtype == "auto"
|
||||
|
||||
assert input_tokens.shape == (sum(prompt_lens), )
|
||||
assert input_positions.shape == (sum(prompt_lens), )
|
||||
assert len(input_tokens) == sum(prompt_lens)
|
||||
assert len(input_positions) == sum(prompt_lens)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
|
||||
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
|
||||
prompt_lens,
|
||||
subquery_lens=prompt_lens)
|
||||
assert input_tokens.shape == (sum(prompt_lens), )
|
||||
assert input_positions.shape == (sum(prompt_lens), )
|
||||
assert len(input_tokens) == sum(prompt_lens)
|
||||
assert len(input_positions) == sum(prompt_lens)
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
assert input_tokens == input_positions
|
||||
|
||||
actual = sampling_metadata.selected_token_indices
|
||||
expected = torch.tensor(expected_selected_token_indices,
|
||||
@ -122,7 +125,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
revision=None,
|
||||
enforce_eager=False,
|
||||
)
|
||||
model_runner = ModelRunner(model_config, None, None, None, None)
|
||||
scheduler_config = SchedulerConfig(100000,
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=False)
|
||||
model_runner = ModelRunner(model_config, None, scheduler_config, None,
|
||||
None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
prompt_lens = []
|
||||
@ -143,16 +151,15 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
input_tokens, input_positions, attn_metadata, _, _, _ = (
|
||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
|
||||
# Verify input metadata is correct for prompts.
|
||||
device = model_runner.device
|
||||
assert attn_metadata.is_prompt is False
|
||||
assert attn_metadata.prompt_lens is None
|
||||
assert attn_metadata.num_prompt_tokens == 0
|
||||
assert attn_metadata.num_generation_tokens == expected_bs
|
||||
assert attn_metadata.max_prompt_len is None
|
||||
assert attn_metadata.subquery_start_loc is None
|
||||
assert attn_metadata.seq_start_loc is None
|
||||
@ -170,11 +177,10 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
model_runner.get_max_block_per_batch())
|
||||
# Cuda graph should not be used for prerill.
|
||||
assert attn_metadata.use_cuda_graph is True
|
||||
assert attn_metadata.kv_cache_dtype == "auto"
|
||||
|
||||
assert input_tokens.shape == (expected_bs, )
|
||||
assert input_positions.shape == (expected_bs, )
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
assert len(input_tokens) == expected_bs
|
||||
assert len(input_positions) == expected_bs
|
||||
assert input_tokens == input_positions
|
||||
|
||||
# Verify Sampling
|
||||
expected_selected_token_indices = []
|
||||
@ -190,3 +196,148 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
device=actual.device,
|
||||
dtype=actual.dtype)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_empty_seq_group():
|
||||
"""Verify prepare prompt and decode returns empty output."""
|
||||
model_config = ModelConfig(
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
enforce_eager=False,
|
||||
)
|
||||
model_runner = ModelRunner(model_config, None, None, None, None)
|
||||
model_runner.set_block_size(16)
|
||||
seq_group_metadata_list = []
|
||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
|
||||
(input_tokens, input_positions, attn_metadata, return_prompt_lens, _, _, _,
|
||||
_, _,
|
||||
slot_mapping) = (model_runner._prepare_prompt(seq_group_metadata_list))
|
||||
assert len(input_tokens) == 0
|
||||
assert len(input_positions) == 0
|
||||
assert attn_metadata is None
|
||||
assert len(slot_mapping) == 0
|
||||
assert len(return_prompt_lens) == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
|
||||
|
||||
def get_world_size(group=None):
|
||||
return 1
|
||||
|
||||
def mock_get_process_group_ranks(group=None):
|
||||
return [0]
|
||||
|
||||
monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size)
|
||||
monkeypatch.setattr(torch.distributed, "get_process_group_ranks",
|
||||
mock_get_process_group_ranks)
|
||||
|
||||
model_config = ModelConfig(
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
download_dir=None,
|
||||
load_format="dummy",
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(100000,
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=True)
|
||||
model_runner = ModelRunner(model_config,
|
||||
None,
|
||||
scheduler_config,
|
||||
None,
|
||||
None,
|
||||
is_driver_worker=True)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
# Add prefill requests.
|
||||
prompt_lens = []
|
||||
seq_group_metadata_list = []
|
||||
prefill_metadata_list = []
|
||||
decode_metadata_list = []
|
||||
block_tables = {0: [1]}
|
||||
prefill_batch_size = batch_size // 2
|
||||
decode_batch_size = batch_size - prefill_batch_size
|
||||
for i in range(prefill_batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_lens.append(prompt_len)
|
||||
seq_data = SequenceData(list(range(prompt_len)))
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables=block_tables,
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == seq_data.get_len()
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
prefill_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# Add decode requests
|
||||
for i in range(prefill_batch_size, batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
prompt_len = i % (model_runner.block_size - 1) + 1
|
||||
prompt_toks = list(range(prompt_len))
|
||||
seq_data = SequenceData(prompt_toks)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
seq_data={0: seq_data},
|
||||
sampling_params=SamplingParams(temperature=0),
|
||||
block_tables={0: [1]},
|
||||
)
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
decode_metadata_list.append(seq_group_metadata)
|
||||
|
||||
(input_tokens, input_positions, attn_metadata, _, _, _,
|
||||
_) = model_runner.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
prefill_meta_actual = attn_metadata.prefill_metadata
|
||||
decode_meta_actual = attn_metadata.decode_metadata
|
||||
|
||||
assert len(attn_metadata.slot_mapping) == len(input_tokens)
|
||||
assert len(input_positions) == len(input_tokens)
|
||||
assert attn_metadata.kv_cache_dtype == "auto"
|
||||
assert attn_metadata.num_prefills == prefill_batch_size
|
||||
if enforce_eager:
|
||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||
else:
|
||||
assert attn_metadata.num_decode_tokens == _get_graph_batch_size(
|
||||
decode_batch_size)
|
||||
assert attn_metadata.num_prefill_tokens == sum(prompt_lens)
|
||||
|
||||
# Verify attn metadata is consistent. We don't need to test individual
|
||||
# values here because they are tested above.
|
||||
prefill_meta = model_runner._prepare_prompt(
|
||||
prefill_metadata_list).attn_metadata
|
||||
decode_meta = model_runner._prepare_decode(
|
||||
decode_metadata_list).attn_metadata
|
||||
|
||||
for attr_expected, attr_actual in zip(vars(prefill_meta),
|
||||
vars(prefill_meta_actual)):
|
||||
assert attr_expected[1] == attr_actual[1]
|
||||
for attr_expected, attr_actual in zip(vars(decode_meta),
|
||||
vars(decode_meta_actual)):
|
||||
assert attr_expected[1] == attr_actual[1]
|
||||
|
||||
193
vllm/_custom_ops.py
Normal file
193
vllm/_custom_ops.py
Normal file
@ -0,0 +1,193 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from vllm._C import cache_ops as vllm_cache_ops
|
||||
from vllm._C import ops as vllm_ops
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# activation ops
|
||||
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
vllm_ops.silu_and_mul(out, x)
|
||||
|
||||
|
||||
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
vllm_ops.gelu_and_mul(out, x)
|
||||
|
||||
|
||||
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
vllm_ops.gelu_tanh_and_mul(out, x)
|
||||
|
||||
|
||||
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
vllm_ops.gelu_fast(out, x)
|
||||
|
||||
|
||||
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
|
||||
vllm_ops.gelu_new(out, x)
|
||||
|
||||
|
||||
# page attention ops
|
||||
def paged_attention_v1(
|
||||
out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
max_context_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
kv_scale: float,
|
||||
) -> None:
|
||||
vllm_ops.paged_attention_v1(out, query, key_cache, value_cache,
|
||||
num_kv_heads, scale, block_tables,
|
||||
context_lens, block_size, max_context_len,
|
||||
alibi_slopes, kv_cache_dtype, kv_scale)
|
||||
|
||||
|
||||
def paged_attention_v2(
|
||||
out: torch.Tensor,
|
||||
exp_sum: torch.Tensor,
|
||||
max_logits: torch.Tensor,
|
||||
tmp_out: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
max_context_len: int,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
kv_cache_dtype: str,
|
||||
kv_scale: float,
|
||||
) -> None:
|
||||
vllm_ops.paged_attention_v2(out, exp_sum, max_logits, tmp_out, query,
|
||||
key_cache, value_cache, num_kv_heads, scale,
|
||||
block_tables, context_lens, block_size,
|
||||
max_context_len, alibi_slopes, kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
|
||||
# pos encoding ops
|
||||
def rotary_embedding(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
vllm_ops.rotary_embedding(positions, query, key, head_size, cos_sin_cache,
|
||||
is_neox)
|
||||
|
||||
|
||||
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
|
||||
key: torch.Tensor, head_size: int,
|
||||
cos_sin_cache: torch.Tensor, is_neox: bool,
|
||||
rot_dim: int,
|
||||
cos_sin_cache_offsets: torch.Tensor) -> None:
|
||||
vllm_ops.batched_rotary_embedding(positions, query, key, head_size,
|
||||
cos_sin_cache, is_neox, rot_dim,
|
||||
cos_sin_cache_offsets)
|
||||
|
||||
|
||||
# layer norm ops
|
||||
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
|
||||
epsilon: float) -> None:
|
||||
vllm_ops.rms_norm(out, input, weight, epsilon)
|
||||
|
||||
|
||||
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
|
||||
weight: torch.Tensor, epsilon: float) -> None:
|
||||
vllm_ops.fused_add_rms_norm(input, residual, weight, epsilon)
|
||||
|
||||
|
||||
# quantization ops
|
||||
# awq
|
||||
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
|
||||
zeros: torch.Tensor, split_k_iters: int, thx: int,
|
||||
thy: int) -> torch.Tensor:
|
||||
return vllm_ops.awq_dequantize(qweight, scales, zeros, split_k_iters, thx,
|
||||
thy)
|
||||
|
||||
|
||||
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
|
||||
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
|
||||
return vllm_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
|
||||
|
||||
|
||||
# gptq
|
||||
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
|
||||
b_g_idx: torch.Tensor, use_exllama: bool,
|
||||
bit: int) -> torch.Tensor:
|
||||
return vllm_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
|
||||
b_g_idx, use_exllama, bit)
|
||||
|
||||
|
||||
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
|
||||
bit: int) -> None:
|
||||
vllm_ops.gptq_shuffle(q_weight, q_perm, bit)
|
||||
|
||||
|
||||
# squeezellm
|
||||
def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor,
|
||||
lookup_table: torch.Tensor) -> None:
|
||||
vllm_ops.squeezellm_gemm(vec, mat, mul, lookup_table)
|
||||
|
||||
|
||||
# marlin
|
||||
def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int,
|
||||
size_n: int, size_k: int) -> torch.Tensor:
|
||||
return vllm_ops.marlin_gemm(a, b_q_weight, b_scales, workspace, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
|
||||
# moe
|
||||
def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int,
|
||||
block_size: int, sorted_token_ids: torch.Tensor,
|
||||
experts_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor) -> None:
|
||||
vllm_ops.moe_align_block_size(topk_ids, num_experts, block_size,
|
||||
sorted_token_ids, experts_ids,
|
||||
num_tokens_post_pad)
|
||||
|
||||
|
||||
def reshape_and_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
kv_scale: float,
|
||||
) -> None:
|
||||
vllm_cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
|
||||
slot_mapping, kv_cache_dtype, kv_scale)
|
||||
|
||||
|
||||
def copy_blocks(key_caches: torch.Tensor, value_caches: torch.Tensor,
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
vllm_cache_ops.copy_blocks(key_caches, value_caches, block_mapping)
|
||||
|
||||
|
||||
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
|
||||
block_mapping: Dict[int, int]) -> None:
|
||||
vllm_cache_ops.swap_blocks(src, dst, block_mapping)
|
||||
|
||||
|
||||
def convert_fp8(output: torch.Tensor, input: torch.Tensor) -> None:
|
||||
vllm_cache_ops.convert_fp8(output, input)
|
||||
|
||||
|
||||
#TODO: cuda_utils, custom_ar
|
||||
@ -1,5 +1,6 @@
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
@ -8,4 +9,5 @@ __all__ = [
|
||||
"AttentionMetadata",
|
||||
"Attention",
|
||||
"get_attn_backend",
|
||||
"AttentionMetadataPerStage",
|
||||
]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
@ -47,7 +47,8 @@ class AttentionBackend(ABC):
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata:
|
||||
class AttentionMetadataPerStage:
|
||||
"""Attention metadata for a specific stage. I.e., prefill or decode."""
|
||||
|
||||
def asdict_zerocopy(self) -> Dict[str, Any]:
|
||||
"""Similar to dataclasses.asdict, but avoids deepcopying."""
|
||||
@ -59,6 +60,41 @@ class AttentionMetadata:
|
||||
}
|
||||
|
||||
|
||||
T = TypeVar("T", bound=AttentionMetadataPerStage)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata(Generic[T]):
|
||||
"""Attention metadata for prefill and decode batched together."""
|
||||
# Total number of prefill requests.
|
||||
num_prefills: int
|
||||
# Number of prefill tokens.
|
||||
num_prefill_tokens: int
|
||||
# Number of decode tokens. Note that it is equivalent to the number of
|
||||
# decode requests.
|
||||
num_decode_tokens: int
|
||||
# The attention metadata for prefill requests in a batch.
|
||||
# None if there's no prefill requests in a batch.
|
||||
prefill_metadata: Optional[T]
|
||||
# The attention metadata for decode requests in a batch.
|
||||
# None if there's no decode requests in a batch.
|
||||
decode_metadata: Optional[T]
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
# The kv cache's data type.
|
||||
kv_cache_dtype: str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_prefill_tokens > 0:
|
||||
assert self.num_prefills > 0
|
||||
assert self.prefill_metadata is not None
|
||||
if self.num_decode_tokens > 0:
|
||||
assert self.decode_metadata is not None
|
||||
|
||||
|
||||
class AttentionImpl(ABC):
|
||||
|
||||
@abstractmethod
|
||||
@ -80,7 +116,7 @@ class AttentionImpl(ABC):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -11,7 +11,8 @@ import torch
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
|
||||
@ -53,7 +54,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class FlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
PagedAttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
@ -68,10 +70,6 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# The number of prompt tokens. Doesn't include padding.
|
||||
num_prompt_tokens: int
|
||||
# The number of generation tokens. Doesn't include padding.
|
||||
num_generation_tokens: int
|
||||
|
||||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||
# |---------- N-1 iteration --------|
|
||||
@ -107,18 +105,27 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class FlashAttentionImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens -------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
||||
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -155,7 +162,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
attn_metadata: AttentionMetadata[FlashAttentionMetadata],
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
@ -188,52 +195,70 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
attn_metadata.kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
# normal attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
output = flash_attn_varlen_func(
|
||||
out = flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=attn_metadata.seq_start_loc,
|
||||
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_prompt_len,
|
||||
max_seqlen_k=attn_metadata.max_prompt_len,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prompt_len,
|
||||
max_seqlen_k=prefill_meta.max_prompt_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
# TODO(Hai) this triton kernel has regression issue (broke) to
|
||||
# deal with different data types between KV and FP8 KV cache,
|
||||
# to be addressed separately.
|
||||
output = PagedAttention.forward_prefix(
|
||||
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.subquery_start_loc,
|
||||
attn_metadata.prompt_lens_tensor,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_subquery_len,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.prompt_lens_tensor,
|
||||
prefill_meta.context_lens,
|
||||
prefill_meta.max_subquery_len,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_context_len,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.context_lens,
|
||||
decode_meta.max_context_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
|
||||
@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
@ -51,7 +52,8 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class ROCmFlashAttentionMetadata(AttentionMetadataPerStage,
|
||||
PagedAttentionMetadata):
|
||||
"""Metadata for FlashAttentionBackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
@ -66,10 +68,6 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# The number of prompt tokens. Doesn't include padding.
|
||||
num_prompt_tokens: int
|
||||
# The number of generation tokens. Doesn't include padding.
|
||||
num_generation_tokens: int
|
||||
|
||||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||
# |---------- N-1 iteration --------|
|
||||
@ -117,6 +115,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->|
|
||||
|<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -155,7 +162,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
# AMD Radeon 7900 series (gfx1100) currently does not support
|
||||
# xFormers nor FlashAttention. As a temporary workaround, we use
|
||||
# naive PyTorch implementation of attention.
|
||||
self.attn_fuc = _naive_attention()
|
||||
self.attn_fuc = _naive_attention
|
||||
logger.debug("Using naive attention in ROCmBackend")
|
||||
elif self.use_triton_flash_attn:
|
||||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
||||
@ -181,7 +188,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: ROCmFlashAttentionMetadata,
|
||||
attn_metadata: AttentionMetadata[ROCmFlashAttentionMetadata],
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention and PagedAttention.
|
||||
@ -218,9 +225,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
kv_scale,
|
||||
)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
# triton attention
|
||||
# When block_tables are not filled, it means q and k are the
|
||||
# prompt, and they have the same length.
|
||||
@ -230,63 +253,69 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
key = self.repeat_kv(key, self.num_queries_per_kv)
|
||||
value = self.repeat_kv(value, self.num_queries_per_kv)
|
||||
if self.use_naive_attn:
|
||||
output = self.attn_fuc(
|
||||
out = self.attn_fuc(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_metadata.prompt_lens,
|
||||
prefill_meta.prompt_lens,
|
||||
self.scale,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
output, _ = self.attn_func(
|
||||
out, _ = self.attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.max_prompt_len,
|
||||
attn_metadata.max_prompt_len,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.seq_start_loc,
|
||||
prefill_meta.max_prompt_len,
|
||||
prefill_meta.max_prompt_len,
|
||||
True,
|
||||
self.scale,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
output = self.attn_func(
|
||||
out = self.attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=attn_metadata.seq_start_loc,
|
||||
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_prompt_len,
|
||||
max_seqlen_k=attn_metadata.max_prompt_len,
|
||||
cu_seqlens_q=prefill_meta.seq_start_loc,
|
||||
cu_seqlens_k=prefill_meta.seq_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_prompt_len,
|
||||
max_seqlen_k=prefill_meta.max_prompt_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output = PagedAttention.forward_prefix(
|
||||
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.subquery_start_loc,
|
||||
attn_metadata.prompt_lens_tensor,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_subquery_len,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.prompt_lens_tensor,
|
||||
prefill_meta.context_lens,
|
||||
prefill_meta.max_subquery_len,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_context_len,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.context_lens,
|
||||
decode_meta.max_context_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
@ -305,26 +334,21 @@ def _naive_attention(
|
||||
prompt_lens: List[int],
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = query.shape[0]
|
||||
output = torch.empty_like(query)
|
||||
start = 0
|
||||
for _, prompt_len in enumerate(prompt_lens):
|
||||
end = start + prompt_len
|
||||
out = _naive_masked_attention(
|
||||
query[None, start:end],
|
||||
key[None, start:end],
|
||||
value[None, start:end],
|
||||
query[start:end],
|
||||
key[start:end],
|
||||
value[start:end],
|
||||
scale,
|
||||
)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out)
|
||||
start += prompt_len
|
||||
|
||||
# Using view got RuntimeError: view size is not compatible
|
||||
# with input tensor's size and stride (at least one
|
||||
# dimension spans across two contiguous subspaces).
|
||||
# Use reshape instead.
|
||||
return output.reshape(num_tokens, -1)
|
||||
return output
|
||||
|
||||
|
||||
def _naive_masked_attention(
|
||||
@ -333,14 +357,13 @@ def _naive_masked_attention(
|
||||
value: torch.Tensor,
|
||||
scale: float,
|
||||
) -> torch.Tensor:
|
||||
seq_len, _, _ = query.shape
|
||||
seq_len, head_size, head_dim = query.shape
|
||||
attn_mask = torch.triu(torch.ones(seq_len,
|
||||
seq_len,
|
||||
dtype=query.dtype,
|
||||
device=query.device),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(query.dtype).min
|
||||
|
||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||
attn_weights = attn_weights + attn_mask.float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
|
||||
@ -7,7 +7,8 @@ import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
|
||||
@ -49,7 +50,8 @@ class TorchSDPABackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
|
||||
AttentionMetadataPerStage):
|
||||
"""Metadata for TorchSDPABackend.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
@ -57,15 +59,6 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
prompt_lens: Optional[List[int]]
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
num_prompt_tokens: int
|
||||
num_generation_tokens: int
|
||||
|
||||
max_subquery_len: Optional[int] = None
|
||||
max_prompt_len: Optional[int] = None
|
||||
subquery_start_loc: Optional[torch.Tensor] = None
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
use_cuda_graph: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
@ -224,7 +217,7 @@ def _make_alibi_bias(
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].expand(num_heads, prompt_len, prompt_len)
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
inf_mask = torch.empty(
|
||||
(1, prompt_len, prompt_len),
|
||||
|
||||
@ -9,7 +9,8 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionMetadata)
|
||||
AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
@ -54,7 +55,7 @@ class XFormersBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class XFormersMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
|
||||
"""Metadata for XFormersbackend.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
@ -65,19 +66,10 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
# (batch_size,). The prompt length per sequence. None if it is a decoding.
|
||||
prompt_lens: Optional[List[int]]
|
||||
# prompt_lens stored as a tensor.
|
||||
prompt_lens_tensor: Optional[torch.Tensor]
|
||||
# The number of prompt tokens. Doesn't include padding.
|
||||
num_prompt_tokens: int
|
||||
# The number of generation tokens. Doesn't include padding.
|
||||
num_generation_tokens: int
|
||||
|
||||
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||
# |---------- N-1 iteration --------|
|
||||
@ -123,18 +115,27 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
class XFormersImpl(AttentionImpl):
|
||||
"""
|
||||
If the input tensors contain prompt tokens, the layout is as follows:
|
||||
|<--------------- num_prompt_tokens --------------->|
|
||||
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1--->|
|
||||
|<--------------- num_prefill_tokens ----------------->|
|
||||
|<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->|
|
||||
|
||||
Otherwise, the layout is as follows:
|
||||
|<------------------ num_generation_tokens (M) ----------------->|
|
||||
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||
|<----------------- num_decode_tokens ------------------>|
|
||||
|<--decode_0-->|..........|<--decode_M-1-->|<--padding-->|
|
||||
|
||||
Generation tokens can contain padding when cuda-graph is used.
|
||||
Currently, prompt tokens don't contain any padding.
|
||||
|
||||
The prompts might have different lengths, while the generation tokens
|
||||
always have length 1.
|
||||
|
||||
If chunked prefill is enabled, prefill tokens and decode tokens can be
|
||||
batched together in a flattened 1D query.
|
||||
|
||||
|<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->|
|
||||
|<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->|
|
||||
|
||||
Currently, cuda graph is disabled for chunked prefill, meaning there's no
|
||||
padding between prefill and decode tokens.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -170,7 +171,7 @@ class XFormersImpl(AttentionImpl):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: XFormersMetadata,
|
||||
attn_metadata: AttentionMetadata[XFormersMetadata],
|
||||
kv_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with xFormers and PagedAttention.
|
||||
@ -202,59 +203,61 @@ class XFormersImpl(AttentionImpl):
|
||||
attn_metadata.kv_cache_dtype,
|
||||
kv_scale)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
|
||||
|
||||
output = torch.empty_like(query)
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[num_prefill_tokens:]
|
||||
# QKV for prefill.
|
||||
query = query[:num_prefill_tokens]
|
||||
key = key[:num_prefill_tokens]
|
||||
value = value[:num_prefill_tokens]
|
||||
|
||||
assert query.shape[0] == num_prefill_tokens
|
||||
assert decode_query.shape[0] == num_decode_tokens
|
||||
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
# Prompt run.
|
||||
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
|
||||
# normal attention.
|
||||
# block tables are empty if the prompt does not have a cached
|
||||
# prefix.
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
# heads.
|
||||
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0],
|
||||
self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
output = self._run_memory_efficient_xformers_forward(
|
||||
query, key, value, attn_metadata)
|
||||
out = self._run_memory_efficient_xformers_forward(
|
||||
query, key, value, prefill_meta)
|
||||
assert out.shape == output[:num_prefill_tokens].shape
|
||||
output[:num_prefill_tokens] = out
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
# TODO(Hai) this triton kernel has regression issue (broke) to
|
||||
# deal with different data types between KV and FP8 KV cache,
|
||||
# to be addressed separately.
|
||||
output = PagedAttention.forward_prefix(
|
||||
out = PagedAttention.forward_prefix(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.subquery_start_loc,
|
||||
attn_metadata.prompt_lens_tensor,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_subquery_len,
|
||||
prefill_meta.block_tables,
|
||||
prefill_meta.subquery_start_loc,
|
||||
prefill_meta.prompt_lens_tensor,
|
||||
prefill_meta.context_lens,
|
||||
prefill_meta.max_subquery_len,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = PagedAttention.forward_decode(
|
||||
query,
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
||||
decode_query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.max_context_len,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.context_lens,
|
||||
decode_meta.max_context_len,
|
||||
attn_metadata.kv_cache_dtype,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
@ -275,13 +278,30 @@ class XFormersImpl(AttentionImpl):
|
||||
"""Attention for 1D query of multiple prompts. Multiple prompt
|
||||
tokens are flattened in to `query` input.
|
||||
|
||||
See https://facebookresearch.github.io/xformers/components/ops.html
|
||||
for API spec.
|
||||
|
||||
Args:
|
||||
output: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
query: shape = [num_prompt_tokens, num_heads, head_size]
|
||||
key: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prompt_tokens, num_kv_heads, head_size]
|
||||
output: shape = [num_prefill_tokens, num_heads, head_size]
|
||||
query: shape = [num_prefill_tokens, num_heads, head_size]
|
||||
key: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
"""
|
||||
original_query = query
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# GQA/MQA requires the shape [B, M, G, H, K].
|
||||
# Note that the output also has the same shape (which is different
|
||||
# from a spec from the doc).
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv, key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
@ -302,6 +322,7 @@ class XFormersImpl(AttentionImpl):
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
# Add the batch dimension.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
@ -312,14 +333,13 @@ class XFormersImpl(AttentionImpl):
|
||||
attn_bias=attn_metadata.attn_bias[0],
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
|
||||
return out.view_as(query)
|
||||
return out.view_as(original_query)
|
||||
|
||||
# Attention with alibi slopes.
|
||||
# FIXME(woosuk): Because xformers does not support dynamic sequence
|
||||
# lengths with custom attention bias, we process each prompt one by
|
||||
# one. This is inefficient, especially when we have many short prompts.
|
||||
output = torch.empty_like(query)
|
||||
output = torch.empty_like(original_query)
|
||||
start = 0
|
||||
for i, prompt_len in enumerate(attn_metadata.prompt_lens):
|
||||
end = start + prompt_len
|
||||
@ -331,7 +351,7 @@ class XFormersImpl(AttentionImpl):
|
||||
p=0.0,
|
||||
scale=self.scale)
|
||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||
output[start:end].copy_(out.squeeze(0))
|
||||
output[start:end].copy_(out.view_as(original_query[start:end]))
|
||||
start += prompt_len
|
||||
return output
|
||||
|
||||
|
||||
@ -4,7 +4,8 @@ from typing import List, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.abstract import (AttentionMetadata,
|
||||
AttentionMetadataPerStage)
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
|
||||
|
||||
@ -41,7 +42,7 @@ class Attention(nn.Module):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._C import cache_ops, ops
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
@ -13,11 +13,6 @@ _PARTITION_SIZE = 512
|
||||
@dataclass
|
||||
class PagedAttentionMetadata:
|
||||
"""Metadata for PagedAttention."""
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
|
||||
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
|
||||
# in block 0, and 1st slot in block 1, respectively.
|
||||
slot_mapping: torch.Tensor
|
||||
# (batch_size,). The length of context (tokens stored in KV cache) per
|
||||
# sequence. WARNING: When it is a prefill request, it doesn't include new
|
||||
# tokens. When it is for decoding, it includes a new token.
|
||||
@ -31,7 +26,6 @@ class PagedAttentionMetadata:
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
kv_cache_dtype: str
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
@ -75,7 +69,7 @@ class PagedAttention:
|
||||
kv_cache_dtype: str,
|
||||
kv_scale: float,
|
||||
) -> None:
|
||||
cache_ops.reshape_and_cache(
|
||||
ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
@ -205,11 +199,11 @@ class PagedAttention:
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
|
||||
|
||||
src_value_cache = src_kv_cache[1]
|
||||
dst_value_cache = dst_kv_cache[1]
|
||||
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
@ -218,4 +212,4 @@ class PagedAttention:
|
||||
) -> None:
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
@ -415,7 +415,11 @@ def attn_fwd(
|
||||
return
|
||||
|
||||
is_mqa = hq != hk
|
||||
off_h_k = off_h_q % hk if is_mqa else off_h_q
|
||||
if is_mqa: # noqa: SIM108
|
||||
off_h_k = off_h_q % hk
|
||||
else:
|
||||
off_h_k = off_h_q
|
||||
|
||||
n_extra_tokens = 0
|
||||
if seqlen_k < BLOCK_N:
|
||||
n_extra_tokens = BLOCK_N - seqlen_k
|
||||
@ -677,8 +681,7 @@ def check_args(
|
||||
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||
# TODO: Change assert if we support qkl f8 and v f16
|
||||
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||
# TODO: Fix assert to check head size <=256 once supported
|
||||
assert head_size <= 128
|
||||
assert head_size <= 256
|
||||
assert o.shape == q.shape
|
||||
assert (nheads_q % nheads_k) == 0
|
||||
|
||||
@ -729,7 +732,7 @@ class _attention(torch.autograd.Function):
|
||||
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||
|
||||
# Get closest power of 2 over or equal to 32.
|
||||
unpadded_head_dims = {32, 64, 128}
|
||||
unpadded_head_dims = {32, 64, 128, 256}
|
||||
if head_size not in unpadded_head_dims:
|
||||
padded_d_model = None
|
||||
for i in unpadded_head_dims:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import enum
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Type
|
||||
|
||||
@ -10,6 +11,8 @@ from vllm.utils import is_cpu, is_hip, is_tpu
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
|
||||
|
||||
|
||||
class _Backend(enum.Enum):
|
||||
FLASH_ATTN = enum.auto()
|
||||
@ -83,4 +86,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
||||
"Cannot use FlashAttention backend because the flash_attn package "
|
||||
"is not found. Please install it for better performance.")
|
||||
return _Backend.XFORMERS
|
||||
|
||||
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
|
||||
if backend_by_env_var is not None:
|
||||
return _Backend[backend_by_env_var]
|
||||
|
||||
# Default case.
|
||||
return _Backend.FLASH_ATTN
|
||||
|
||||
135
vllm/config.py
135
vllm/config.py
@ -1,8 +1,10 @@
|
||||
import enum
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import typing
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional, Union
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
@ -16,6 +18,8 @@ from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GB = 1 << 30
|
||||
@ -62,8 +66,8 @@ class ModelConfig:
|
||||
weights. If None, we assume the model weights are not quantized.
|
||||
quantization_param_path: Path to JSON file containing scaling factors.
|
||||
Used to load KV cache scaling factors into the model when KV cache
|
||||
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
|
||||
be used to load activation and weight scaling factors when the
|
||||
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
|
||||
be used to load activation and weight scaling factors when the
|
||||
model dtype is FP8_E4M3 on ROCm.
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
@ -139,13 +143,14 @@ class ModelConfig:
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
supported_load_format = [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
"auto", "pt", "safetensors", "npcache", "dummy", "tensorizer"
|
||||
]
|
||||
rocm_not_supported_load_format = []
|
||||
rocm_not_supported_load_format: List[str] = []
|
||||
if load_format not in supported_load_format:
|
||||
raise ValueError(
|
||||
f"Unknown load format: {self.load_format}. Must be one of "
|
||||
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||
"'auto', 'pt', 'safetensors', 'npcache', 'tensorizer', or "
|
||||
"'dummy'.")
|
||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||
rocm_supported_load_format = [
|
||||
f for f in supported_load_format
|
||||
@ -158,7 +163,9 @@ class ModelConfig:
|
||||
|
||||
# TODO: Remove this check once HF updates the pt weights of Mixtral.
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
if "MixtralForCausalLM" in architectures and load_format == "pt":
|
||||
# architectures can be None instead of []
|
||||
if architectures and "MixtralForCausalLM" in architectures \
|
||||
and load_format == "pt":
|
||||
raise ValueError(
|
||||
"Currently, the 'pt' format is not supported for Mixtral. "
|
||||
"Please use the 'safetensors' format instead. ")
|
||||
@ -415,7 +422,7 @@ class CacheConfig:
|
||||
@dataclass
|
||||
class TokenizerPoolConfig:
|
||||
"""Configuration for the tokenizer pool.
|
||||
|
||||
|
||||
Args:
|
||||
pool_size: Number of tokenizer workers in the pool.
|
||||
pool_type: Type of the pool.
|
||||
@ -439,9 +446,9 @@ class TokenizerPoolConfig:
|
||||
tokenizer_pool_extra_config: Optional[Union[str, dict]]
|
||||
) -> Optional["TokenizerPoolConfig"]:
|
||||
"""Create a TokenizerPoolConfig from the given parameters.
|
||||
|
||||
|
||||
If tokenizer_pool_size is 0, return None.
|
||||
|
||||
|
||||
Args:
|
||||
tokenizer_pool_size: Number of tokenizer workers in the pool.
|
||||
tokenizer_pool_type: Type of the pool.
|
||||
@ -563,9 +570,16 @@ class SchedulerConfig:
|
||||
if max_num_batched_tokens is not None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value for
|
||||
# higher throughput.
|
||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||
if enable_chunked_prefill:
|
||||
# For chunked prefill, choose the well-tuned batch size.
|
||||
self.max_num_batched_tokens = 768
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value
|
||||
# for higher throughput.
|
||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||
if enable_chunked_prefill:
|
||||
logger.info("Chunked prefill is enabled (EXPERIMENTAL).")
|
||||
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self.use_v2_block_manager = use_v2_block_manager
|
||||
@ -675,6 +689,9 @@ class SpeculativeConfig:
|
||||
"num_speculative_tokens to be provided, but found "
|
||||
f"{speculative_model=} and {num_speculative_tokens=}.")
|
||||
|
||||
assert (speculative_model is not None
|
||||
and num_speculative_tokens is not None)
|
||||
|
||||
# TODO: The user should be able to specify revision/quantization/max
|
||||
# model len for the draft model. It is not currently supported.
|
||||
draft_revision = None
|
||||
@ -818,9 +835,12 @@ class LoRAConfig:
|
||||
self.lora_dtype = model_config.dtype
|
||||
elif isinstance(self.lora_dtype, str):
|
||||
self.lora_dtype = getattr(torch, self.lora_dtype)
|
||||
if model_config.quantization is not None:
|
||||
raise ValueError(
|
||||
"LoRA is not supported with quantized models yet.")
|
||||
if model_config.quantization and model_config.quantization not in [
|
||||
"awq", "gptq"
|
||||
]:
|
||||
# TODO support marlin and squeezellm
|
||||
logger.warning(f"{model_config.quantization} quantization is not "
|
||||
"tested with LoRA yet.")
|
||||
|
||||
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
|
||||
if scheduler_config.max_num_batched_tokens > 65528:
|
||||
@ -872,6 +892,65 @@ class VisionLanguageConfig:
|
||||
f"{[x.name for x in cls.ImageInputType]}.") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerConfig:
|
||||
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO,
|
||||
str, bytes, os.PathLike, int]
|
||||
vllm_tensorized: bool
|
||||
verify_hash: Optional[bool] = False
|
||||
num_readers: Optional[int] = 1
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
model_class: Optional[torch.nn.Module] = None
|
||||
hf_config: Optional[PretrainedConfig] = None
|
||||
dtype: Union[str, torch.dtype] = None
|
||||
|
||||
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
tensorizer_args = {
|
||||
"tensorizer_uri": self.tensorizer_uri,
|
||||
"vllm_tensorized": self.vllm_tensorized,
|
||||
"verify_hash": self.verify_hash,
|
||||
"num_readers": self.num_readers,
|
||||
"encryption_keyfile": self.encryption_keyfile,
|
||||
"s3_access_key_id": self.s3_access_key_id,
|
||||
"s3_secret_access_key": self.s3_secret_access_key,
|
||||
"s3_endpoint": self.s3_endpoint,
|
||||
}
|
||||
return TensorizerArgs(**tensorizer_args)
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
if (parallel_config.tensor_parallel_size > 1
|
||||
and self.tensorizer_uri is not None):
|
||||
raise ValueError(
|
||||
"Loading to multiple GPUs is not currently supported with "
|
||||
"vLLM-serialized models. Please set tensor_parallel_size=1."
|
||||
" or use a non-vLLM-serialized model, such as a "
|
||||
"serialized Hugging Face `PretrainedModel`.")
|
||||
|
||||
def verify_with_model_config(self, model_config) -> None:
|
||||
if (model_config.quantization is not None
|
||||
and self.tensorizer_uri is not None):
|
||||
from vllm.model_executor.tensorizer_loader import (
|
||||
tensorizer_warning)
|
||||
tensorizer_warning(
|
||||
"Loading a model using Tensorizer with quantization on vLLM"
|
||||
" is unstable and may lead to errors.")
|
||||
|
||||
if (model_config.load_format != "tensorizer"
|
||||
and self.tensorizer_uri is not None):
|
||||
raise ValueError(
|
||||
"A tensorizer uri was passed for tensorizer loading, but the "
|
||||
f"load format was set to {model_config.load_format}. "
|
||||
"Please set the load format to 'tensorizer' to use "
|
||||
f"tensorizer args.")
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
"float16": torch.float16,
|
||||
@ -986,7 +1065,7 @@ def _get_and_verify_max_len(
|
||||
derived_max_model_len *= scaling_factor
|
||||
|
||||
if max_model_len is None:
|
||||
max_model_len = derived_max_model_len
|
||||
max_model_len = int(derived_max_model_len)
|
||||
elif max_model_len > derived_max_model_len:
|
||||
# Some models might have a separate key for specifying model_max_length
|
||||
# that will be bigger than derived_max_model_len. We compare user input
|
||||
@ -1005,6 +1084,21 @@ def _get_and_verify_max_len(
|
||||
return int(max_model_len)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodingConfig:
|
||||
"""Dataclass which contains the decoding strategy of the engine"""
|
||||
|
||||
# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
|
||||
guided_decoding_backend: str = 'outlines'
|
||||
|
||||
def __post_init__(self):
|
||||
valid_guided_backends = ['outlines', 'lm-format-enforcer']
|
||||
backend = self.guided_decoding_backend
|
||||
if backend not in valid_guided_backends:
|
||||
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
|
||||
f"must be one of {valid_guided_backends}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EngineConfig:
|
||||
"""Dataclass which contains all engine-related configuration. This
|
||||
@ -1019,6 +1113,8 @@ class EngineConfig:
|
||||
lora_config: Optional[LoRAConfig]
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
speculative_config: Optional[SpeculativeConfig]
|
||||
decoding_config: Optional[DecodingConfig]
|
||||
tensorizer_config: Optional[TensorizerConfig]
|
||||
|
||||
def __post_init__(self):
|
||||
"""Verify configs are valid & consistent with each other.
|
||||
@ -1026,6 +1122,11 @@ class EngineConfig:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
if self.tensorizer_config:
|
||||
self.tensorizer_config.verify_with_parallel_config(
|
||||
self.parallel_config)
|
||||
self.tensorizer_config.verify_with_model_config(self.model_config)
|
||||
|
||||
if self.lora_config:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, FrozenSet, List, Optional, Protocol
|
||||
|
||||
from vllm.utils import Device
|
||||
|
||||
@ -10,23 +10,28 @@ class Block(ABC):
|
||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
@property
|
||||
@abstractmethod
|
||||
def block_id(self) -> Optional[int]:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
@property
|
||||
@abstractmethod
|
||||
def token_ids(self) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_empty_slots(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_full(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
@property
|
||||
@abstractmethod
|
||||
def prev_block(self) -> Optional["Block"]:
|
||||
pass
|
||||
|
||||
@ -47,12 +52,13 @@ class Block(ABC):
|
||||
class BlockAllocator(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
|
||||
def allocate_mutable(self, prev_block: Optional[Block],
|
||||
device: Device) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_immutable(self, prev_block: Optional[Block],
|
||||
token_ids: List[int]) -> Block:
|
||||
token_ids: List[int], device: Device) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -64,11 +70,12 @@ class BlockAllocator(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self) -> int:
|
||||
def get_num_free_blocks(self, device: Device) -> int:
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
def all_block_ids(self) -> frozenset[int]:
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -2,7 +2,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import count, takewhile
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set
|
||||
|
||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||
from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor
|
||||
@ -231,10 +233,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
|
||||
if self.enable_caching:
|
||||
logger.info("Automatic prefix caching is enabled.")
|
||||
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
|
||||
num_gpu_blocks)
|
||||
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
|
||||
num_cpu_blocks)
|
||||
self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
|
||||
Device.GPU, block_size, num_gpu_blocks)
|
||||
self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
|
||||
Device.CPU, block_size, num_cpu_blocks)
|
||||
else:
|
||||
self.gpu_allocator = UncachedBlockAllocator(
|
||||
Device.GPU, block_size, num_gpu_blocks)
|
||||
@ -588,7 +590,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
for b in takewhile(lambda b: b.computed, block_table[:-1])
|
||||
]
|
||||
|
||||
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
"""Return the block ids that are common for a given sequence group.
|
||||
|
||||
Used in prefill (can skip prefill of some blocks).
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""A block manager that manages token blocks."""
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
|
||||
from vllm.core.block.block_table import BlockTable
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
@ -205,7 +206,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
# as computed.
|
||||
self.block_allocator.mark_blocks_as_computed()
|
||||
|
||||
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
"""Determine which blocks for which we skip prefill.
|
||||
|
||||
With prefix caching we can skip prefill for previously-generated blocks.
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
from typing import Sequence as GenericSequence
|
||||
|
||||
from vllm.sequence import Sequence, SequenceGroup
|
||||
|
||||
@ -103,7 +104,8 @@ class BlockSpaceManager(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -42,8 +42,8 @@ class SchedulingBudget:
|
||||
"""
|
||||
token_budget: int
|
||||
max_num_seqs: int
|
||||
_requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set)
|
||||
_requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set)
|
||||
_requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
|
||||
_requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
|
||||
_num_batched_tokens: int = 0
|
||||
_num_curr_seqs: int = 0
|
||||
|
||||
@ -133,14 +133,18 @@ class SchedulerOutputs:
|
||||
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
||||
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
||||
|
||||
def _sort_by_lora_ids(self) -> bool:
|
||||
def _sort_by_lora_ids(self):
|
||||
self.scheduled_seq_groups = sorted(
|
||||
self.scheduled_seq_groups,
|
||||
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
|
||||
|
||||
@property
|
||||
def lora_requests(self) -> Set[LoRARequest]:
|
||||
return {g.seq_group.lora_request for g in self.scheduled_seq_groups}
|
||||
return {
|
||||
g.seq_group.lora_request
|
||||
for g in self.scheduled_seq_groups
|
||||
if g.seq_group.lora_request is not None
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -333,7 +337,8 @@ class Scheduler:
|
||||
self.free_seq(seq)
|
||||
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
return self.waiting or self.running or self.swapped
|
||||
return len(self.waiting) != 0 or len(self.running) != 0 or len(
|
||||
self.swapped) != 0
|
||||
|
||||
def get_num_unfinished_seq_groups(self) -> int:
|
||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||
@ -400,7 +405,7 @@ class Scheduler:
|
||||
budget.subtract_num_seqs(seq_group.request_id,
|
||||
num_running_seqs)
|
||||
if curr_loras is not None and seq_group.lora_int_id > 0:
|
||||
curr_loras.pop(seq_group.lora_int_id)
|
||||
curr_loras.remove(seq_group.lora_int_id)
|
||||
|
||||
if running_queue:
|
||||
# Preempt the lowest-priority sequence groups.
|
||||
@ -492,7 +497,7 @@ class Scheduler:
|
||||
now = time.time()
|
||||
swapped_queue = policy.sort_by_priority(now, swapped_queue)
|
||||
|
||||
leftover_swapped = deque()
|
||||
leftover_swapped: Deque[SequenceGroup] = deque()
|
||||
while swapped_queue:
|
||||
seq_group = swapped_queue[0]
|
||||
|
||||
@ -503,7 +508,9 @@ class Scheduler:
|
||||
lora_int_id = 0
|
||||
if self.lora_enabled:
|
||||
lora_int_id = seq_group.lora_int_id
|
||||
if (lora_int_id > 0 and lora_int_id not in curr_loras
|
||||
assert curr_loras is not None
|
||||
assert self.lora_config is not None
|
||||
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
|
||||
and len(curr_loras) >= self.lora_config.max_loras):
|
||||
# We don't have a space for another LoRA, so
|
||||
# we ignore this request for now.
|
||||
@ -589,7 +596,7 @@ class Scheduler:
|
||||
# Copy the queue so that the input queue is not modified.
|
||||
waiting_queue = deque([s for s in waiting_queue])
|
||||
|
||||
leftover_waiting_sequences = deque()
|
||||
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
|
||||
while self._passed_delay(time.time()) and waiting_queue:
|
||||
seq_group = waiting_queue[0]
|
||||
|
||||
@ -631,6 +638,8 @@ class Scheduler:
|
||||
lora_int_id = 0
|
||||
if self.lora_enabled:
|
||||
lora_int_id = seq_group.lora_int_id
|
||||
assert curr_loras is not None
|
||||
assert self.lora_config is not None
|
||||
if (self.lora_enabled and lora_int_id > 0
|
||||
and lora_int_id not in curr_loras
|
||||
and len(curr_loras) >= self.lora_config.max_loras):
|
||||
@ -670,7 +679,7 @@ class Scheduler:
|
||||
def _schedule_default(self) -> SchedulerOutputs:
|
||||
"""Schedule queued requests.
|
||||
|
||||
The current policy is designed to opimimize the throughput. First,
|
||||
The current policy is designed to optimize the throughput. First,
|
||||
it batches as many prefill requests as possible. And it schedules
|
||||
decodes. If there's a pressure on GPU memory, decode requests can
|
||||
be swapped or preempted.
|
||||
@ -776,7 +785,7 @@ class Scheduler:
|
||||
token_budget=self.scheduler_config.max_num_batched_tokens,
|
||||
max_num_seqs=self.scheduler_config.max_num_seqs,
|
||||
)
|
||||
curr_loras = set()
|
||||
curr_loras: Set[int] = set()
|
||||
|
||||
remaining_waiting, prefills = (self.waiting,
|
||||
SchedulerPrefillOutputs.create_empty())
|
||||
@ -826,13 +835,12 @@ class Scheduler:
|
||||
# Update swapped requests.
|
||||
self.swapped = remaining_swapped
|
||||
self.swapped.extend(running_scheduled.swapped_out)
|
||||
|
||||
return SchedulerOutputs(
|
||||
scheduled_seq_groups=(prefills.seq_groups +
|
||||
running_scheduled.decode_seq_groups +
|
||||
running_scheduled.prefill_seq_groups +
|
||||
swapped_in.decode_seq_groups +
|
||||
swapped_in.prefill_seq_groups),
|
||||
swapped_in.prefill_seq_groups +
|
||||
running_scheduled.decode_seq_groups +
|
||||
swapped_in.decode_seq_groups),
|
||||
num_prefill_groups=(len(prefills.seq_groups) +
|
||||
len(swapped_in.prefill_seq_groups) +
|
||||
len(running_scheduled.prefill_seq_groups)),
|
||||
@ -907,7 +915,7 @@ class Scheduler:
|
||||
|
||||
# It assumes the scheduled_seq_groups is ordered by
|
||||
# prefill < decoding.
|
||||
is_prompt = i < scheduler_outputs.num_prefill_groups
|
||||
is_prompt = seq_group.is_prefill()
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group.request_id,
|
||||
is_prompt=is_prompt,
|
||||
@ -1084,7 +1092,7 @@ class Scheduler:
|
||||
|
||||
def _get_num_new_tokens(self, seq_group: SequenceGroup,
|
||||
status: SequenceStatus, enable_chunking: bool,
|
||||
budget: SchedulingBudget) -> Tuple[int, bool]:
|
||||
budget: SchedulingBudget) -> int:
|
||||
"""Get the next new tokens to compute for a given sequence group
|
||||
that's in a given `status`.
|
||||
|
||||
|
||||
3
vllm/distributed/__init__.py
Normal file
3
vllm/distributed/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .communication_op import *
|
||||
from .parallel_state import *
|
||||
from .utils import *
|
||||
@ -1,15 +1,13 @@
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from vllm.model_executor.parallel_utils import pynccl_utils
|
||||
from vllm.model_executor.parallel_utils.custom_all_reduce import (
|
||||
custom_all_reduce)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)
|
||||
from .parallel_state import (get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
is_pynccl_enabled_for_all_reduce)
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
@ -24,6 +22,10 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
TLDR: always assume this function modifies its input, but use the return
|
||||
value as the output.
|
||||
"""
|
||||
from vllm.distributed.device_communicators import pynccl_utils
|
||||
from vllm.distributed.device_communicators.custom_all_reduce import (
|
||||
custom_all_reduce)
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
return input_
|
||||
@ -142,7 +144,7 @@ def broadcast_tensor_dict(
|
||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||
src: int = 0,
|
||||
group: Optional[ProcessGroup] = None,
|
||||
) -> Dict[Any, Union[torch.Tensor, Any]]:
|
||||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
||||
"""Broadcast the input tensor dictionary."""
|
||||
group = group or torch.distributed.group.WORLD
|
||||
ranks = torch.distributed.get_process_group_ranks(group)
|
||||
@ -155,10 +157,10 @@ def broadcast_tensor_dict(
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
if rank == src:
|
||||
metadata_list: List[Tuple[Any, Any]] = []
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
||||
metadata_list = []
|
||||
for key, value in tensor_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.is_cuda, (
|
||||
@ -171,19 +173,27 @@ def broadcast_tensor_dict(
|
||||
torch.distributed.broadcast_object_list([metadata_list],
|
||||
src=src,
|
||||
group=group)
|
||||
async_handles = []
|
||||
for key, value in metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = tensor_dict[key]
|
||||
torch.distributed.broadcast(tensor, src=src, group=group)
|
||||
async_handles.append(
|
||||
torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
group=group,
|
||||
async_op=True))
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
|
||||
else:
|
||||
recv_metadata_list = [None]
|
||||
torch.distributed.broadcast_object_list(recv_metadata_list,
|
||||
src=src,
|
||||
group=group)
|
||||
metadata_list = recv_metadata_list[0]
|
||||
assert recv_metadata_list[0] is not None
|
||||
tensor_dict = {}
|
||||
async_handles = []
|
||||
for key, value in metadata_list:
|
||||
for key, value in recv_metadata_list[0]:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
0
vllm/distributed/device_communicators/__init__.py
Normal file
0
vllm/distributed/device_communicators/__init__.py
Normal file
@ -5,8 +5,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
|
||||
try:
|
||||
import pynvml
|
||||
@ -25,6 +23,9 @@ _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
|
||||
|
||||
|
||||
def init_custom_ar() -> None:
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
|
||||
global _CA_HANDLE
|
||||
if _CA_HANDLE is not None:
|
||||
return
|
||||
@ -41,12 +42,17 @@ def init_custom_ar() -> None:
|
||||
" disable_custom_all_reduce=True explicitly.", world_size,
|
||||
str(_SUPPORTED_WORLD_SIZES))
|
||||
return
|
||||
if not _can_p2p(rank, world_size):
|
||||
num_dev = torch.cuda.device_count()
|
||||
# note: num dev can be larger than world_size if we're only using
|
||||
# first few GPUs
|
||||
if num_dev < world_size:
|
||||
logger.warn(
|
||||
"Custom allreduce is disabled because your platform lacks GPU P2P"
|
||||
" capability or P2P test failed. To silence this warning, specify"
|
||||
" disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
"Cannot test GPU P2P because not all GPUs are visible to the "
|
||||
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
|
||||
" is set.")
|
||||
return False
|
||||
# test nvlink first, this will filter out most of the cases
|
||||
# where custom allreduce is not supported
|
||||
full_nvlink = _is_full_nvlink(rank, world_size)
|
||||
if world_size > 2 and not full_nvlink:
|
||||
logger.warn(
|
||||
@ -54,6 +60,15 @@ def init_custom_ar() -> None:
|
||||
" than two PCIe-only GPUs. To silence this warning, specify"
|
||||
" disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
# test P2P capability
|
||||
# this is expensive to compute at the first time
|
||||
# then we cache the result
|
||||
if not _can_p2p(rank, world_size):
|
||||
logger.warn(
|
||||
"Custom allreduce is disabled because your platform lacks GPU P2P"
|
||||
" capability or P2P test failed. To silence this warning, specify"
|
||||
" disable_custom_all_reduce=True explicitly.")
|
||||
return
|
||||
_CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)
|
||||
|
||||
|
||||
@ -142,40 +157,15 @@ def _is_full_nvlink(rank, world_size):
|
||||
|
||||
|
||||
def _can_p2p(rank: int, world_size: int) -> bool:
|
||||
num_dev = torch.cuda.device_count()
|
||||
# note: num dev can be larger than world_size if we're only using
|
||||
# first few GPUs
|
||||
if num_dev < world_size:
|
||||
logger.warn(
|
||||
"Cannot test GPU P2P because not all GPUs are visible to the "
|
||||
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
|
||||
" is set.")
|
||||
return False
|
||||
from vllm.distributed.utils import gpu_p2p_access_check
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
continue
|
||||
if not torch.cuda.can_device_access_peer(rank, i):
|
||||
return False
|
||||
# on some platforms, P2P support might be buggy and we need
|
||||
# additional checks. See also:
|
||||
# https://github.com/vllm-project/vllm/issues/2728
|
||||
if not _can_actually_p2p(rank, i):
|
||||
if not gpu_p2p_access_check(rank, i):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# code partly borrowed from
|
||||
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
|
||||
# License: MIT
|
||||
def _can_actually_p2p(idx_a, idx_b):
|
||||
dev_i = f"cuda:{idx_a}"
|
||||
dev_j = f"cuda:{idx_b}"
|
||||
a = torch.randn(5, device=dev_i) + 123.0
|
||||
b = a.to(dev_j)
|
||||
c = b.to(dev_i)
|
||||
return torch.all(a == c)
|
||||
|
||||
|
||||
class CustomAllreduce:
|
||||
|
||||
# max_size: max supported allreduce size
|
||||
@ -9,8 +9,8 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
|
||||
ncclGetVersion)
|
||||
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
|
||||
ncclGetVersion)
|
||||
except Exception as e:
|
||||
# in non-NVIDIA environments, we can't import the nccl module
|
||||
# e.g. when running on machines with AMD GPUs
|
||||
@ -8,7 +8,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils import pynccl_utils
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Tensor model parallel group that the current rank belongs to.
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
@ -39,14 +41,23 @@ _CPU_WORLD_GROUP = None
|
||||
# source rank when broadcasting from the first or last pipeline stage.
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
||||
_LOCAL_RANK = -1
|
||||
|
||||
|
||||
def get_local_rank():
|
||||
global _LOCAL_RANK
|
||||
return _LOCAL_RANK
|
||||
|
||||
|
||||
def init_distributed_environment(
|
||||
world_size: int,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
world_size: int = -1,
|
||||
rank: int = -1,
|
||||
distributed_init_method: str = "env://",
|
||||
local_rank: int = -1,
|
||||
backend: str = "nccl",
|
||||
):
|
||||
logger.debug(f"{world_size=} {rank=} {local_rank=} "
|
||||
f"{distributed_init_method=} {backend=}")
|
||||
if not torch.distributed.is_initialized():
|
||||
assert distributed_init_method is not None, (
|
||||
"distributed_init_method must be provided when initializing "
|
||||
@ -62,6 +73,8 @@ def init_distributed_environment(
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
|
||||
backend="gloo")
|
||||
global _LOCAL_RANK
|
||||
_LOCAL_RANK = local_rank
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
@ -266,6 +279,7 @@ def destroy_model_parallel():
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
from vllm.distributed.device_communicators import pynccl_utils
|
||||
|
||||
# Destroy the pynccl states if any.
|
||||
pynccl_utils.destroy_process_group()
|
||||
@ -279,6 +293,7 @@ _ENABLE_PYNCCL_FOR_ALL_REDUCE = False
|
||||
|
||||
@contextlib.contextmanager
|
||||
def with_pynccl_for_all_reduce():
|
||||
from vllm.distributed.device_communicators import pynccl_utils
|
||||
"""use pynccl instead of torch.distributed for all reduce"""
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_size == 1:
|
||||
133
vllm/distributed/utils.py
Normal file
133
vllm/distributed/utils.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .parallel_state import get_cpu_world_group, get_local_rank
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
contiguous_split_chunks: bool = False,
|
||||
) -> Sequence[torch.Tensor]:
|
||||
""" Split a tensor along its last dimension.
|
||||
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
|
||||
Returns:
|
||||
A list of Tensors
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# NOTE: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
# code partly borrowed from
|
||||
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
|
||||
# License: MIT
|
||||
def _can_actually_p2p(idx_a, idx_b):
|
||||
dev_i = f"cuda:{idx_a}"
|
||||
dev_j = f"cuda:{idx_b}"
|
||||
a = torch.randn(5, device=dev_i) + 123.0
|
||||
b = a.to(dev_j)
|
||||
c = b.to(dev_i)
|
||||
return torch.all(a == c).cpu().item()
|
||||
|
||||
|
||||
# why do we need this cache?
|
||||
# 1. we can have runtime checks for P2P access, where every process checks
|
||||
# P2P access to all other GPUs. Unfortunately, the test might cost many
|
||||
# (world_size * world_size) cuda context, and reduce the memory available
|
||||
# for the model. see https://github.com/vllm-project/vllm/issues/3821
|
||||
# 2. alternatively, we can have a p2p map that is generated by the master
|
||||
# process and broadcasted to all other processes. This still requires
|
||||
# #world_size of cuda context, belonging to the master process, on each GPU.
|
||||
# 3. we can have a cache file, that records the p2p access status. The first
|
||||
# time the master process checks the p2p access, it will generate the cache
|
||||
# file, at the cost of #world_size of cuda context. Later on, all processes
|
||||
# can read the cache file to check the p2p access status without any cost of
|
||||
# additional cuda context.
|
||||
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
|
||||
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
|
||||
# e.g. used by different vllm engines. The device id in the cache file is a
|
||||
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
|
||||
# of visible devices in the vllm engine.
|
||||
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None
|
||||
|
||||
|
||||
def gpu_p2p_access_check(i: int, j: int) -> bool:
|
||||
"""Check if GPU i can access GPU j."""
|
||||
|
||||
# if the cache variable is already calculated,
|
||||
# read from the cache instead of checking it again
|
||||
global _gpu_p2p_access_cache
|
||||
if _gpu_p2p_access_cache is not None:
|
||||
return _gpu_p2p_access_cache[f"{i}->{j}"]
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
num_dev = torch.cuda.device_count()
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
path = os.path.expanduser(
|
||||
f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
if (not is_distributed or get_local_rank() == 0) \
|
||||
and (not os.path.exists(path)):
|
||||
# only the local master process (with local_rank == 0) can
|
||||
# enter this block to calculate the cache
|
||||
logger.info(f"generating GPU P2P access cache for in {path}")
|
||||
cache = {}
|
||||
for _i in range(num_dev):
|
||||
for _j in range(num_dev):
|
||||
# on some platforms, P2P support might be buggy and we need
|
||||
# additional checks. See also:
|
||||
# https://github.com/vllm-project/vllm/issues/2728
|
||||
cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
|
||||
_i, _j) and _can_actually_p2p(_i, _j)
|
||||
with open(path, "w") as f:
|
||||
json.dump(cache, f, indent=4)
|
||||
if is_distributed:
|
||||
cpu_world_group = get_cpu_world_group()
|
||||
dist.barrier(cpu_world_group)
|
||||
logger.info(f"reading GPU P2P access cache from {path}")
|
||||
with open(path, "r") as f:
|
||||
cache = json.load(f)
|
||||
_gpu_p2p_access_cache = cache
|
||||
return _gpu_p2p_access_cache[f"{i}->{j}"]
|
||||
@ -1,12 +1,15 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
import io
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import BinaryIO, Optional, Union
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TokenizerPoolConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, TensorizerConfig,
|
||||
TokenizerPoolConfig, VisionLanguageConfig)
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
from vllm.utils import str_to_int_tuple
|
||||
|
||||
|
||||
@ -58,15 +61,26 @@ class EngineArgs:
|
||||
num_gpu_blocks_override: Optional[int] = None
|
||||
num_lookahead_slots: int = 0
|
||||
|
||||
# Tensorizer configuration parameters
|
||||
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
|
||||
bytes, os.PathLike, int] = None
|
||||
vllm_tensorized: bool = False
|
||||
verify_hash: Optional[bool] = False
|
||||
num_readers: Optional[int] = 1
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
|
||||
# Related to Vision-language models such as llava
|
||||
image_input_type: Optional[str] = None
|
||||
image_token_id: Optional[int] = None
|
||||
image_input_shape: Optional[str] = None
|
||||
image_feature_size: Optional[int] = None
|
||||
|
||||
scheduler_delay_factor: float = 0.0
|
||||
enable_chunked_prefill: bool = False
|
||||
|
||||
guided_decoding_backend: str = 'outlines'
|
||||
# Speculative decoding configuration.
|
||||
speculative_model: Optional[str] = None
|
||||
num_speculative_tokens: Optional[int] = None
|
||||
@ -135,7 +149,9 @@ class EngineArgs:
|
||||
'--load-format',
|
||||
type=str,
|
||||
default=EngineArgs.load_format,
|
||||
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
|
||||
choices=[
|
||||
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
|
||||
],
|
||||
help='The format of the model weights to load. '
|
||||
'"auto" will try to load the weights in the safetensors format '
|
||||
'and fall back to the pytorch bin format if safetensors format '
|
||||
@ -145,7 +161,10 @@ class EngineArgs:
|
||||
'"npcache" will load the weights in pytorch format and store '
|
||||
'a numpy cache to speed up the loading. '
|
||||
'"dummy" will initialize the weights with random values, '
|
||||
'which is mainly for profiling.')
|
||||
'which is mainly for profiling.'
|
||||
'"tensorizer" will load the weights using tensorizer from CoreWeave'
|
||||
'which assumes tensorizer_uri is set to the location of the '
|
||||
'serialized weights.')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
@ -182,6 +201,13 @@ class EngineArgs:
|
||||
default=EngineArgs.max_model_len,
|
||||
help='model context length. If unspecified, '
|
||||
'will be automatically derived from the model.')
|
||||
parser.add_argument(
|
||||
'--guided-decoding-backend',
|
||||
type=str,
|
||||
default='outlines',
|
||||
choices=['outlines', 'lm-format-enforcer'],
|
||||
help='Which engine will be used for guided decoding'
|
||||
' (JSON schema / regex etc)')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--worker-use-ray',
|
||||
action='store_true',
|
||||
@ -386,9 +412,8 @@ class EngineArgs:
|
||||
'prompt latency) before scheduling next prompt.')
|
||||
parser.add_argument(
|
||||
'--enable-chunked-prefill',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='If True, the prefill requests can be chunked based on the '
|
||||
action='store_true',
|
||||
help='If set, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens')
|
||||
|
||||
parser.add_argument(
|
||||
@ -404,6 +429,7 @@ class EngineArgs:
|
||||
default=None,
|
||||
help='The number of speculative tokens to sample from '
|
||||
'the draft model in speculative decoding')
|
||||
parser = TensorizerArgs.add_cli_args(parser)
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -466,6 +492,17 @@ class EngineArgs:
|
||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||
|
||||
tensorizer_config = TensorizerConfig(
|
||||
tensorizer_uri=self.tensorizer_uri,
|
||||
vllm_tensorized=self.vllm_tensorized,
|
||||
verify_hash=self.verify_hash,
|
||||
num_readers=self.num_readers,
|
||||
encryption_keyfile=self.encryption_keyfile,
|
||||
s3_access_key_id=self.s3_access_key_id,
|
||||
s3_secret_access_key=self.s3_secret_access_key,
|
||||
s3_endpoint=self.s3_endpoint,
|
||||
)
|
||||
|
||||
if self.image_input_type:
|
||||
if (not self.image_token_id or not self.image_input_shape
|
||||
or not self.image_feature_size):
|
||||
@ -482,6 +519,9 @@ class EngineArgs:
|
||||
else:
|
||||
vision_language_config = None
|
||||
|
||||
decoding_config = DecodingConfig(
|
||||
guided_decoding_backend=self.guided_decoding_backend)
|
||||
|
||||
return EngineConfig(model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
@ -489,7 +529,9 @@ class EngineArgs:
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
speculative_config=speculative_config)
|
||||
speculative_config=speculative_config,
|
||||
decoding_config=decoding_config,
|
||||
tensorizer_config=tensorizer_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -333,8 +333,7 @@ class AsyncLLMEngine:
|
||||
if engine_config.device_config.device_type == "neuron":
|
||||
raise NotImplementedError("Neuron is not supported for "
|
||||
"async engine yet.")
|
||||
elif (engine_config.parallel_config.worker_use_ray
|
||||
or engine_args.engine_use_ray):
|
||||
elif engine_config.parallel_config.worker_use_ray:
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||
executor_class = RayGPUExecutorAsync
|
||||
@ -410,8 +409,8 @@ class AsyncLLMEngine:
|
||||
else:
|
||||
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
|
||||
# order of the arguments.
|
||||
cache_config = args[1]
|
||||
parallel_config = args[2]
|
||||
cache_config = kwargs["cache_config"]
|
||||
parallel_config = kwargs["parallel_config"]
|
||||
if parallel_config.tensor_parallel_size == 1:
|
||||
num_gpus = cache_config.gpu_memory_utilization
|
||||
else:
|
||||
|
||||
@ -4,8 +4,9 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
import vllm
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TensorizerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
@ -74,6 +75,8 @@ class LLMEngine:
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
decoding_config: Optional[DecodingConfig],
|
||||
tensorizer_config: Optional[TensorizerConfig],
|
||||
executor_class: Type[ExecutorBase],
|
||||
log_stats: bool,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
@ -99,6 +102,7 @@ class LLMEngine:
|
||||
f"kv_cache_dtype={cache_config.cache_dtype}, "
|
||||
f"quantization_param_path={model_config.quantization_param_path}, "
|
||||
f"device_config={device_config.device}, "
|
||||
f"decoding_config={decoding_config!r}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
@ -110,6 +114,8 @@ class LLMEngine:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.speculative_config = speculative_config
|
||||
self.decoding_config = decoding_config or DecodingConfig()
|
||||
self.tensorizer_config = tensorizer_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
self._init_tokenizer()
|
||||
@ -125,6 +131,7 @@ class LLMEngine:
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
speculative_config=speculative_config,
|
||||
tensorizer_config=tensorizer_config,
|
||||
)
|
||||
|
||||
self._initialize_kv_caches()
|
||||
@ -267,6 +274,9 @@ class LLMEngine:
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
if self.tensorizer_config:
|
||||
self.tensorizer_config.verify_with_parallel_config(
|
||||
self.parallel_config)
|
||||
if self.lora_config:
|
||||
self.lora_config.verify_with_model_config(self.model_config)
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
@ -504,9 +514,11 @@ class LLMEngine:
|
||||
|
||||
for seq, _ in child_seqs:
|
||||
if seq_group.sampling_params.detokenize:
|
||||
self.detokenizer.decode_sequence_inplace(
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, seq_group.sampling_params)
|
||||
self._check_stop(seq, seq_group.sampling_params)
|
||||
else:
|
||||
new_char_count = 0
|
||||
self._check_stop(seq, new_char_count, seq_group.sampling_params)
|
||||
|
||||
# Non-beam search case
|
||||
if not seq_group.sampling_params.use_beam_search:
|
||||
@ -636,7 +648,10 @@ class LLMEngine:
|
||||
seq_group = scheduled_seq_group.seq_group
|
||||
seq_group.update_num_computed_tokens(
|
||||
scheduled_seq_group.token_chunk_size)
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
# If uncomputed tokens > 0, it means prefill is chunked.
|
||||
# We don't need to process outputs in that case.
|
||||
if seq_group.get_num_uncomputed_tokens() == 0:
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
|
||||
# Free the finished sequence groups.
|
||||
self.scheduler.free_finished_seq_groups()
|
||||
@ -798,9 +813,45 @@ class LLMEngine:
|
||||
time_e2e_requests=time_e2e_requests,
|
||||
)
|
||||
|
||||
def _check_stop(self, seq: Sequence,
|
||||
def _check_stop(self, seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Stop the finished sequences."""
|
||||
"""Stop the finished sequences.
|
||||
|
||||
new_char_count is the number of chars added to the
|
||||
sequence's output text for the newly generated token
|
||||
"""
|
||||
|
||||
# Check if the minimum number of tokens has been generated yet;
|
||||
# skip the stop string/token checks if not
|
||||
if seq.get_output_len() < sampling_params.min_tokens:
|
||||
return
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == seq.eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Check if a stop token was encountered.
|
||||
# This assumes a single token produced per step.
|
||||
last_token_id = seq.get_last_token_id()
|
||||
if last_token_id in sampling_params.stop_token_ids:
|
||||
if new_char_count and (
|
||||
not sampling_params.include_stop_str_in_output):
|
||||
# Remove last token
|
||||
seq.output_text = seq.output_text[:-new_char_count]
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = last_token_id
|
||||
return
|
||||
|
||||
# Check if any stop strings are matched.
|
||||
stop_str = self._check_stop_strings(seq, new_char_count,
|
||||
sampling_params)
|
||||
if stop_str is not None:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
|
||||
# Check if the sequence has reached max_model_len.
|
||||
if seq.get_len() > self.scheduler_config.max_model_len:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
@ -811,43 +862,37 @@ class LLMEngine:
|
||||
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
return
|
||||
|
||||
# Check if the minimum number of tokens has been generated yet;
|
||||
# skip the stop string/token checks if not
|
||||
if seq.get_output_len() < sampling_params.min_tokens:
|
||||
return
|
||||
@staticmethod
|
||||
def _check_stop_strings(seq: Sequence, new_char_count: int,
|
||||
sampling_params: SamplingParams) -> Optional[str]:
|
||||
"""Check if any stop strings are matched and truncate sequence
|
||||
output text accordingly.
|
||||
|
||||
if sampling_params.detokenize:
|
||||
for stop_str in sampling_params.stop:
|
||||
if seq.output_text.endswith(stop_str):
|
||||
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = stop_str
|
||||
return
|
||||
last_token_id = seq.get_last_token_id()
|
||||
if last_token_id in sampling_params.stop_token_ids:
|
||||
stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens(
|
||||
last_token_id)
|
||||
self._finalize_sequence(seq, sampling_params, stop_str)
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
seq.stop_reason = last_token_id
|
||||
return
|
||||
Returns the stop string if matched or else None.
|
||||
"""
|
||||
if not new_char_count:
|
||||
return None
|
||||
|
||||
# Check if the sequence has generated the EOS token.
|
||||
if ((not sampling_params.ignore_eos)
|
||||
and seq.get_last_token_id() == seq.eos_token_id):
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
for stop_str in sampling_params.stop:
|
||||
stop_string_len = len(stop_str)
|
||||
# Avoid searching already-searched text.
|
||||
stop_index = seq.output_text.find(
|
||||
stop_str, -new_char_count - stop_string_len)
|
||||
if stop_index == -1:
|
||||
continue
|
||||
|
||||
def _finalize_sequence(self, seq: Sequence,
|
||||
sampling_params: SamplingParams,
|
||||
stop_string: str) -> None:
|
||||
if sampling_params.include_stop_str_in_output:
|
||||
return
|
||||
if sampling_params.include_stop_str_in_output:
|
||||
# Truncate to end of stop string.
|
||||
stop_index += stop_string_len
|
||||
if stop_index >= len(seq.output_text):
|
||||
# No truncation required.
|
||||
return stop_str
|
||||
|
||||
if stop_string and seq.output_text.endswith(stop_string):
|
||||
# Truncate the output text so that the stop string is
|
||||
# not included in the output.
|
||||
seq.output_text = seq.output_text[:-len(stop_string)]
|
||||
# Truncate the output text to either the beginning
|
||||
# or end of the stop string.
|
||||
seq.output_text = seq.output_text[:stop_index]
|
||||
return stop_str
|
||||
return None
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_executor.add_lora(lora_request)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Protocol
|
||||
|
||||
import numpy as np
|
||||
from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info,
|
||||
@ -119,12 +119,18 @@ class Stats:
|
||||
time_e2e_requests: List[float]
|
||||
|
||||
|
||||
class SupportsMetricsInfo(Protocol):
|
||||
|
||||
def metrics_info(self) -> Dict[str, str]:
|
||||
...
|
||||
|
||||
|
||||
class StatLogger:
|
||||
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
|
||||
|
||||
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
|
||||
# Metadata for logging locally.
|
||||
self.last_local_log = time.monotonic()
|
||||
self.last_local_log = time.time()
|
||||
self.local_interval = local_interval
|
||||
|
||||
# Tracked stats over current local logging interval.
|
||||
@ -135,7 +141,7 @@ class StatLogger:
|
||||
self.labels = labels
|
||||
self.metrics = Metrics(labelnames=list(labels.keys()))
|
||||
|
||||
def info(self, type: str, obj: object) -> None:
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
if type == "cache_config":
|
||||
self.metrics.info_cache_config.info(obj.metrics_info())
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import pickle
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_ip, is_hip, set_cuda_visible_devices
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -18,15 +19,20 @@ try:
|
||||
if init_cached_hf_modules:
|
||||
from transformers.dynamic_module_utils import init_hf_modules
|
||||
init_hf_modules()
|
||||
self.worker = None
|
||||
self._worker: Optional[Worker] = None
|
||||
# Since the compiled DAG runs a main execution
|
||||
# in a different thread that calls cuda.set_device.
|
||||
# The flag indicates is set_device is called on
|
||||
# that thread.
|
||||
self.compiled_dag_cuda_device_set = False
|
||||
|
||||
def init_worker(self, worker_init_fn):
|
||||
self.worker = worker_init_fn()
|
||||
def init_worker(self, worker_init_fn: Callable[[], Worker]):
|
||||
self._worker = worker_init_fn()
|
||||
|
||||
@property
|
||||
def worker(self) -> Worker:
|
||||
assert self._worker is not None
|
||||
return self._worker
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.worker, name)
|
||||
@ -70,8 +76,8 @@ except ImportError as e:
|
||||
logger.warning(f"Failed to import Ray with {e!r}. "
|
||||
"For distributed inference, please install Ray with "
|
||||
"`pip install ray`.")
|
||||
ray = None
|
||||
RayWorkerVllm = None
|
||||
ray = None # type: ignore
|
||||
RayWorkerVllm = None # type: ignore
|
||||
|
||||
|
||||
def initialize_ray_cluster(
|
||||
|
||||
@ -47,6 +47,7 @@ async def generate(request: Request) -> Response:
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
assert engine is not None
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
|
||||
@ -86,7 +86,7 @@ class LLM:
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = True,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
@ -170,8 +170,12 @@ class LLM:
|
||||
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
||||
|
||||
# Add requests to the engine.
|
||||
num_requests = len(prompts) if prompts is not None else len(
|
||||
prompt_token_ids)
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
else:
|
||||
assert prompt_token_ids is not None
|
||||
num_requests = len(prompt_token_ids)
|
||||
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||
|
||||
@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel):
|
||||
description=(
|
||||
"If specified, the output will follow the context free grammar."),
|
||||
)
|
||||
guided_decoding_backend: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be either "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
@ -265,6 +271,12 @@ class CompletionRequest(BaseModel):
|
||||
description=(
|
||||
"If specified, the output will follow the context free grammar."),
|
||||
)
|
||||
guided_decoding_backend: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be one of "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
|
||||
@ -63,13 +63,18 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
try:
|
||||
token_ids = self._validate_prompt_and_tokenize(request,
|
||||
prompt=prompt)
|
||||
# Tokenize/detokenize depending on prompt format (string/token list)
|
||||
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
|
||||
request, prompt=prompt)
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
decoding_config = self.engine.engine.decoding_config
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
guided_decode_logits_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
request, await self.engine.get_tokenizer()))
|
||||
guided_decoding_backend, request, await
|
||||
self.engine.get_tokenizer()))
|
||||
if guided_decode_logits_processor:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
@ -78,8 +83,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = self.engine.generate(prompt, sampling_params,
|
||||
request_id, token_ids,
|
||||
result_generator = self.engine.generate(prompt_text, sampling_params,
|
||||
request_id, prompt_ids,
|
||||
lora_request)
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||
Optional, Tuple)
|
||||
@ -17,7 +16,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
|
||||
return prompt_is_tokens, prompts
|
||||
|
||||
|
||||
def merge_async_iterators(*iterators):
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
"""
|
||||
queue = asyncio.Queue()
|
||||
|
||||
finished = [False] * len(iterators)
|
||||
|
||||
async def producer(i, iterator):
|
||||
try:
|
||||
async for item in iterator:
|
||||
await queue.put((i, item))
|
||||
except Exception as e:
|
||||
await queue.put(e)
|
||||
finished[i] = True
|
||||
|
||||
_tasks = [
|
||||
asyncio.create_task(producer(i, iterator))
|
||||
for i, iterator in enumerate(iterators)
|
||||
]
|
||||
|
||||
async def consumer():
|
||||
while not all(finished) or not queue.empty():
|
||||
item = await queue.get()
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
await asyncio.gather(*_tasks)
|
||||
|
||||
return consumer()
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
@ -124,9 +88,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
try:
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
decoding_config = self.engine.engine.decoding_config
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
guided_decode_logit_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
request, await self.engine.get_tokenizer()))
|
||||
guided_decoding_backend, request, await
|
||||
self.engine.get_tokenizer()))
|
||||
if guided_decode_logit_processor is not None:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
@ -136,23 +104,24 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
if prompt_is_tokens:
|
||||
input_ids = self._validate_prompt_and_tokenize(
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
prompt_ids=prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens)
|
||||
else:
|
||||
input_ids = self._validate_prompt_and_tokenize(
|
||||
prompt_formats = self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
prompt=prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens)
|
||||
prompt_ids, prompt_text = prompt_formats
|
||||
|
||||
generators.append(
|
||||
self.engine.generate(prompt,
|
||||
self.engine.generate(prompt_text,
|
||||
sampling_params,
|
||||
f"{request_id}-{i}",
|
||||
prompt_token_ids=input_ids,
|
||||
prompt_token_ids=prompt_ids,
|
||||
lora_request=lora_request))
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
@ -326,7 +295,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
output_text = prompt_text
|
||||
elif request.echo and request.max_tokens > 0:
|
||||
token_ids = prompt_token_ids + output.token_ids
|
||||
top_logprobs = prompt_logprobs + output.logprobs
|
||||
top_logprobs = (prompt_logprobs + output.logprobs
|
||||
if request.logprobs else None)
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
@ -334,6 +304,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
output_text = output.text
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert top_logprobs is not None, (
|
||||
"top_logprobs must be provided when logprobs "
|
||||
"is requested")
|
||||
logprobs = self._create_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import conint
|
||||
|
||||
@ -99,27 +99,32 @@ class OpenAIServing:
|
||||
last_token_len = 0
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id].logprob
|
||||
if step_top_logprobs is None:
|
||||
token = self.tokenizer.decode(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(None)
|
||||
logprobs.top_logprobs.append(None)
|
||||
else:
|
||||
token_logprob = None
|
||||
token = step_top_logprobs[token_id].decoded_token
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
token_logprob = step_top_logprobs[token_id].logprob
|
||||
token = step_top_logprobs[token_id].decoded_token
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({
|
||||
p.decoded_token: p.logprob
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
||||
last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({
|
||||
p.decoded_token: p.logprob
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
return logprobs
|
||||
|
||||
def create_error_response(
|
||||
@ -164,12 +169,12 @@ class OpenAIServing:
|
||||
raise ValueError("The model `{request.model}` does not exist.")
|
||||
|
||||
def _validate_prompt_and_tokenize(
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||
) -> List[int]:
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[conint(ge=1)] = None
|
||||
) -> Tuple[List[int], str]:
|
||||
if not (prompt or prompt_ids):
|
||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||
if (prompt and prompt_ids):
|
||||
@ -187,6 +192,8 @@ class OpenAIServing:
|
||||
else:
|
||||
input_ids = prompt_ids
|
||||
|
||||
input_text = prompt if prompt is not None else self.tokenizer.decode(
|
||||
prompt_ids)
|
||||
token_num = len(input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
@ -201,4 +208,4 @@ class OpenAIServing:
|
||||
f"{request.max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.", )
|
||||
else:
|
||||
return input_ids
|
||||
return input_ids, input_text
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig)
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -16,22 +15,13 @@ logger = init_logger(__name__)
|
||||
|
||||
class CPUExecutor(ExecutorBase):
|
||||
|
||||
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig], *args, **kwargs) -> None:
|
||||
assert device_config.device_type == "cpu"
|
||||
assert lora_config is None, "cpu backend doesn't support LoRA"
|
||||
model_config = _verify_and_get_model_config(model_config)
|
||||
cache_config = _verify_and_get_cache_config(cache_config)
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
def _init_executor(self) -> None:
|
||||
assert self.device_config.device_type == "cpu"
|
||||
assert self.lora_config is None, "cpu backend doesn't support LoRA"
|
||||
self.model_config = _verify_and_get_model_config(self.model_config)
|
||||
self.cache_config = _verify_and_get_cache_config(self.cache_config)
|
||||
self.scheduler_config = _verify_and_get_scheduler_config(
|
||||
self.scheduler_config)
|
||||
|
||||
# Instantiate the worker and load the model to CPU.
|
||||
self._init_worker()
|
||||
@ -60,7 +50,7 @@ class CPUExecutor(ExecutorBase):
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
@ -73,7 +63,10 @@ class CPUExecutor(ExecutorBase):
|
||||
# NOTE: We log here to avoid multiple logs when number of workers is
|
||||
# greater than one. We could log in the engine, but not all executors
|
||||
# have GPUs.
|
||||
logger.info(f"# CPU blocks: {num_cpu_blocks}")
|
||||
# NOTE: `cpu block` for CPU backend is located on CPU memory but is
|
||||
# referred as `gpu block`. Because we want to reuse the existing block
|
||||
# management procedure.
|
||||
logger.info(f"# CPU blocks: {num_gpu_blocks}")
|
||||
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||
|
||||
def execute_model(self,
|
||||
@ -95,7 +88,7 @@ class CPUExecutor(ExecutorBase):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def check_health(self) -> None:
|
||||
@ -116,6 +109,15 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
|
||||
return config
|
||||
|
||||
|
||||
def _verify_and_get_scheduler_config(
|
||||
config: SchedulerConfig) -> SchedulerConfig:
|
||||
if config.chunked_prefill_enabled:
|
||||
logger.warning("Chunked prefill is not supported on CPU, disable it.")
|
||||
config.chunked_prefill_enabled = False
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
|
||||
_GB = 1 << 30
|
||||
if config.enable_prefix_caching:
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
TensorizerConfig, VisionLanguageConfig)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
|
||||
@ -16,7 +16,6 @@ class ExecutorBase(ABC):
|
||||
that can execute the model on multiple devices.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
@ -27,11 +26,26 @@ class ExecutorBase(ABC):
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
tensorizer_config: Optional[TensorizerConfig],
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.speculative_config = speculative_config
|
||||
self.tensorizer_config = tensorizer_config
|
||||
|
||||
self._init_executor()
|
||||
|
||||
@abstractmethod
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
def _init_executor(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available blocks for the GPU KV cache and
|
||||
swappable CPU KV cache.
|
||||
|
||||
@ -39,7 +53,7 @@ class ExecutorBase(ABC):
|
||||
ExecutorBase may require modification of the result, e.g. to ensure the
|
||||
selected cache sizes are compatible with all workers.
|
||||
|
||||
Returns a tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
||||
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
|
||||
are blocks that are "active" on the device and can be appended to.
|
||||
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
|
||||
appended to.
|
||||
@ -71,7 +85,7 @@ class ExecutorBase(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_loras(self) -> List[int]:
|
||||
def list_loras(self) -> Set[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
@ -94,8 +108,7 @@ class ExecutorAsyncBase(ExecutorBase):
|
||||
"""Executes one model step on the given sequences."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def check_health_async(self) -> None:
|
||||
"""Checks if the executor is healthy. If not, it should raise an
|
||||
exception."""
|
||||
raise NotImplementedError
|
||||
self.check_health()
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -15,26 +12,8 @@ logger = init_logger(__name__)
|
||||
|
||||
class GPUExecutor(ExecutorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
|
||||
assert (not speculative_config
|
||||
def _init_executor(self) -> None:
|
||||
assert (not self.speculative_config
|
||||
), "Speculative decoding not yet supported for GPU backend"
|
||||
|
||||
# Instantiate the worker and load the model to GPU.
|
||||
@ -61,12 +40,13 @@ class GPUExecutor(ExecutorBase):
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
tensorizer_config=self.tensorizer_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
@ -104,7 +84,7 @@ class GPUExecutor(ExecutorBase):
|
||||
assert lora_id > 0, "lora_id must be greater than 0."
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def check_health(self) -> None:
|
||||
@ -128,8 +108,3 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy)
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
# GPUExecutor will always be healthy as long as
|
||||
# it's running.
|
||||
return
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -13,23 +10,10 @@ logger = init_logger(__name__)
|
||||
|
||||
class NeuronExecutor(ExecutorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
assert lora_config is None, "LoRA is not supported for Neuron backend."
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
assert (not speculative_config
|
||||
def _init_executor(self) -> None:
|
||||
assert (self.lora_config is
|
||||
None), "LoRA is not supported for Neuron backend."
|
||||
assert (not self.speculative_config
|
||||
), "Speculative decoding not yet supported for Neuron backend."
|
||||
|
||||
# Instantiate the worker and load the model to the device.
|
||||
@ -43,11 +27,12 @@ class NeuronExecutor(ExecutorBase):
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
self.cache_config,
|
||||
)
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks by invoking the
|
||||
underlying worker.
|
||||
"""
|
||||
@ -78,7 +63,7 @@ class NeuronExecutor(ExecutorBase):
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.driver_worker.remove_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.driver_worker.list_loras()
|
||||
|
||||
def check_health(self) -> None:
|
||||
|
||||
@ -3,11 +3,8 @@ import copy
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.engine.ray_utils import RayWorkerVllm, ray
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
@ -32,25 +29,8 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0))
|
||||
|
||||
class RayGPUExecutor(ExecutorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.vision_language_config = vision_language_config
|
||||
assert (not speculative_config
|
||||
def _init_executor(self) -> None:
|
||||
assert (not self.speculative_config
|
||||
), "Speculative decoding not yet supported for RayGPU backend."
|
||||
|
||||
assert self.parallel_config.worker_use_ray
|
||||
@ -68,6 +48,21 @@ class RayGPUExecutor(ExecutorBase):
|
||||
if USE_RAY_COMPILED_DAG:
|
||||
self.forward_dag = self._compiled_ray_dag()
|
||||
|
||||
def _configure_ray_workers_use_nsight(self,
|
||||
ray_remote_kwargs) -> Dict[str, Any]:
|
||||
# If nsight profiling is enabled, we need to set the profiling
|
||||
# configuration for the ray workers as runtime env.
|
||||
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
|
||||
runtime_env.update({
|
||||
"nsight": {
|
||||
"t": "cuda,cudnn,cublas",
|
||||
"o": "'worker_process_%p'",
|
||||
"cuda-graph-trace": "node",
|
||||
}
|
||||
})
|
||||
|
||||
return ray_remote_kwargs
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if self.parallel_config.tensor_parallel_size == 1:
|
||||
@ -83,6 +78,10 @@ class RayGPUExecutor(ExecutorBase):
|
||||
# The remaining workers are the actual ray actors.
|
||||
self.workers: List[RayWorkerVllm] = []
|
||||
|
||||
if self.parallel_config.ray_workers_use_nsight:
|
||||
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
||||
ray_remote_kwargs)
|
||||
|
||||
# Create the workers.
|
||||
driver_ip = get_ip()
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
@ -171,6 +170,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
tensorizer_config=self.tensorizer_config,
|
||||
))
|
||||
|
||||
# Initialize the driver worker with the Worker class.
|
||||
@ -187,6 +187,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
distributed_init_method=distributed_init_method,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
tensorizer_config=self.tensorizer_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
@ -197,7 +198,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
max_parallel_loading_workers,
|
||||
)
|
||||
|
||||
def determine_num_available_blocks(self) -> tuple[int, int]:
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of available KV blocks.
|
||||
|
||||
This invokes `determine_num_available_blocks` on each worker and takes
|
||||
@ -205,7 +206,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
compatible with all workers.
|
||||
|
||||
Returns:
|
||||
- tuple[num_gpu_blocks, num_cpu_blocks]
|
||||
- Tuple[num_gpu_blocks, num_cpu_blocks]
|
||||
"""
|
||||
# Get the maximum number of blocks that can be allocated on GPU and CPU.
|
||||
num_blocks = self._run_workers("determine_num_available_blocks", )
|
||||
@ -269,14 +270,14 @@ class RayGPUExecutor(ExecutorBase):
|
||||
lora_id=lora_id,
|
||||
)
|
||||
|
||||
def list_loras(self) -> List[int]:
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self._run_workers("list_loras")
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[List[Any]] = None,
|
||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
use_ray_compiled_dag: bool = False,
|
||||
@ -291,6 +292,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
if use_ray_compiled_dag:
|
||||
# Right now, compiled DAG can only accept a single
|
||||
# input. TODO(sang): Fix it.
|
||||
assert self.forward_dag is not None
|
||||
output_channels = self.forward_dag.execute(1)
|
||||
else:
|
||||
# Start the ray workers first.
|
||||
@ -369,7 +371,7 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[List[Any]] = None,
|
||||
driver_args: Optional[Tuple[Any, ...]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
@ -411,7 +413,3 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
|
||||
# Only the driver worker returns the sampling results.
|
||||
output = all_outputs[0]
|
||||
return output
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
"""Raises an error if engine is unhealthy."""
|
||||
self._check_if_any_actor_is_dead()
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
|
||||
|
||||
@ -26,7 +27,7 @@ class NewLineFormatter(logging.Formatter):
|
||||
|
||||
|
||||
_root_logger = logging.getLogger("vllm")
|
||||
_default_handler = None
|
||||
_default_handler: Optional[logging.Handler] = None
|
||||
|
||||
|
||||
def _setup_logger():
|
||||
@ -55,7 +56,12 @@ def init_logger(name: str):
|
||||
# Use the same settings as above for root logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
|
||||
|
||||
if VLLM_CONFIGURE_LOGGING:
|
||||
if _default_handler is None:
|
||||
raise ValueError(
|
||||
"_default_handler is not set up. This should never happen!"
|
||||
" Please open an issue on Github.")
|
||||
logger.addHandler(_default_handler)
|
||||
logger.propagate = False
|
||||
return logger
|
||||
|
||||
@ -10,6 +10,12 @@ import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
@ -18,18 +24,27 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.utils import (
|
||||
split_tensor_along_last_dim)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
||||
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
|
||||
"""Returns the device for where to place the LoRA tensors."""
|
||||
# unquantizedLinear
|
||||
if hasattr(base_layer, "weight"):
|
||||
return base_layer.weight.device
|
||||
# GPTQ/AWQ/SqueezeLLM
|
||||
elif hasattr(base_layer, "qweight"):
|
||||
return base_layer.qweight.device
|
||||
# marlin
|
||||
elif hasattr(base_layer, "B"):
|
||||
return base_layer.B.device
|
||||
else:
|
||||
raise ValueError(f"Unsupported base layer: {base_layer}")
|
||||
|
||||
|
||||
def _apply_lora(
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
@ -268,12 +283,13 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
|
||||
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
|
||||
embedding_len = self.indices_len[3]
|
||||
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
|
||||
full_lora_a_embeddings = F.embedding(
|
||||
x + indices,
|
||||
self.lora_a_stacked_2d,
|
||||
)
|
||||
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
|
||||
indices = self.embeddings_indices[0][:embedding_len].view_as(x)
|
||||
full_output = self.base_layer.forward(
|
||||
x.add_(indices * added_tokens_mask))
|
||||
|
||||
@ -302,6 +318,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size = self.base_layer.input_size
|
||||
self.output_size = self.base_layer.output_size_per_partition
|
||||
self.device = _get_lora_device(self.base_layer)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
@ -312,17 +331,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.weight.shape[0],
|
||||
self.output_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
@ -368,7 +387,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
self.base_layer, x, bias)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
@ -402,10 +421,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
if self.base_layer.skip_bias_add else None)
|
||||
return output, output_bias
|
||||
|
||||
@property
|
||||
def linear_weights(self):
|
||||
return self.base_layer.linear_weights
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
@ -446,18 +461,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
) for _ in range(n_slices))
|
||||
self.lora_b_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.weight.shape[0] // 2,
|
||||
self.output_size // 2,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
) for _ in range(n_slices))
|
||||
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
@ -505,7 +520,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
self.base_layer, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
@ -623,25 +638,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
self.lora_b_stacked = (
|
||||
@ -651,7 +666,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
self.q_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
@ -659,7 +674,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
self.kv_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
@ -667,7 +682,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
self.kv_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
@ -746,7 +761,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
self.base_layer, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
@ -770,6 +785,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(self, base_layer: RowParallelLinear) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.input_size = self.base_layer.input_size_per_partition
|
||||
self.output_size = self.base_layer.output_size
|
||||
self.device = _get_lora_device(self.base_layer)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
@ -781,20 +799,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
self.input_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.weight.shape[0],
|
||||
self.output_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
device=self.device,
|
||||
)
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
@ -813,7 +831,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.reset_lora(index)
|
||||
if self.base_layer.tp_size > 1:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.base_layer.weight.shape[1]
|
||||
shard_size = self.input_size
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_a = lora_a[start_idx:end_idx, :]
|
||||
@ -838,7 +856,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x)
|
||||
self.base_layer, x)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
@ -888,7 +906,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.base_layer.weight
|
||||
|
||||
return self.base_layer.weight if hasattr(
|
||||
self.base_layer, "weight") else self.base_layer.qweight
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
@ -939,9 +959,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
|
||||
if 32000 < self.base_layer.vocab_size > 33024:
|
||||
if 32000 < self.base_layer.vocab_size > 128512:
|
||||
raise ValueError("When using LoRA, vocab size must be "
|
||||
"32000 >= vocab_size <= 33024")
|
||||
"32000 >= vocab_size <= 128512")
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
|
||||
25
vllm/model_executor/guided_decoding/__init__.py
Normal file
25
vllm/model_executor/guided_decoding/__init__.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
|
||||
get_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
|
||||
async def get_guided_decoding_logits_processor(
|
||||
guided_decoding_backend: str, request: Union[CompletionRequest,
|
||||
ChatCompletionRequest],
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
if guided_decoding_backend == 'outlines':
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
if guided_decoding_backend == 'lm-format-enforcer':
|
||||
return await get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
||||
@ -0,0 +1,69 @@
|
||||
from functools import lru_cache
|
||||
from json import loads as json_loads
|
||||
from typing import Optional, Union
|
||||
|
||||
from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
|
||||
RegexParser, StringParser,
|
||||
TokenEnforcerTokenizerData, UnionParser)
|
||||
from lmformatenforcer.integrations.vllm import (
|
||||
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
|
||||
async def get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
|
||||
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer)
|
||||
character_level_parser: CharacterLevelParser
|
||||
if request.guided_json:
|
||||
schema = _normalize_json_schema_object(request.guided_json)
|
||||
character_level_parser = JsonSchemaParser(schema)
|
||||
elif request.guided_choice:
|
||||
character_level_parser = UnionParser(
|
||||
[StringParser(choice) for choice in request.guided_choice])
|
||||
elif request.guided_regex:
|
||||
character_level_parser = RegexParser(request.guided_regex)
|
||||
elif request.guided_grammar:
|
||||
# CFG grammar not supported by LMFE, revert to outlines
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
elif (request.response_format is not None
|
||||
and request.response_format.type == "json_object"):
|
||||
character_level_parser = JsonSchemaParser(
|
||||
None) # None means any json object
|
||||
else:
|
||||
return None
|
||||
|
||||
logits_processor = build_vllm_logits_processor(tokenizer_data,
|
||||
character_level_parser)
|
||||
return logits_processor
|
||||
|
||||
|
||||
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
|
||||
if isinstance(schema, str):
|
||||
return json_loads(schema)
|
||||
if isinstance(schema, dict):
|
||||
return schema
|
||||
if isinstance(schema, BaseModel):
|
||||
return schema.model_json_schema()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
|
||||
return build_vllm_token_enforcer_tokenizer_data(tokenizer)
|
||||
@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_logits_processors import (CFGLogitsProcessor,
|
||||
JSONLogitsProcessor,
|
||||
RegexLogitsProcessor)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
|
||||
|
||||
class GuidedDecodingMode(Enum):
|
||||
@ -54,7 +53,7 @@ pair : UNESCAPED_STRING ":" value
|
||||
global_thread_pool = None # used for generating logits processor fsm
|
||||
|
||||
|
||||
async def get_guided_decoding_logits_processor(
|
||||
async def get_outlines_guided_decoding_logits_processor(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
||||
"""
|
||||
@ -91,7 +90,7 @@ def _get_guide_and_mode(
|
||||
json = request.guided_json
|
||||
if isinstance(json, dict):
|
||||
# turn dict into hashable string
|
||||
json = json_dumps(json, sort_keys=True)
|
||||
json = json_dumps(json)
|
||||
elif isinstance(json, BaseModel):
|
||||
# use pydantic signature so that different model classes
|
||||
# with the same fields will get hashed the same
|
||||
@ -13,9 +13,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -27,50 +29,6 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
|
||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. The decoder of outlines, returns a list whereas
|
||||
the decode of vLLM returns an str. To sync the vLLM decoder with
|
||||
outlines internal api, the decoder should be adapted. In addition
|
||||
we need to handle the missing spaces to Llama's tokenizer to be
|
||||
able to compile FSMs for this model.
|
||||
|
||||
"""
|
||||
if getattr(tokenizer, "_outlines_adapted", False):
|
||||
return tokenizer
|
||||
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
def change_decoder(
|
||||
decoder: Callable[[List[int]], str]
|
||||
) -> Callable[[List[int]], List[str]]:
|
||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||
|
||||
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
||||
return [decoder(inp_tokens)]
|
||||
|
||||
return new_decoder
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
tokenizer.decode = change_decoder(tokenizer.decode)
|
||||
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
||||
|
||||
return tokenizer
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize the FSM states."""
|
||||
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||
@ -78,7 +36,6 @@ class BaseLogitsProcessor:
|
||||
def __call__(self, input_ids: List[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
"""Use the FSM to bias the logits before sampling the next token."""
|
||||
|
||||
seq_id = hash(tuple(input_ids))
|
||||
|
||||
if len(input_ids) == 0:
|
||||
@ -96,7 +53,6 @@ class BaseLogitsProcessor:
|
||||
device=scores.device)
|
||||
mask[allowed_tokens] = 0
|
||||
scores.add_(mask)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
@ -113,7 +69,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
self.fsm = fsm
|
||||
|
||||
@ -167,6 +123,54 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
fsm = CFGFSM(cfg, tokenizer)
|
||||
self.fsm = fsm
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. The decoder of outlines, returns a list whereas
|
||||
the decode of vLLM returns an str. To sync the vLLM decoder with
|
||||
outlines internal api, the decoder should be adapted. In addition
|
||||
we need to handle the missing spaces to Llama's tokenizer to be
|
||||
able to compile FSMs for this model.
|
||||
|
||||
"""
|
||||
if getattr(tokenizer, "_outlines_adapted", False):
|
||||
return tokenizer
|
||||
|
||||
tokenizer = copy.deepcopy(tokenizer)
|
||||
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
def change_decoder(
|
||||
decoder: Callable[[List[int]],
|
||||
str]) -> Callable[[List[int]], List[str]]:
|
||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||
|
||||
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
||||
return [decoder(inp_tokens)]
|
||||
|
||||
return new_decoder
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
tokenizer.decode = change_decoder(tokenizer.decode)
|
||||
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
||||
|
||||
return tokenizer
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user