mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 11:47:09 +08:00
Merge branch 'main' into mla-support-awq-marlin
This commit is contained in:
commit
b8510f1081
@ -107,6 +107,10 @@ steps:
|
||||
mirror_hardwares: [amd]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/llm
|
||||
- tests/entrypoints/openai
|
||||
- tests/entrypoints/test_chat_utils
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
|
||||
@ -124,9 +128,10 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/
|
||||
- vllm/core/
|
||||
- tests/distributed
|
||||
- tests/distributed/test_utils
|
||||
- tests/distributed/test_pynccl
|
||||
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||
- tests/compile
|
||||
- tests/compile/test_basic_correctness
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
commands:
|
||||
@ -174,6 +179,9 @@ steps:
|
||||
- vllm/
|
||||
- tests/engine
|
||||
- tests/tokenization
|
||||
- tests/test_sequence
|
||||
- tests/test_config
|
||||
- tests/test_logger
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
|
||||
# OOM in the CI unless we run this separately
|
||||
|
||||
2
.github/workflows/cleanup_pr_body.yml
vendored
2
.github/workflows/cleanup_pr_body.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
|
||||
6
.github/workflows/lint-and-deploy.yaml
vendored
6
.github/workflows/lint-and-deploy.yaml
vendored
@ -17,12 +17,12 @@ jobs:
|
||||
version: v3.14.4
|
||||
|
||||
#Python is required because ct lint runs Yamale and yamllint which require Python.
|
||||
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: '3.13'
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@e6669bcd63d7cb57cb4380c33043eebe5d111992 # v2.6.1
|
||||
uses: helm/chart-testing-action@0d28d3144d3a25ea2cc349d6e59901c4ff469b3b # v2.7.0
|
||||
with:
|
||||
version: v3.10.1
|
||||
|
||||
@ -47,7 +47,7 @@ jobs:
|
||||
aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive
|
||||
|
||||
- name: Create kind cluster
|
||||
uses: helm/kind-action@0025e74a8c7512023d06dc019c617aa3cf561fde # v1.10.0
|
||||
uses: helm/kind-action@a1b0e391336a6ee6713a0583f8c6240d70863de3 # v1.12.0
|
||||
|
||||
- name: Build the Docker image vllm cpu
|
||||
run: docker buildx build -f Dockerfile.cpu -t vllm-cpu-env .
|
||||
|
||||
3
.github/workflows/pre-commit.yml
vendored
3
.github/workflows/pre-commit.yml
vendored
@ -10,10 +10,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
- uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: echo "::add-matcher::.github/workflows/matchers/actionlint.json"
|
||||
- run: echo "::add-matcher::.github/workflows/matchers/mypy.json"
|
||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||
with:
|
||||
extra_args: --all-files --hook-stage manual
|
||||
|
||||
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -13,7 +13,7 @@ jobs:
|
||||
actions: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@28ca1036281a5e5922ead5184a1bbf96e5fc984e # v9.0.0
|
||||
- uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0
|
||||
with:
|
||||
# Increasing this value ensures that changes to this workflow
|
||||
# propagate to all issues and PRs in days rather than months
|
||||
|
||||
@ -116,13 +116,6 @@ repos:
|
||||
language: python
|
||||
types: [python]
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
language: system
|
||||
verbose: true
|
||||
pass_filenames: false
|
||||
exclude: 'vllm/third_party/.*'
|
||||
- id: check-filenames
|
||||
name: Check for spaces in all filenames
|
||||
entry: bash
|
||||
@ -133,3 +126,12 @@ repos:
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
exclude: 'vllm/third_party/.*'
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
language: system
|
||||
verbose: true
|
||||
pass_filenames: false
|
||||
exclude: 'vllm/third_party/.*'
|
||||
# Insert new entries above the `suggestion` entry
|
||||
|
||||
@ -192,7 +192,7 @@ set_gencode_flags_for_srcs(
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
message(STATUS "Enabling cumem allocator extension.")
|
||||
# link against cuda driver library
|
||||
list(APPEND CUMEM_LIBS cuda)
|
||||
list(APPEND CUMEM_LIBS CUDA::cuda_driver)
|
||||
define_gpu_extension_target(
|
||||
cumem_allocator
|
||||
DESTINATION vllm
|
||||
|
||||
@ -15,6 +15,10 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
|
||||
---
|
||||
|
||||
We are excited to invite you to our Menlo Park meetup with Meta, evening of Thursday, February 27! Meta engineers will discuss the improvements on top of vLLM, and vLLM contributors will share updates from the v0.7.x series of releases. [Register Now](https://lu.ma/h7g3kuj9)
|
||||
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
|
||||
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||
|
||||
@ -19,3 +19,11 @@ mkdir coco -p
|
||||
wget http://images.cocodataset.org/zips/train2017.zip -O coco/train2017.zip
|
||||
unzip coco/train2017.zip -d coco/
|
||||
```
|
||||
|
||||
# Downloading the BurstGPT dataset
|
||||
|
||||
You can download the BurstGPT v1.1 dataset by running:
|
||||
|
||||
```bash
|
||||
wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv
|
||||
```
|
||||
|
||||
@ -38,6 +38,7 @@ from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
from datasets import load_dataset
|
||||
@ -131,6 +132,35 @@ def sample_sharegpt_requests(
|
||||
return filtered_dataset
|
||||
|
||||
|
||||
def sample_burstgpt_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
random_seed: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, int, int, None]]:
|
||||
df = pd.read_csv(dataset_path)
|
||||
gpt4_df = df[df["Model"] == "GPT-4"]
|
||||
# Remove the failed requests (i.e., response length is 0)
|
||||
gpt4_df = gpt4_df[gpt4_df["Response tokens"] > 0]
|
||||
# Randomly sample num_requests from the dataset
|
||||
if num_requests <= len(gpt4_df):
|
||||
gpt4_df = gpt4_df.sample(n=num_requests, random_state=random_seed)
|
||||
else:
|
||||
gpt4_df = gpt4_df.sample(n=num_requests,
|
||||
random_state=random_seed,
|
||||
replace=True)
|
||||
# Convert the dataframe to a list of tuples
|
||||
dataset = gpt4_df.values.tolist()
|
||||
input_requests = []
|
||||
for i in range(num_requests):
|
||||
input_len = int(dataset[i][2])
|
||||
output_len = int(dataset[i][3])
|
||||
prompt = tokenizer.decode([(i + j) % tokenizer.vocab_size
|
||||
for j in range(input_len)])
|
||||
input_requests.append((prompt, input_len, output_len, None))
|
||||
return input_requests
|
||||
|
||||
|
||||
def sample_sonnet_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
@ -830,6 +860,14 @@ def main(args: argparse.Namespace):
|
||||
fixed_output_len=args.sharegpt_output_len,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "burstgpt":
|
||||
input_requests = sample_burstgpt_requests(
|
||||
dataset_path=args.dataset_path,
|
||||
num_requests=args.num_prompts,
|
||||
random_seed=args.seed,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
elif args.dataset_name == "sonnet":
|
||||
# Do not format the prompt, pass to message directly
|
||||
if args.backend == "openai-chat":
|
||||
@ -995,7 +1033,7 @@ if __name__ == "__main__":
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="sharegpt",
|
||||
choices=["sharegpt", "sonnet", "random", "hf"],
|
||||
choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
parser.add_argument("--dataset-path",
|
||||
@ -1237,11 +1275,12 @@ if __name__ == "__main__":
|
||||
'--tokenizer-mode',
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=['auto', 'slow', 'mistral'],
|
||||
choices=['auto', 'slow', 'mistral', 'custom'],
|
||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||
'fast tokenizer if available.\n* "slow" will '
|
||||
'always use the slow tokenizer. \n* '
|
||||
'"mistral" will always use the `mistral_common` tokenizer.')
|
||||
'"mistral" will always use the `mistral_common` tokenizer. \n*'
|
||||
'"custom" will use --tokenizer to select the preregistered tokenizer.')
|
||||
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
#include <ATen/cuda/Atomic.cuh>
|
||||
|
||||
#include "../cuda_compat.h"
|
||||
#include "../dispatch_utils.h"
|
||||
|
||||
@ -1122,4 +1122,4 @@ void paged_attention(
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
||||
@ -19,17 +19,19 @@ Currently, there are no pre-built OpenVINO wheels.
|
||||
|
||||
### Build wheel from source
|
||||
|
||||
First, install Python. For example, on Ubuntu 22.04, you can run:
|
||||
First, install Python and ensure you lave the latest pip. For example, on Ubuntu 22.04, you can run:
|
||||
|
||||
```console
|
||||
sudo apt-get update -y
|
||||
sudo apt-get install python3
|
||||
pip install --upgrade pip
|
||||
```
|
||||
|
||||
Second, install prerequisites vLLM OpenVINO backend installation:
|
||||
Second, clone vLLM and install prerequisites for the vLLM OpenVINO backend installation:
|
||||
|
||||
```console
|
||||
pip install --upgrade pip
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
```
|
||||
|
||||
|
||||
@ -856,7 +856,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
- * `UltravoxModel`
|
||||
* Ultravox
|
||||
* T + A<sup>E+</sup>
|
||||
* `fixie-ai/ultravox-v0_3`
|
||||
* `fixie-ai/ultravox-v0_5-llama-3_2-1b`
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
* ✅︎
|
||||
|
||||
@ -359,12 +359,12 @@ export VLLM_VIDEO_FETCH_TIMEOUT=<timeout>
|
||||
### Audio
|
||||
|
||||
Audio input is supported according to [OpenAI Audio API](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in).
|
||||
Here is a simple example using Ultravox-v0.3.
|
||||
Here is a simple example using Ultravox-v0.5-1B.
|
||||
|
||||
First, launch the OpenAI-compatible server:
|
||||
|
||||
```bash
|
||||
vllm serve fixie-ai/ultravox-v0_3
|
||||
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b
|
||||
```
|
||||
|
||||
Then, you can use the OpenAI client as follows:
|
||||
|
||||
@ -24,9 +24,9 @@ question_per_audio_count = {
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# Ultravox 0.3
|
||||
# Ultravox 0.5-1B
|
||||
def run_ultravox(question: str, audio_count: int):
|
||||
model_name = "fixie-ai/ultravox-v0_3"
|
||||
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
messages = [{
|
||||
|
||||
530
examples/offline_inference/prithvi_geospatial_mae.py
Normal file
530
examples/offline_inference/prithvi_geospatial_mae.py
Normal file
@ -0,0 +1,530 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This is a demo script showing how to use the
|
||||
PrithviGeospatialMAE model with vLLM
|
||||
This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
|
||||
|
||||
Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
|
||||
|
||||
The requirements for running this script are:
|
||||
- Installing [terratorch, albumentations, rasterio] in your python environment
|
||||
- downloading the model weights in a 'model' folder local to the script
|
||||
(temporary measure until the proper config.json file is uploaded to HF)
|
||||
- download an input example image (India_900498_S2Hand.tif) and place it in
|
||||
the same folder with the script (or specify with the --data_file argument)
|
||||
|
||||
Run the example:
|
||||
python prithvi_geospatial_mae.py
|
||||
|
||||
""" # noqa: E501
|
||||
import argparse
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
import albumentations
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||
|
||||
from vllm import LLM
|
||||
|
||||
NO_DATA = -9999
|
||||
NO_DATA_FLOAT = 0.0001
|
||||
OFFSET = 0
|
||||
PERCENTILE = 99
|
||||
|
||||
model_config = """{
|
||||
"architectures": ["PrithviGeoSpatialMAE"],
|
||||
"num_classes": 0,
|
||||
"pretrained_cfg": {
|
||||
"task_args": {
|
||||
"task": "SemanticSegmentationTask",
|
||||
"model_factory": "EncoderDecoderFactory",
|
||||
"loss": "ce",
|
||||
"ignore_index": -1,
|
||||
"lr": 0.001,
|
||||
"freeze_backbone": false,
|
||||
"freeze_decoder": false,
|
||||
"plot_on_val": 10,
|
||||
"optimizer": "AdamW",
|
||||
"scheduler": "CosineAnnealingLR"
|
||||
},
|
||||
"model_args": {
|
||||
"backbone_pretrained": false,
|
||||
"backbone": "prithvi_eo_v2_300_tl",
|
||||
"decoder": "UperNetDecoder",
|
||||
"decoder_channels": 256,
|
||||
"decoder_scale_modules": true,
|
||||
"num_classes": 2,
|
||||
"rescale": true,
|
||||
"backbone_bands": [
|
||||
"BLUE",
|
||||
"GREEN",
|
||||
"RED",
|
||||
"NIR_NARROW",
|
||||
"SWIR_1",
|
||||
"SWIR_2"
|
||||
],
|
||||
"head_dropout": 0.1,
|
||||
"necks": [
|
||||
{
|
||||
"name": "SelectIndices",
|
||||
"indices": [
|
||||
5,
|
||||
11,
|
||||
17,
|
||||
23
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "ReshapeTokensToImage"
|
||||
}
|
||||
]
|
||||
},
|
||||
"optimizer_params" : {
|
||||
"lr": 5.0e-05,
|
||||
"betas": [0.9, 0.999],
|
||||
"eps": [1.0e-08],
|
||||
"weight_decay": 0.05,
|
||||
"amsgrad": false,
|
||||
"maximize": false,
|
||||
"capturable": false,
|
||||
"differentiable": false
|
||||
},
|
||||
"scheduler_params" : {
|
||||
"T_max": 50,
|
||||
"eta_min": 0,
|
||||
"last_epoch": -1,
|
||||
"verbose": "deprecated"
|
||||
}
|
||||
},
|
||||
|
||||
|
||||
"torch_dtype": "float32"
|
||||
}
|
||||
"""
|
||||
|
||||
# Temporarily creating the "config.json" for the model.
|
||||
# This is going to disappear once the correct config.json is available on HF
|
||||
with open(os.path.join(os.path.dirname(__file__), "./model/config.json"),
|
||||
'w') as config_file:
|
||||
config_file.write(model_config)
|
||||
|
||||
datamodule_config = {
|
||||
'bands': ['BLUE', 'GREEN', 'RED', 'NIR_NARROW', 'SWIR_1', 'SWIR_2'],
|
||||
'batch_size':
|
||||
16,
|
||||
'constant_scale':
|
||||
0.0001,
|
||||
'data_root':
|
||||
'/dccstor/geofm-finetuning/datasets/sen1floods11',
|
||||
'drop_last':
|
||||
True,
|
||||
'no_data_replace':
|
||||
0.0,
|
||||
'no_label_replace':
|
||||
-1,
|
||||
'num_workers':
|
||||
8,
|
||||
'test_transform': [
|
||||
albumentations.Resize(always_apply=False,
|
||||
height=448,
|
||||
interpolation=1,
|
||||
p=1,
|
||||
width=448),
|
||||
albumentations.pytorch.ToTensorV2(transpose_mask=False,
|
||||
always_apply=True,
|
||||
p=1.0)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class PrithviMAE:
|
||||
|
||||
def __init__(self):
|
||||
print("Initializing PrithviMAE model")
|
||||
self.model = LLM(model=os.path.join(os.path.dirname(__file__),
|
||||
"./model"),
|
||||
skip_tokenizer_init=True,
|
||||
dtype="float32")
|
||||
|
||||
def run(self, input_data, location_coords):
|
||||
print("################ Running inference on vLLM ##############")
|
||||
# merge the inputs into one data structure
|
||||
mm_data = {
|
||||
"pixel_values":
|
||||
torch.empty(0) if input_data is None else input_data,
|
||||
"location_coords":
|
||||
torch.empty(0) if location_coords is None else location_coords
|
||||
}
|
||||
|
||||
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
|
||||
|
||||
outputs = self.model.encode(prompt, use_tqdm=False)
|
||||
print(
|
||||
"################ Inference done (it took seconds) ##############"
|
||||
)
|
||||
|
||||
return outputs[0].outputs.data
|
||||
|
||||
|
||||
def generate_datamodule():
|
||||
datamodule = Sen1Floods11NonGeoDataModule(
|
||||
data_root=datamodule_config['data_root'],
|
||||
batch_size=datamodule_config["batch_size"],
|
||||
num_workers=datamodule_config["num_workers"],
|
||||
bands=datamodule_config["bands"],
|
||||
drop_last=datamodule_config["drop_last"],
|
||||
test_transform=datamodule_config["test_transform"
|
||||
""])
|
||||
|
||||
return datamodule
|
||||
|
||||
|
||||
def process_channel_group(orig_img, channels):
|
||||
"""
|
||||
Args:
|
||||
orig_img: torch.Tensor representing original image (reference)
|
||||
with shape = (bands, H, W).
|
||||
channels: list of indices representing RGB channels.
|
||||
|
||||
Returns:
|
||||
torch.Tensor with shape (num_channels, height, width) for original image
|
||||
"""
|
||||
|
||||
orig_img = orig_img[channels, ...]
|
||||
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
|
||||
valid_mask[orig_img == NO_DATA_FLOAT] = False
|
||||
|
||||
# Rescale (enhancing contrast)
|
||||
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
||||
min_value = OFFSET
|
||||
|
||||
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0,
|
||||
1)
|
||||
|
||||
# No data as zeros
|
||||
orig_img[~valid_mask] = 0
|
||||
|
||||
return orig_img
|
||||
|
||||
|
||||
def read_geotiff(file_path: str):
|
||||
"""Read all bands from *file_path* and return image + meta info.
|
||||
|
||||
Args:
|
||||
file_path: path to image file.
|
||||
|
||||
Returns:
|
||||
np.ndarray with shape (bands, height, width)
|
||||
meta info dict
|
||||
"""
|
||||
|
||||
with rasterio.open(file_path) as src:
|
||||
img = src.read()
|
||||
meta = src.meta
|
||||
try:
|
||||
coords = src.lnglat()
|
||||
except Exception:
|
||||
# Cannot read coords
|
||||
coords = None
|
||||
|
||||
return img, meta, coords
|
||||
|
||||
|
||||
def save_geotiff(image, output_path: str, meta: dict):
|
||||
"""Save multi-band image in Geotiff file.
|
||||
|
||||
Args:
|
||||
image: np.ndarray with shape (bands, height, width)
|
||||
output_path: path where to save the image
|
||||
meta: dict with meta info.
|
||||
"""
|
||||
|
||||
with rasterio.open(output_path, "w", **meta) as dest:
|
||||
for i in range(image.shape[0]):
|
||||
dest.write(image[i, :, :], i + 1)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _convert_np_uint8(float_image: torch.Tensor):
|
||||
image = float_image.numpy() * 255.0
|
||||
image = image.astype(dtype=np.uint8)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def load_example(
|
||||
file_paths: List[str],
|
||||
mean: List[float] = None,
|
||||
std: List[float] = None,
|
||||
indices: Union[list[int], None] = None,
|
||||
):
|
||||
"""Build an input example by loading images in *file_paths*.
|
||||
|
||||
Args:
|
||||
file_paths: list of file paths .
|
||||
mean: list containing mean values for each band in the images
|
||||
in *file_paths*.
|
||||
std: list containing std values for each band in the images
|
||||
in *file_paths*.
|
||||
|
||||
Returns:
|
||||
np.array containing created example
|
||||
list of meta info for each image in *file_paths*
|
||||
"""
|
||||
|
||||
imgs = []
|
||||
metas = []
|
||||
temporal_coords = []
|
||||
location_coords = []
|
||||
|
||||
for file in file_paths:
|
||||
img, meta, coords = read_geotiff(file)
|
||||
|
||||
# Rescaling (don't normalize on nodata)
|
||||
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
||||
if indices is not None:
|
||||
img = img[..., indices]
|
||||
if mean is not None and std is not None:
|
||||
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
||||
|
||||
imgs.append(img)
|
||||
metas.append(meta)
|
||||
if coords is not None:
|
||||
location_coords.append(coords)
|
||||
|
||||
try:
|
||||
match = re.search(r'(\d{7,8}T\d{6})', file)
|
||||
if match:
|
||||
year = int(match.group(1)[:4])
|
||||
julian_day = match.group(1).split('T')[0][4:]
|
||||
if len(julian_day) == 3:
|
||||
julian_day = int(julian_day)
|
||||
else:
|
||||
julian_day = datetime.datetime.strptime(
|
||||
julian_day, '%m%d').timetuple().tm_yday
|
||||
temporal_coords.append([year, julian_day])
|
||||
except Exception as e:
|
||||
print(f'Could not extract timestamp for {file} ({e})')
|
||||
|
||||
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
||||
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
|
||||
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
||||
|
||||
return imgs, temporal_coords, location_coords, metas
|
||||
|
||||
|
||||
def run_model(input_data,
|
||||
temporal_coords,
|
||||
location_coords,
|
||||
model,
|
||||
datamodule,
|
||||
img_size,
|
||||
lightning_model=None):
|
||||
# Reflect pad if not divisible by img_size
|
||||
original_h, original_w = input_data.shape[-2:]
|
||||
pad_h = (img_size - (original_h % img_size)) % img_size
|
||||
pad_w = (img_size - (original_w % img_size)) % img_size
|
||||
input_data = np.pad(input_data,
|
||||
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
|
||||
mode="reflect")
|
||||
|
||||
# Build sliding window
|
||||
batch_size = 1
|
||||
batch = torch.tensor(input_data, device="cpu")
|
||||
windows = (batch.unfold(3, img_size,
|
||||
img_size).unfold(4, img_size, img_size))
|
||||
h1, w1 = windows.shape[3:5]
|
||||
windows = rearrange(windows,
|
||||
"b c t h1 w1 h w -> (b h1 w1) c t h w",
|
||||
h=img_size,
|
||||
w=img_size)
|
||||
|
||||
# Split into batches if number of windows > batch_size
|
||||
num_batches = windows.shape[0] // batch_size if windows.shape[
|
||||
0] > batch_size else 1
|
||||
windows = torch.tensor_split(windows, num_batches, dim=0)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
if temporal_coords:
|
||||
temporal_coords = torch.tensor(temporal_coords,
|
||||
device=device).unsqueeze(0)
|
||||
else:
|
||||
temporal_coords = None
|
||||
if location_coords:
|
||||
location_coords = torch.tensor(location_coords[0],
|
||||
device=device).unsqueeze(0)
|
||||
else:
|
||||
location_coords = None
|
||||
|
||||
# Run model
|
||||
pred_imgs = []
|
||||
for x in windows:
|
||||
# Apply standardization
|
||||
x = datamodule.test_transform(
|
||||
image=x.squeeze().numpy().transpose(1, 2, 0))
|
||||
x = datamodule.aug(x)['image']
|
||||
|
||||
with torch.no_grad():
|
||||
x = x.to(device)
|
||||
pred = model.run(x, location_coords=location_coords)
|
||||
if lightning_model:
|
||||
pred_lightning = lightning_model(
|
||||
x,
|
||||
temporal_coords=temporal_coords,
|
||||
location_coords=location_coords)
|
||||
pred_lightning = pred_lightning.output.detach().cpu()
|
||||
if not torch.equal(pred, pred_lightning):
|
||||
print("Inference output is not equal")
|
||||
y_hat = pred.argmax(dim=1)
|
||||
|
||||
y_hat = torch.nn.functional.interpolate(y_hat.unsqueeze(1).float(),
|
||||
size=img_size,
|
||||
mode="nearest")
|
||||
|
||||
pred_imgs.append(y_hat)
|
||||
|
||||
pred_imgs = torch.concat(pred_imgs, dim=0)
|
||||
|
||||
# Build images from patches
|
||||
pred_imgs = rearrange(
|
||||
pred_imgs,
|
||||
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
|
||||
h=img_size,
|
||||
w=img_size,
|
||||
b=1,
|
||||
c=1,
|
||||
h1=h1,
|
||||
w1=w1,
|
||||
)
|
||||
|
||||
# Cut padded area back to original size
|
||||
pred_imgs = pred_imgs[..., :original_h, :original_w]
|
||||
|
||||
# Squeeze (batch size 1)
|
||||
pred_imgs = pred_imgs[0]
|
||||
|
||||
return pred_imgs
|
||||
|
||||
|
||||
def main(
|
||||
data_file: str,
|
||||
output_dir: str,
|
||||
rgb_outputs: bool,
|
||||
input_indices: list[int] = None,
|
||||
):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Load model ---------------------------------------------------------------
|
||||
|
||||
model_obj = PrithviMAE()
|
||||
datamodule = generate_datamodule()
|
||||
img_size = 256 # Size of Sen1Floods11
|
||||
|
||||
# Loading data -------------------------------------------------------------
|
||||
|
||||
input_data, temporal_coords, location_coords, meta_data = load_example(
|
||||
file_paths=[data_file],
|
||||
indices=input_indices,
|
||||
)
|
||||
|
||||
meta_data = meta_data[0] # only one image
|
||||
|
||||
if input_data.mean() > 1:
|
||||
input_data = input_data / 10000 # Convert to range 0-1
|
||||
|
||||
# Running model ------------------------------------------------------------
|
||||
|
||||
channels = [
|
||||
datamodule_config['bands'].index(b) for b in ["RED", "GREEN", "BLUE"]
|
||||
] # BGR -> RGB
|
||||
|
||||
pred = run_model(input_data, temporal_coords, location_coords, model_obj,
|
||||
datamodule, img_size)
|
||||
|
||||
# Save pred
|
||||
meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
||||
pred_file = os.path.join(
|
||||
output_dir,
|
||||
f"pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||
save_geotiff(_convert_np_uint8(pred), pred_file, meta_data)
|
||||
|
||||
# Save image + pred
|
||||
meta_data.update(count=3, dtype="uint8", compress="lzw", nodata=0)
|
||||
|
||||
if input_data.mean() < 1:
|
||||
input_data = input_data * 10000 # Scale to 0-10000
|
||||
|
||||
rgb_orig = process_channel_group(
|
||||
orig_img=torch.Tensor(input_data[0, :, 0, ...]),
|
||||
channels=channels,
|
||||
)
|
||||
|
||||
pred[pred == 0.] = np.nan
|
||||
img_pred = rgb_orig * 0.7 + pred * 0.3
|
||||
img_pred[img_pred.isnan()] = rgb_orig[img_pred.isnan()]
|
||||
|
||||
img_pred_file = os.path.join(
|
||||
output_dir,
|
||||
f"rgb_pred_{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||
save_geotiff(
|
||||
image=_convert_np_uint8(img_pred),
|
||||
output_path=img_pred_file,
|
||||
meta=meta_data,
|
||||
)
|
||||
|
||||
# Save image rgb
|
||||
if rgb_outputs:
|
||||
rgb_file = os.path.join(
|
||||
output_dir, "original_rgb_"
|
||||
f"{os.path.splitext(os.path.basename(data_file))[0]}.tiff")
|
||||
save_geotiff(
|
||||
image=_convert_np_uint8(rgb_orig),
|
||||
output_path=rgb_file,
|
||||
meta=meta_data,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--data_file",
|
||||
type=str,
|
||||
default="./India_900498_S2Hand.tif",
|
||||
help="Path to the file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="output",
|
||||
help="Path to the directory where to save outputs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_indices",
|
||||
default=[1, 2, 3, 8, 11, 12],
|
||||
type=int,
|
||||
nargs="+",
|
||||
help=
|
||||
"0-based indices of the six Prithvi channels to be selected from the "
|
||||
"input. By default selects [1,2,3,8,11,12] for S2L1C data.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rgb_outputs",
|
||||
action="store_true",
|
||||
help="If present, output files will only contain RGB channels. "
|
||||
"Otherwise, all bands will be saved.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(**vars(args))
|
||||
@ -92,7 +92,7 @@ class MyLLM(LLM):
|
||||
# a hack to make the script work.
|
||||
# stop ray from manipulating CUDA_VISIBLE_DEVICES
|
||||
# at the top-level
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@ -59,7 +59,7 @@ class MyLLM(LLM):
|
||||
# a hack to make the script work.
|
||||
# stop ray from manipulating CUDA_VISIBLE_DEVICES
|
||||
# at the top-level
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
# every worker will use 0.4 GPU, so that we can schedule
|
||||
# 2 instances on the same GPUs.
|
||||
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
|
||||
|
||||
@ -12,7 +12,7 @@ vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
||||
|
||||
(audio inference with Ultravox)
|
||||
vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096
|
||||
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096
|
||||
"""
|
||||
import base64
|
||||
|
||||
|
||||
@ -36,8 +36,8 @@ response = client.chat.completions.create(model=model, messages=messages)
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
content = response.choices[0].message.content
|
||||
|
||||
print("reasoning_content:", reasoning_content)
|
||||
print("content:", content)
|
||||
print("reasoning_content for Round 1:", reasoning_content)
|
||||
print("content for Round 1:", content)
|
||||
|
||||
# Round 2
|
||||
messages.append({"role": "assistant", "content": content})
|
||||
@ -50,5 +50,5 @@ response = client.chat.completions.create(model=model, messages=messages)
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
content = response.choices[0].message.content
|
||||
|
||||
print("reasoning_content:", reasoning_content)
|
||||
print("content:", content)
|
||||
print("reasoning_content for Round 2:", reasoning_content)
|
||||
print("content for Round 2:", content)
|
||||
|
||||
5
setup.py
5
setup.py
@ -48,8 +48,9 @@ elif not (sys.platform.startswith("linux")
|
||||
"so vLLM may not be able to run correctly", sys.platform)
|
||||
VLLM_TARGET_DEVICE = "empty"
|
||||
elif (sys.platform.startswith("linux") and torch.version.cuda is None
|
||||
and os.getenv("VLLM_TARGET_DEVICE") is None):
|
||||
# if cuda is not available and VLLM_TARGET_DEVICE is not set,
|
||||
and os.getenv("VLLM_TARGET_DEVICE") is None
|
||||
and torch.version.hip is None):
|
||||
# if cuda or hip is not available and VLLM_TARGET_DEVICE is not set,
|
||||
# fallback to cpu
|
||||
VLLM_TARGET_DEVICE = "cpu"
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
@ -44,7 +44,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
@ -72,7 +72,7 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
|
||||
# so that each worker can see all the GPUs
|
||||
# they will be able to set the device to the correct GPU
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
@ -108,7 +108,7 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
@ -148,7 +148,7 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
|
||||
@ -24,7 +24,7 @@ for i, v in enumerate(test_sizes):
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
@ -80,7 +80,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
del os.environ["CUDA_VISIBLE_DEVICES"]
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
|
||||
@ -215,7 +215,7 @@ MULTIMODAL_MODELS = {
|
||||
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True),
|
||||
"Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(),
|
||||
"Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(),
|
||||
"fixie-ai/ultravox-v0_3": PPTestSettings.fast(trust_remote_code=True),
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
# [Encoder-decoder]
|
||||
# TODO: Implement PP
|
||||
# "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(),
|
||||
@ -234,7 +234,7 @@ TEST_MODELS = [
|
||||
# [MULTIMODAL GENERATION]
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"microsoft/Phi-3-vision-128k-instruct",
|
||||
"fixie-ai/ultravox-v0_3",
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||
# [LANGUAGE GENERATION - HYBRID ARCH]
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
]
|
||||
|
||||
@ -15,32 +15,62 @@ start_token = "<think>"
|
||||
end_token = "</think>"
|
||||
|
||||
SIMPLE_REASONING = {
|
||||
"output": "<think>This is a reasoning section</think>This is the rest",
|
||||
"output": "This is a reasoning section</think>This is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
}
|
||||
COMPLETE_REASONING = {
|
||||
"output": "<think>This is a reasoning section</think>",
|
||||
"output": "This is a reasoning section</think>",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
}
|
||||
NO_REASONING = {
|
||||
"output": "This is a reasoning section",
|
||||
"output": "This is content",
|
||||
"reasoning_content": None,
|
||||
"content": "This is a reasoning section",
|
||||
"content": "This is content",
|
||||
}
|
||||
NO_REASONING_STREAMING = {
|
||||
"output": "This is a reasoning section",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
}
|
||||
MULTIPLE_LINES = {
|
||||
"output": "<think>This\nThat</think>This is the rest\nThat",
|
||||
"output": "This\nThat</think>This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING = {
|
||||
"output": "<think></think>This is the rest",
|
||||
"output": "</think>This is the rest",
|
||||
"reasoning_content": "",
|
||||
"content": "This is the rest",
|
||||
}
|
||||
SHORTEST_REASONING = {
|
||||
"output": "<think></think>This is the rest",
|
||||
"output": "</think>This is the rest",
|
||||
"reasoning_content": None,
|
||||
"content": "This is the rest",
|
||||
}
|
||||
REASONING_WITH_THINK = {
|
||||
"output": "<think>This is a reasoning section</think>This is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
}
|
||||
COMPLETE_REASONING_WITH_THINK = {
|
||||
"output": "<think>This is a reasoning section</think>",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
}
|
||||
MULTIPLE_LINES_WITH_THINK = {
|
||||
"output": "<think>This\nThat</think>This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
|
||||
"output": "</think>This is the rest",
|
||||
"reasoning_content": "",
|
||||
"content": "This is the rest",
|
||||
}
|
||||
SHORTEST_REASONING_WITH_THINK = {
|
||||
"output": "</think>This is the rest",
|
||||
"reasoning_content": None,
|
||||
"content": "This is the rest",
|
||||
}
|
||||
@ -49,37 +79,37 @@ TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_streaming",
|
||||
id="simple_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_streaming",
|
||||
id="simple_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_streaming",
|
||||
id="complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_streaming",
|
||||
id="complete_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_REASONING,
|
||||
id="no_streaming",
|
||||
id="no_reasoning_token",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
NO_REASONING,
|
||||
id="no_streaming",
|
||||
NO_REASONING_STREAMING,
|
||||
id="no_reasoning_token_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines_streaming",
|
||||
id="multiple_lines",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
@ -89,23 +119,65 @@ TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING,
|
||||
id="shortest_streaming",
|
||||
id="shortest",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING,
|
||||
id="shortest_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
REASONING_WITH_THINK,
|
||||
id="reasoning_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
REASONING_WITH_THINK,
|
||||
id="reasoning_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING_WITH_THINK,
|
||||
id="complete_reasoning_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING_WITH_THINK,
|
||||
id="complete_reasoning_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES_WITH_THINK,
|
||||
id="multiple_lines_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES_WITH_THINK,
|
||||
id="multiple_lines_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
|
||||
id="shortest_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING_WITH_THINK,
|
||||
id="shortest_with_think_streaming",
|
||||
),
|
||||
]
|
||||
|
||||
# Global tokenizer initialization to avoid repeated loading
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
||||
tokenizer.add_tokens([start_token, end_token])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
||||
tokenizer.add_tokens([start_token, end_token])
|
||||
output = tokenizer.tokenize(param_dict["output"])
|
||||
# decode everything to tokens
|
||||
output_tokens: List[str] = [
|
||||
|
||||
@ -11,7 +11,7 @@ from vllm.multimodal.utils import encode_audio_base64, fetch_audio
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_3"
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
TEST_AUDIO_URLS = [
|
||||
AudioAsset("winning_call").url,
|
||||
]
|
||||
|
||||
@ -85,6 +85,10 @@ EXPECTED_VALUES = {
|
||||
"vllm:time_per_output_token_seconds":
|
||||
[("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))],
|
||||
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_prompt_tokens":
|
||||
[("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
|
||||
("_count", _NUM_REQUESTS)],
|
||||
@ -169,6 +173,18 @@ EXPECTED_METRICS = [
|
||||
"vllm:e2e_request_latency_seconds_sum",
|
||||
"vllm:e2e_request_latency_seconds_bucket",
|
||||
"vllm:e2e_request_latency_seconds_count",
|
||||
"vllm:request_queue_time_seconds_sum",
|
||||
"vllm:request_queue_time_seconds_bucket",
|
||||
"vllm:request_queue_time_seconds_count",
|
||||
"vllm:request_inference_time_seconds_sum",
|
||||
"vllm:request_inference_time_seconds_bucket",
|
||||
"vllm:request_inference_time_seconds_count",
|
||||
"vllm:request_prefill_time_seconds_sum",
|
||||
"vllm:request_prefill_time_seconds_bucket",
|
||||
"vllm:request_prefill_time_seconds_count",
|
||||
"vllm:request_decode_time_seconds_sum",
|
||||
"vllm:request_decode_time_seconds_bucket",
|
||||
"vllm:request_decode_time_seconds_count",
|
||||
"vllm:request_prompt_tokens_sum",
|
||||
"vllm:request_prompt_tokens_bucket",
|
||||
"vllm:request_prompt_tokens_count",
|
||||
@ -203,6 +219,8 @@ EXPECTED_METRICS_V1 = [
|
||||
"vllm:num_requests_running",
|
||||
"vllm:num_requests_waiting",
|
||||
"vllm:gpu_cache_usage_perc",
|
||||
"vllm:gpu_prefix_cache_queries",
|
||||
"vllm:gpu_prefix_cache_hits",
|
||||
"vllm:prompt_tokens_total",
|
||||
"vllm:generation_tokens_total",
|
||||
"vllm:request_success_total",
|
||||
@ -218,6 +236,21 @@ EXPECTED_METRICS_V1 = [
|
||||
"vllm:time_per_output_token_seconds_sum",
|
||||
"vllm:time_per_output_token_seconds_bucket",
|
||||
"vllm:time_per_output_token_seconds_count",
|
||||
"vllm:e2e_request_latency_seconds_sum",
|
||||
"vllm:e2e_request_latency_seconds_bucket",
|
||||
"vllm:e2e_request_latency_seconds_count",
|
||||
"vllm:request_queue_time_seconds_sum",
|
||||
"vllm:request_queue_time_seconds_bucket",
|
||||
"vllm:request_queue_time_seconds_count",
|
||||
"vllm:request_inference_time_seconds_sum",
|
||||
"vllm:request_inference_time_seconds_bucket",
|
||||
"vllm:request_inference_time_seconds_count",
|
||||
"vllm:request_prefill_time_seconds_sum",
|
||||
"vllm:request_prefill_time_seconds_bucket",
|
||||
"vllm:request_prefill_time_seconds_count",
|
||||
"vllm:request_decode_time_seconds_sum",
|
||||
"vllm:request_decode_time_seconds_bucket",
|
||||
"vllm:request_decode_time_seconds_count",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ from ..utils import VLLM_PATH
|
||||
EXAMPLES_DIR = VLLM_PATH / "examples"
|
||||
|
||||
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
||||
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3"
|
||||
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
|
||||
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
|
||||
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
|
||||
|
||||
@ -606,20 +606,26 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
|
||||
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||
MergedColumnParallelLinearWithLoRA)
|
||||
# Verify packed lora is correct
|
||||
model_lora_clone = model_lora.clone(1)
|
||||
model_lora_clone1 = model_lora1.clone(1)
|
||||
assert manager.add_adapter(model_lora)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
|
||||
assert model_lora.get_lora("gate_proj") is None
|
||||
assert model_lora.get_lora("up_proj") is None
|
||||
assert model_lora1.get_lora("up_proj") is None
|
||||
packed_lora = model_lora.get_lora("gate_up_proj")
|
||||
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
|
||||
|
||||
torch.testing.assert_close(packed_lora.lora_a[0],
|
||||
model_lora.get_lora("gate_proj").lora_a)
|
||||
model_lora_clone.get_lora("gate_proj").lora_a)
|
||||
torch.testing.assert_close(packed_lora.lora_b[0],
|
||||
model_lora.get_lora("gate_proj").lora_b)
|
||||
model_lora_clone.get_lora("gate_proj").lora_b)
|
||||
torch.testing.assert_close(packed_lora.lora_a[1],
|
||||
model_lora.get_lora("up_proj").lora_a)
|
||||
model_lora_clone.get_lora("up_proj").lora_a)
|
||||
torch.testing.assert_close(packed_lora.lora_b[1],
|
||||
model_lora.get_lora("up_proj").lora_b)
|
||||
model_lora_clone.get_lora("up_proj").lora_b)
|
||||
|
||||
packed_lora1 = model_lora1.get_lora("gate_up_proj")
|
||||
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
|
||||
@ -627,6 +633,6 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
assert packed_lora1.lora_a[0] is None
|
||||
assert packed_lora1.lora_b[0] is None
|
||||
torch.testing.assert_close(packed_lora1.lora_a[1],
|
||||
model_lora1.get_lora("up_proj").lora_a)
|
||||
model_lora_clone1.get_lora("up_proj").lora_a)
|
||||
torch.testing.assert_close(packed_lora1.lora_b[1],
|
||||
model_lora1.get_lora("up_proj").lora_b)
|
||||
model_lora_clone1.get_lora("up_proj").lora_b)
|
||||
|
||||
652
tests/lora/test_punica_ops.py
Normal file
652
tests/lora/test_punica_ops.py
Normal file
@ -0,0 +1,652 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from threading import Lock
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.lora.ops.triton_ops # noqa: F401
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (PunicaTensors, assert_close, generate_data,
|
||||
generate_data_for_expand_nslices,
|
||||
generate_data_for_nslices)
|
||||
|
||||
|
||||
# Utility shrink and expand operations used as reference implementations.
|
||||
def sgmv_shrink_for_nslices(
|
||||
nslices: int, inputs_tensor: torch.Tensor,
|
||||
lora_weights_lst: List[torch.Tensor], out_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor,
|
||||
prompt_lora_mapping: torch.Tensor, batches: int, max_seq_length: int,
|
||||
num_tokens: int, scaling: float):
|
||||
"""
|
||||
Wrapper around sgmv_shrink that handles any nslices.
|
||||
"""
|
||||
for index in range(nslices):
|
||||
sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights_lst[index],
|
||||
out_tensor[index],
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
num_tokens,
|
||||
scaling,
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand_for_nslices(nslices: int, hidden_size: int,
|
||||
inputs_tensor: torch.Tensor,
|
||||
lora_weights_lst: List[torch.Tensor],
|
||||
out_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
prompt_lora_mapping: torch.Tensor, batches: int,
|
||||
max_seq_length: int, num_tokens: int,
|
||||
add_inputs: bool) -> None:
|
||||
"""
|
||||
Wrapper around sgmv_expand that handles any nslices.
|
||||
"""
|
||||
if nslices == 1:
|
||||
# Verify the torch's sgmv_expand op
|
||||
sgmv_expand(
|
||||
inputs_tensor[0],
|
||||
lora_weights_lst[0],
|
||||
out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
num_tokens,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
else:
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
sgmv_expand_slice(
|
||||
inputs_tensor[index],
|
||||
lora_weights,
|
||||
out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
num_tokens,
|
||||
slice_offset,
|
||||
hidden_size,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
slice_offset += hidden_size
|
||||
|
||||
|
||||
_dict_lock = Lock()
|
||||
|
||||
|
||||
def check_sgmv_shrink(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, scaling: float):
|
||||
"""
|
||||
Compare outputs of vllm.sgmv_shrink kernel against a reference
|
||||
implementation.
|
||||
"""
|
||||
data: PunicaTensors = generate_data_for_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
nslices,
|
||||
dtype,
|
||||
"shrink",
|
||||
device,
|
||||
)
|
||||
max_seq_length, token_nums = data.meta()
|
||||
|
||||
# Preventing cache error pointer.
|
||||
with _dict_lock:
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
)
|
||||
|
||||
sgmv_shrink_for_nslices(
|
||||
nslices,
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
)
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
|
||||
|
||||
def check_sgmv_expand(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, seq_length: int, add_inputs: bool):
|
||||
"""
|
||||
Compare outputs of vllm.sgmv_expand kernel against a reference
|
||||
implementation.
|
||||
"""
|
||||
data: PunicaTensors = generate_data_for_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
nslices,
|
||||
dtype,
|
||||
"expand",
|
||||
device,
|
||||
)
|
||||
|
||||
max_seq_length, token_nums = data.meta()
|
||||
|
||||
with _dict_lock:
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_expand(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
offset_start=0,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
sgmv_expand_for_nslices(nslices,
|
||||
hidden_size,
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.b_seq_start_loc,
|
||||
data.seq_len_tensor,
|
||||
data.prompt_lora_mapping,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
add_inputs=add_inputs)
|
||||
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
|
||||
|
||||
def check_bgmv_shrink(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, dtype: torch.dtype, device: str,
|
||||
scaling: float):
|
||||
"""
|
||||
Compare vllm.bgmv_shrink against a reference implementation.
|
||||
"""
|
||||
seq_length = 1
|
||||
data: PunicaTensors = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
"shrink",
|
||||
device,
|
||||
)
|
||||
|
||||
torch.ops.vllm.bgmv_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
scaling,
|
||||
)
|
||||
|
||||
bgmv_shrink(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
scaling,
|
||||
)
|
||||
|
||||
data.ref_out_tensor = data.ref_out_tensor.to(torch.float32)
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
|
||||
|
||||
def check_bgmv_expand(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, dtype: torch.dtype, device: str,
|
||||
add_inputs: bool):
|
||||
"""
|
||||
Compare vllm.bgmv_expand against a reference implementation.
|
||||
"""
|
||||
seq_length = 1
|
||||
data: PunicaTensors = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
"expand",
|
||||
device,
|
||||
)
|
||||
|
||||
torch.ops.vllm.bgmv_expand(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.our_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
bgmv_expand(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights,
|
||||
data.ref_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
|
||||
|
||||
def check_bgmv_expand_slice(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int, dtype: torch.dtype,
|
||||
device: str, add_inputs: bool):
|
||||
"""
|
||||
Compare vllm.bgmv_expand_slice against a reference implementation.
|
||||
"""
|
||||
seq_length = 1
|
||||
data: PunicaTensors = generate_data_for_expand_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
nslices,
|
||||
device,
|
||||
)
|
||||
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
torch.ops.vllm.bgmv_expand_slice(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights[index],
|
||||
data.our_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
bgmv_expand_slice(
|
||||
data.inputs_tensor,
|
||||
data.lora_weights[index],
|
||||
data.ref_out_tensor,
|
||||
data.token_lora_mapping,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=add_inputs,
|
||||
)
|
||||
|
||||
slice_offset += hidden_size
|
||||
assert_close(data.our_out_tensor, data.ref_out_tensor)
|
||||
|
||||
|
||||
# Tests
|
||||
# We test the punica kernels along 2 verticals mainly.
|
||||
# 1. Variations in hidden_dim size
|
||||
# 2. Variations in all other parameters like (batch_size, max_rank, num_loras
|
||||
# etc.)
|
||||
|
||||
# We have collected the hidden_sizes included in the LoRA models
|
||||
# currently supported by vLLM. It tests whether the corresponding Triton
|
||||
# kernel can run normally when tensor parallelism is set to
|
||||
# [1, 2, 4, 8, 16, 32, 64].
|
||||
HIDDEN_SIZES = [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
896,
|
||||
1024,
|
||||
1152,
|
||||
1216,
|
||||
1280,
|
||||
1536,
|
||||
1664,
|
||||
2048,
|
||||
2240,
|
||||
2304,
|
||||
2368,
|
||||
2432,
|
||||
2560,
|
||||
2752,
|
||||
3072,
|
||||
3328,
|
||||
3456,
|
||||
3584,
|
||||
3712,
|
||||
4096,
|
||||
4480,
|
||||
4608,
|
||||
4736,
|
||||
4864,
|
||||
5120,
|
||||
5504,
|
||||
5632,
|
||||
5888,
|
||||
6144,
|
||||
6400,
|
||||
6848,
|
||||
6912,
|
||||
7168,
|
||||
7424,
|
||||
8192,
|
||||
8960,
|
||||
9216,
|
||||
9472,
|
||||
10240,
|
||||
11008,
|
||||
11264,
|
||||
13824,
|
||||
14336,
|
||||
14784,
|
||||
14848,
|
||||
15360,
|
||||
18944,
|
||||
22016,
|
||||
22528,
|
||||
24576,
|
||||
27392,
|
||||
27648,
|
||||
29568,
|
||||
29696,
|
||||
32000,
|
||||
32256,
|
||||
32512,
|
||||
32768,
|
||||
33024,
|
||||
36864,
|
||||
43264,
|
||||
49152,
|
||||
49408,
|
||||
60544,
|
||||
60672,
|
||||
64000,
|
||||
64256,
|
||||
102400,
|
||||
102656,
|
||||
128000,
|
||||
128256,
|
||||
]
|
||||
#The size of TP
|
||||
divisibility = [1, 2, 8, 16, 64]
|
||||
|
||||
all_hidden_size = []
|
||||
for div in divisibility:
|
||||
for hidden_size in HIDDEN_SIZES:
|
||||
all_hidden_size.append(hidden_size // div)
|
||||
|
||||
HIDDEN_SIZES = list(set(all_hidden_size))
|
||||
|
||||
# Test params that focuses on hidden_size variation.
|
||||
hs_test_params = {
|
||||
"hidden_sizes": HIDDEN_SIZES,
|
||||
"batches": [4],
|
||||
"num_loras": [4],
|
||||
"max_ranks": [32],
|
||||
}
|
||||
|
||||
# General tests params that tests for variations in all dimensions
|
||||
# except hidden_size.
|
||||
test_params = {
|
||||
"hidden_sizes": [2049],
|
||||
"batches": [1, 4, 16, 32],
|
||||
"num_loras": [1, 8, 32, 128],
|
||||
"max_ranks": [1, 4, 8, 16, 32, 64, 128, 256],
|
||||
}
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
DEVICES = [f"cuda:{0}"]
|
||||
SEED = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
def test_punica_sgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
op_type: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_sgmv_shrink(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_sgmv_expand(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", hs_test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
def test_punica_sgmv_hidden_size(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
op_type: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_sgmv_shrink(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_sgmv_expand(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
seq_length=128,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
def test_punica_bgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
op_type: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_bgmv_shrink(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_bgmv_expand(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", hs_test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
def test_punica_bgmv_hidden_size(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
seed: int,
|
||||
op_type: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
if op_type == "shrink":
|
||||
check_bgmv_shrink(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
scaling=0.5)
|
||||
else:
|
||||
check_bgmv_expand(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
def test_punica_bgmv_expand_nslices(batches: int, num_loras: int, rank: int,
|
||||
hidden_size: int, nslices: int,
|
||||
dtype: torch.dtype, device: str,
|
||||
seed: int):
|
||||
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
check_bgmv_expand_slice(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
add_inputs=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", hs_test_params['batches'])
|
||||
@pytest.mark.parametrize("num_loras", hs_test_params['num_loras'])
|
||||
@pytest.mark.parametrize("rank", hs_test_params['max_ranks'])
|
||||
@pytest.mark.parametrize("hidden_size", hs_test_params['hidden_sizes'])
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
def test_punica_bgmv_expand_nslices_hidden_size(batches: int, num_loras: int,
|
||||
rank: int, hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
device: str, seed: int):
|
||||
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
check_bgmv_expand_slice(batches=batches,
|
||||
num_loras=num_loras,
|
||||
rank=rank,
|
||||
hidden_size=hidden_size,
|
||||
nslices=nslices,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
add_inputs=True)
|
||||
@ -1,401 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This script is mainly used to tests various hidden_sizes. We have collected the
|
||||
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
|
||||
whether the corresponding Triton kernel can run normally when tensor parallelism
|
||||
is set to [1, 2, 4, 8, 16, 32, 64].
|
||||
"""
|
||||
from threading import Lock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.lora.ops.triton_ops # noqa: F401
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (assert_close, generate_data,
|
||||
generate_data_for_expand_nslices,
|
||||
generate_data_for_nslices)
|
||||
|
||||
HIDDEN_SIZES = [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
896,
|
||||
1024,
|
||||
1152,
|
||||
1216,
|
||||
1280,
|
||||
1536,
|
||||
1664,
|
||||
2048,
|
||||
2240,
|
||||
2304,
|
||||
2368,
|
||||
2432,
|
||||
2560,
|
||||
2752,
|
||||
3072,
|
||||
3328,
|
||||
3456,
|
||||
3584,
|
||||
3712,
|
||||
4096,
|
||||
4480,
|
||||
4608,
|
||||
4736,
|
||||
4864,
|
||||
5120,
|
||||
5504,
|
||||
5632,
|
||||
5888,
|
||||
6144,
|
||||
6400,
|
||||
6848,
|
||||
6912,
|
||||
7168,
|
||||
7424,
|
||||
8192,
|
||||
8960,
|
||||
9216,
|
||||
9472,
|
||||
10240,
|
||||
11008,
|
||||
11264,
|
||||
13824,
|
||||
14336,
|
||||
14784,
|
||||
14848,
|
||||
15360,
|
||||
18944,
|
||||
22016,
|
||||
22528,
|
||||
24576,
|
||||
27392,
|
||||
27648,
|
||||
29568,
|
||||
29696,
|
||||
32000,
|
||||
32256,
|
||||
32512,
|
||||
32768,
|
||||
33024,
|
||||
36864,
|
||||
43264,
|
||||
49152,
|
||||
49408,
|
||||
60544,
|
||||
60672,
|
||||
64000,
|
||||
64256,
|
||||
102400,
|
||||
102656,
|
||||
128000,
|
||||
128256,
|
||||
]
|
||||
#The size of TP
|
||||
divisibility = [1, 2, 8, 16, 64]
|
||||
|
||||
all_hidden_size = []
|
||||
for div in divisibility:
|
||||
for hidden_size in HIDDEN_SIZES:
|
||||
all_hidden_size.append(hidden_size // div)
|
||||
|
||||
HIDDEN_SIZES = list(set(all_hidden_size))
|
||||
|
||||
BATCHES = [4]
|
||||
NUM_LORA = [4]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
MAX_RANKS = [32]
|
||||
SCALES = [0.5]
|
||||
SEED = [0]
|
||||
DEVICES = [f"cuda:{0}"]
|
||||
|
||||
_dict_lock = Lock()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("scaling", SCALES)
|
||||
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_sgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
scaling: float,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 128
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data_for_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
nslices,
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
)
|
||||
max_seq_length = seq_len_tensor.max()
|
||||
token_nums = seq_len_tensor.sum().item()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
max_seq_length = max_seq_length.item()
|
||||
if op_type == "shrink":
|
||||
# Preventing cache error pointer.
|
||||
with _dict_lock:
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
)
|
||||
for index in range(nslices):
|
||||
sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights_lst[index],
|
||||
ref_out_tensor[index],
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
)
|
||||
|
||||
else:
|
||||
with _dict_lock:
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
offset_start=0,
|
||||
add_inputs=True,
|
||||
)
|
||||
if nslices == 1:
|
||||
# Verify the torch's sgmv_expand op
|
||||
sgmv_expand(
|
||||
inputs_tensor[0],
|
||||
lora_weights_lst[0],
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
add_inputs=True,
|
||||
)
|
||||
else:
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
sgmv_expand_slice(
|
||||
inputs_tensor[index],
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
slice_offset,
|
||||
hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
slice_offset += hidden_size
|
||||
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("scaling", SCALES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_bgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
scaling: float,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 1
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
)
|
||||
if op_type == "shrink":
|
||||
torch.ops.vllm.bgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
scaling,
|
||||
)
|
||||
|
||||
bgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
indices,
|
||||
scaling,
|
||||
)
|
||||
|
||||
else:
|
||||
torch.ops.vllm.bgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
bgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
if op_type == "shrink":
|
||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_bgmv_expand_nslices(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 1
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_outputs,
|
||||
ref_outputs,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data_for_expand_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
nslices,
|
||||
device,
|
||||
)
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
torch.ops.vllm.bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_outputs,
|
||||
indices,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
ref_outputs,
|
||||
indices,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
slice_offset += hidden_size
|
||||
assert_close(our_outputs, ref_outputs)
|
||||
@ -1,317 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This script is mainly used to test whether trtion kernels can run normally
|
||||
under different conditions, including various batches, numbers of LoRA , and
|
||||
maximum ranks.
|
||||
"""
|
||||
from threading import Lock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Enable custom op register
|
||||
import vllm.lora.ops.triton_ops # noqa: F401
|
||||
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
|
||||
bgmv_shrink, sgmv_expand,
|
||||
sgmv_expand_slice, sgmv_shrink)
|
||||
from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import (assert_close, generate_data,
|
||||
generate_data_for_expand_nslices,
|
||||
generate_data_for_nslices)
|
||||
|
||||
HIDDEN_SIZES = [2049]
|
||||
|
||||
BATCHES = [1, 4, 16, 32]
|
||||
NUM_LORA = [1, 8, 32, 128]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
MAX_RANKS = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
SCALES = [0.5]
|
||||
SEED = [0]
|
||||
DEVICES = [f"cuda:{0}"]
|
||||
|
||||
_dict_lock = Lock()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("scaling", SCALES)
|
||||
@pytest.mark.parametrize("nslices", [1, 2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_sgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
scaling: float,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 128
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data_for_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
nslices,
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
)
|
||||
max_seq_length = seq_len_tensor.max()
|
||||
token_nums = seq_len_tensor.sum().item()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
max_seq_length = max_seq_length.item()
|
||||
if op_type == "shrink":
|
||||
# Preventing cache error pointer.
|
||||
with _dict_lock:
|
||||
_LORA_A_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
)
|
||||
for index in range(nslices):
|
||||
sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights_lst[index],
|
||||
ref_out_tensor[index],
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
scaling,
|
||||
)
|
||||
|
||||
else:
|
||||
with _dict_lock:
|
||||
_LORA_B_PTR_DICT.clear()
|
||||
torch.ops.vllm.sgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
offset_start=0,
|
||||
add_inputs=True,
|
||||
)
|
||||
slice_offset = 0
|
||||
if nslices == 1:
|
||||
# Verify the torch's sgmv_expand op
|
||||
sgmv_expand(
|
||||
inputs_tensor[0],
|
||||
lora_weights_lst[0],
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
add_inputs=True,
|
||||
)
|
||||
else:
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
sgmv_expand_slice(
|
||||
inputs_tensor[index],
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
token_nums,
|
||||
slice_offset,
|
||||
hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
slice_offset += hidden_size
|
||||
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("scaling", SCALES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_bgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
scaling: float,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 1
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
)
|
||||
if op_type == "shrink":
|
||||
torch.ops.vllm.bgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
scaling,
|
||||
)
|
||||
|
||||
bgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
indices,
|
||||
scaling,
|
||||
)
|
||||
|
||||
else:
|
||||
torch.ops.vllm.bgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
bgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
ref_out_tensor,
|
||||
indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
if op_type == "shrink":
|
||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_punica_bgmv_expand_nslices(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
torch.set_default_device(device)
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
seq_length = 1
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_outputs,
|
||||
ref_outputs,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data_for_expand_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
nslices,
|
||||
device,
|
||||
)
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
torch.ops.vllm.bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_outputs,
|
||||
indices,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
ref_outputs,
|
||||
indices,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
|
||||
slice_offset += hidden_size
|
||||
assert_close(our_outputs, ref_outputs)
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -106,6 +107,31 @@ def assert_close(a, b):
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PunicaTensors:
|
||||
inputs_tensor: torch.Tensor
|
||||
lora_weights: Union[torch.Tensor, List[torch.Tensor]]
|
||||
our_out_tensor: torch.Tensor
|
||||
ref_out_tensor: torch.Tensor
|
||||
b_seq_start_loc: torch.Tensor
|
||||
prompt_lora_mapping: torch.Tensor
|
||||
seq_len_tensor: torch.Tensor
|
||||
token_lora_mapping: torch.Tensor
|
||||
|
||||
def meta(self) -> Tuple[int, int]:
|
||||
"""
|
||||
Infer max_seq_length and token_nums from the tensors
|
||||
and return them.
|
||||
"""
|
||||
max_seq_length = self.seq_len_tensor.max()
|
||||
token_nums = self.seq_len_tensor.sum().item()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
max_seq_length = max_seq_length.item()
|
||||
return max_seq_length, token_nums
|
||||
|
||||
|
||||
def generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
@ -115,7 +141,7 @@ def generate_data(
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
):
|
||||
) -> PunicaTensors:
|
||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||
(batches, )).to(device)
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
@ -164,7 +190,8 @@ def generate_data(
|
||||
indices[current_offset:current_offset +
|
||||
seq_len_tensor[b_id]].copy_(lora_index)
|
||||
current_offset += seq_len_tensor[b_id].item()
|
||||
return (
|
||||
|
||||
return PunicaTensors(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
@ -185,7 +212,7 @@ def generate_data_for_expand_nslices(
|
||||
dtype,
|
||||
nslices,
|
||||
device,
|
||||
):
|
||||
) -> PunicaTensors:
|
||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||
(batches, )).to(device)
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
@ -222,7 +249,7 @@ def generate_data_for_expand_nslices(
|
||||
current_offset += seq_len_tensor[b_id].item()
|
||||
|
||||
lora_indices_tensor = lora_indices_tensor.to(device)
|
||||
return (
|
||||
return PunicaTensors(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
@ -244,7 +271,7 @@ def generate_data_for_nslices(
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
):
|
||||
) -> PunicaTensors:
|
||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||
(batches, )).to(device)
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
@ -302,7 +329,7 @@ def generate_data_for_nslices(
|
||||
current_offset += seq_len_tensor[b_id].item()
|
||||
|
||||
lora_indices_tensor = lora_indices_tensor.to(device)
|
||||
return (
|
||||
return PunicaTensors(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
|
||||
0
tests/mistral_tool_use/__init__.py
Normal file
0
tests/mistral_tool_use/__init__.py
Normal file
40
tests/mistral_tool_use/conftest.py
Normal file
40
tests/mistral_tool_use/conftest.py
Normal file
@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import ARGS, CONFIGS, ServerConfig
|
||||
|
||||
|
||||
# for each server config, download the model and return the config
|
||||
@pytest.fixture(scope="session", params=CONFIGS.keys())
|
||||
def server_config(request):
|
||||
config = CONFIGS[request.param]
|
||||
|
||||
if current_platform.is_rocm() and not config.get("supports_rocm", True):
|
||||
pytest.skip("The {} model can't be tested on the ROCm platform".format(
|
||||
config["model"]))
|
||||
|
||||
# download model and tokenizer using transformers
|
||||
snapshot_download(config["model"])
|
||||
yield CONFIGS[request.param]
|
||||
|
||||
|
||||
# run this for each server config
|
||||
@pytest.fixture(scope="session")
|
||||
def server(request, server_config: ServerConfig):
|
||||
model = server_config["model"]
|
||||
args_for_model = server_config["arguments"]
|
||||
with RemoteOpenAIServer(model, ARGS + args_for_model,
|
||||
max_wait_seconds=480) as server:
|
||||
yield server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server: RemoteOpenAIServer):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
29
tests/mistral_tool_use/test_mistral_tool_calls.py
Normal file
29
tests/mistral_tool_use/test_mistral_tool_calls.py
Normal file
@ -0,0 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from tests.tool_use.utils import MESSAGES_ASKING_FOR_TOOLS, WEATHER_TOOL
|
||||
|
||||
|
||||
# test: a tool_choice with mistral-tokenizer results in an ID of length 9
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_with_tool_choice(client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
model_name: str = models.data[0].id
|
||||
chat_completion = await client.chat.completions.create(
|
||||
messages=MESSAGES_ASKING_FOR_TOOLS,
|
||||
temperature=0,
|
||||
max_completion_tokens=100,
|
||||
model=model_name,
|
||||
tools=[WEATHER_TOOL],
|
||||
tool_choice=WEATHER_TOOL,
|
||||
logprobs=False)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
|
||||
assert choice.finish_reason != "tool_calls" # "stop" or "length"
|
||||
assert choice.message.role == "assistant"
|
||||
assert choice.message.tool_calls is None \
|
||||
or len(choice.message.tool_calls) == 1
|
||||
assert len(choice.message.tool_calls[0].id) == 9 # length of 9 for mistral
|
||||
33
tests/mistral_tool_use/utils.py
Normal file
33
tests/mistral_tool_use/utils.py
Normal file
@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class ServerConfig(TypedDict, total=False):
|
||||
model: str
|
||||
arguments: List[str]
|
||||
system_prompt: Optional[str]
|
||||
supports_parallel: Optional[bool]
|
||||
supports_rocm: Optional[bool]
|
||||
|
||||
|
||||
ARGS: List[str] = ["--max-model-len", "1024"]
|
||||
|
||||
CONFIGS: Dict[str, ServerConfig] = {
|
||||
"mistral": {
|
||||
"model":
|
||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"arguments": [
|
||||
"--tokenizer-mode", "mistral",
|
||||
"--ignore-patterns=\"consolidated.safetensors\""
|
||||
],
|
||||
"system_prompt":
|
||||
"You are a helpful assistant with access to tools. If a tool"
|
||||
" that you have would be helpful to answer a user query, "
|
||||
"call the tool. Otherwise, answer the user's query directly "
|
||||
"without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT "
|
||||
"to the user's question - just respond to it normally."
|
||||
},
|
||||
}
|
||||
@ -15,7 +15,7 @@ from ....conftest import HfRunner, VllmRunner
|
||||
from ....utils import RemoteOpenAIServer
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_3"
|
||||
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
|
||||
|
||||
AudioTuple = Tuple[np.ndarray, int]
|
||||
|
||||
|
||||
@ -164,7 +164,7 @@ def _test_processing_correctness(
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"fixie-ai/ultravox-v0_3",
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||
])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
|
||||
@ -214,6 +214,10 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"Phi3VForCausalLM": _HfExamplesInfo("TIGER-Lab/VLM2Vec-Full",
|
||||
trust_remote_code=True),
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501
|
||||
# The model on Huggingface is currently being updated,
|
||||
# hence I temporarily mark it as not available online
|
||||
"PrithviGeoSpatialMAE": _HfExamplesInfo("ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11", # noqa: E501
|
||||
is_available_online=False),
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_EXAMPLE_MODELS = {
|
||||
@ -267,7 +271,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
|
||||
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
|
||||
min_transformers_version="4.49"), # noqa: E501
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
|
||||
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||
trust_remote_code=True),
|
||||
# [Encoder-decoder]
|
||||
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@ -171,12 +170,22 @@ def ref_context_attention(
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"block_size, large_tile_size",
|
||||
[
|
||||
(32, 2048), # 64 blocks
|
||||
(32, 4096), # 128 blocks
|
||||
(32, 8192), # 256 blocks
|
||||
(64, 8192), # 128 blocks
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_heads,num_queries_per_kv,head_size,mixed_precision",
|
||||
[
|
||||
(4, 2, 8, False),
|
||||
(4, 2, 8, True),
|
||||
(32, 8, 64, True),
|
||||
(16, 2, 128, True),
|
||||
],
|
||||
)
|
||||
@torch.inference_mode()
|
||||
@ -184,6 +193,8 @@ def test_contexted_kv_attention(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
block_size: int,
|
||||
large_tile_size,
|
||||
mixed_precision: bool,
|
||||
) -> None:
|
||||
import os
|
||||
@ -192,40 +203,46 @@ def test_contexted_kv_attention(
|
||||
|
||||
from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc
|
||||
|
||||
assert large_tile_size % block_size == 0
|
||||
|
||||
device = xm.xla_device()
|
||||
|
||||
os.environ["NEURON_CC_FLAGS"] = (
|
||||
" --model-type=transformer -O1 "
|
||||
" --internal-hlo2tensorizer-options='--verify-hlo' ")
|
||||
compiler_flags = [
|
||||
"--model-type=transformer -O1",
|
||||
"--internal-hlo2tensorizer-options='--verify-hlo'",
|
||||
"--retry_failed_compilation",
|
||||
]
|
||||
compiler_flags_str = " ".join(compiler_flags)
|
||||
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
|
||||
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.set_printoptions(sci_mode=False)
|
||||
|
||||
min_ctx_len = 2
|
||||
max_ctx_len = 64
|
||||
min_query_len = 2
|
||||
max_query_len = 64
|
||||
prefill_batch_size = 2
|
||||
decode_batch_size = 6
|
||||
min_ctx_len = 32
|
||||
max_ctx_len = 1024
|
||||
min_query_len = 16
|
||||
max_query_len = 512
|
||||
prefill_batch_size = 4
|
||||
decode_batch_size = 12
|
||||
batch_size = prefill_batch_size + decode_batch_size
|
||||
block_size = 32
|
||||
max_model_len = (max_query_len + max_ctx_len) * 4
|
||||
|
||||
max_block_per_request = max_model_len // block_size
|
||||
dtype = torch.float32
|
||||
cache_size = (batch_size * max_block_per_request) + 2
|
||||
ctx_lens = [
|
||||
random.randint(min_ctx_len, max_ctx_len)
|
||||
for _ in range(prefill_batch_size)
|
||||
] + [
|
||||
random.randint(min_ctx_len, max_ctx_len)
|
||||
for _ in range(decode_batch_size)
|
||||
]
|
||||
query_lens = [
|
||||
random.randint(min_query_len, max_query_len)
|
||||
for _ in range(prefill_batch_size)
|
||||
] + [1 for _ in range(decode_batch_size)]
|
||||
prefill_ctx_lens = torch.randint(min_ctx_len,
|
||||
max_ctx_len + 1, (prefill_batch_size, ),
|
||||
dtype=torch.long).tolist()
|
||||
decode_ctx_lens = torch.randint(min_ctx_len,
|
||||
max_ctx_len + 1, (decode_batch_size, ),
|
||||
dtype=torch.long).tolist()
|
||||
ctx_lens = prefill_ctx_lens + decode_ctx_lens
|
||||
query_lens = torch.randint(
|
||||
min_query_len,
|
||||
max_query_len + 1,
|
||||
(prefill_batch_size, ),
|
||||
dtype=torch.long,
|
||||
).tolist() + [1 for _ in range(decode_batch_size)]
|
||||
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
|
||||
num_kv_heads = num_heads // num_queries_per_kv
|
||||
|
||||
@ -254,7 +271,6 @@ def test_contexted_kv_attention(
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[:batch_size * max_block_per_request].view(
|
||||
batch_size, max_block_per_request)
|
||||
torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
||||
dtype=torch.long),
|
||||
@ -311,9 +327,7 @@ def test_contexted_kv_attention(
|
||||
# build neuron program
|
||||
return_debug_tensors = False
|
||||
B_P_SIZE = 128
|
||||
LARGE_TILE_SZ = 2048
|
||||
max_num_queries = (
|
||||
(sum(query_lens) + block_size - 1) // block_size) * block_size
|
||||
LARGE_TILE_SZ = large_tile_size
|
||||
|
||||
def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
|
||||
num_blocks):
|
||||
@ -332,26 +346,28 @@ def test_contexted_kv_attention(
|
||||
0,
|
||||
)
|
||||
|
||||
def shift_bit_length(x):
|
||||
return 1 << (x - 1).bit_length()
|
||||
def ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
def pad_to_multiple(a, b):
|
||||
return ceil_div(a, b) * b
|
||||
|
||||
def pad_to_next_power_of_2(a):
|
||||
assert a > 0
|
||||
return 2**int(a - 1).bit_length()
|
||||
|
||||
# calculate input shapes
|
||||
max_num_queries_shifted = shift_bit_length(max_num_queries)
|
||||
max_num_queries_factor = B_P_SIZE // max_num_queries_shifted
|
||||
max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor
|
||||
assert (max_num_queries_padded == B_P_SIZE
|
||||
), "invalid {max_num_queries_padded=}"
|
||||
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
|
||||
max_num_queries = pad_to_next_power_of_2(max_num_queries)
|
||||
head_size_padded = B_P_SIZE
|
||||
assert head_size_padded >= head_size
|
||||
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
|
||||
num_active_blocks_shifted = shift_bit_length(
|
||||
((context_lens + block_size - 1) // block_size).sum().item())
|
||||
num_active_blocks_factor = (LARGE_TILE_SZ // block_size //
|
||||
num_active_blocks_shifted)
|
||||
num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor
|
||||
assert (num_active_blocks *
|
||||
block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}"
|
||||
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
|
||||
num_active_blocks = pad_to_multiple(num_active_blocks,
|
||||
LARGE_TILE_SZ // block_size)
|
||||
context_kv_len = num_active_blocks * block_size
|
||||
assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}"
|
||||
assert (context_kv_len %
|
||||
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"
|
||||
|
||||
# pad QKV tensors
|
||||
pad_dims = (
|
||||
@ -360,7 +376,7 @@ def test_contexted_kv_attention(
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
max_num_queries_padded - query.shape[0],
|
||||
max_num_queries - query.shape[0],
|
||||
)
|
||||
query = F.pad(query, pad_dims, "constant", 0)
|
||||
k = F.pad(k, pad_dims, "constant", 0)
|
||||
@ -397,7 +413,7 @@ def test_contexted_kv_attention(
|
||||
0,
|
||||
context_kv_len - prior_mask.shape[1],
|
||||
0,
|
||||
B_P_SIZE - prior_mask.shape[0],
|
||||
max_num_queries - prior_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
@ -406,9 +422,9 @@ def test_contexted_kv_attention(
|
||||
active_mask,
|
||||
(
|
||||
0,
|
||||
B_P_SIZE - active_mask.shape[1],
|
||||
max_num_queries - active_mask.shape[1],
|
||||
0,
|
||||
B_P_SIZE - active_mask.shape[0],
|
||||
max_num_queries - active_mask.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
@ -430,6 +446,8 @@ def test_contexted_kv_attention(
|
||||
n_kv_head=num_kv_heads,
|
||||
head_size=head_size,
|
||||
mixed_precision=mixed_precision,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
return_debug_tensors=return_debug_tensors,
|
||||
)
|
||||
|
||||
if return_debug_tensors:
|
||||
@ -439,17 +457,15 @@ def test_contexted_kv_attention(
|
||||
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
|
||||
debug_tensors = []
|
||||
|
||||
output_nki = torch.tensor(output_nki).cpu()
|
||||
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]
|
||||
|
||||
num_actual_tokens = sum(query_lens)
|
||||
print(f"{num_actual_tokens=}")
|
||||
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
|
||||
output_nki = output_nki.permute(
|
||||
0, 2, 1, 3)[:, :, :, :head_size].cpu()[0, :num_actual_tokens, :, :]
|
||||
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
|
||||
output_nki = output_nki[0, :num_actual_tokens, :, :]
|
||||
output_ref_padded = F.pad(
|
||||
output_ref,
|
||||
(0, 0, 0, 0, 0, 0, 0, max_num_queries_padded - output_ref.shape[0]),
|
||||
(0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
|
||||
68
tests/quantization/test_gptq_dynamic.py
Normal file
68
tests/quantization/test_gptq_dynamic.py
Normal file
@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests whether gptq models with dynamic quantized can be loaded.
|
||||
|
||||
Run `pytest tests/quantization/test_gptq_dynamic.py --forked`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_dynamic_override)
|
||||
|
||||
PROMPT = "On the surface of Mars, we found"
|
||||
|
||||
# The first layer is quantized using bits=4, group_size=128
|
||||
# The second layer is quantized using bits=8, group_size=32
|
||||
# All other layers (layer index >= 2) are not quantized
|
||||
MODEL_QUANT = [
|
||||
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue",
|
||||
True),
|
||||
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse",
|
||||
False),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id, use_marlin_kernel", MODEL_QUANT)
|
||||
def test_gptq_with_dynamic(vllm_runner, model_id: str,
|
||||
use_marlin_kernel: bool):
|
||||
|
||||
vllm_model = vllm_runner(model_id, dtype=torch.float16, max_model_len=2048)
|
||||
|
||||
linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else (
|
||||
GPTQLinearMethod)
|
||||
|
||||
for name, submodule in (vllm_model.model.llm_engine.model_executor.
|
||||
driver_worker.model_runner.model.named_modules()):
|
||||
if name == "lm_head":
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
elif name == 'model.layers.0.self_attn.qkv_proj':
|
||||
# The first layer is quantized using bits=4, group_size=128
|
||||
# desc_act=True
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert config.weight_bits == 4
|
||||
assert config.group_size == 128
|
||||
assert config.desc_act
|
||||
elif name == 'model.layers.1.self_attn.qkv_proj':
|
||||
# The second layer is quantized using bits=8, group_size=32
|
||||
# desc_act=False
|
||||
assert isinstance(submodule.quant_method, linear_method_cls)
|
||||
config = submodule.quant_method.quant_config
|
||||
assert get_dynamic_override(config, layer_name=name,
|
||||
key="bits") == 8
|
||||
assert get_dynamic_override(config,
|
||||
layer_name=name,
|
||||
key="group_size") == 32
|
||||
assert not get_dynamic_override(
|
||||
config, layer_name=name, key="desc_act")
|
||||
elif (name == 'model.layers.2.self_attn.qkv_proj'
|
||||
or name == 'model.layers.2.mlp.gate_up_proj'):
|
||||
# All other layers (layer index >= 2) are not quantized
|
||||
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
|
||||
|
||||
del vllm_model
|
||||
@ -3,7 +3,6 @@
|
||||
|
||||
Run `pytest tests/quantization/test_quant_lm_head_true.py --forked`.
|
||||
"""
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -17,31 +16,31 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
|
||||
PROMPT = "On the surface of Mars, we found"
|
||||
|
||||
MODELS_QUANT = [(
|
||||
"LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse",
|
||||
True), ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
|
||||
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)]
|
||||
MODELS_QUANT = [
|
||||
("ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head", True),
|
||||
("ModelCloud/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-10-25-2024", False),
|
||||
("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", False),
|
||||
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", False)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_lm_head_quant", MODELS_QUANT)
|
||||
@pytest.mark.parametrize("model_id, lm_head_quantized", MODELS_QUANT)
|
||||
def test_lm_head(
|
||||
vllm_runner,
|
||||
model_lm_head_quant: Tuple[str, bool],
|
||||
model_id: str,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
model, lm_head_quantized = model_lm_head_quant
|
||||
|
||||
with vllm_runner(model, dtype=torch.float16,
|
||||
with vllm_runner(model_id, dtype=torch.float16,
|
||||
max_model_len=2048) as vllm_model:
|
||||
|
||||
def check_model(model):
|
||||
lm_head_layer = model.lm_head
|
||||
|
||||
if lm_head_quantized:
|
||||
assert isinstance(lm_head_layer.linear_method,
|
||||
assert isinstance(lm_head_layer.quant_method,
|
||||
(GPTQLinearMethod, GPTQMarlinLinearMethod,
|
||||
MarlinLinearMethod))
|
||||
else:
|
||||
assert isinstance(lm_head_layer.linear_method,
|
||||
assert isinstance(lm_head_layer.quant_method,
|
||||
UnquantizedEmbeddingMethod)
|
||||
|
||||
vllm_model.apply_model(check_model)
|
||||
|
||||
@ -8,32 +8,17 @@ from vllm.platforms.interface import Platform
|
||||
|
||||
|
||||
def test_seed_behavior():
|
||||
# Test with seed=None
|
||||
Platform.seed_everything(None)
|
||||
# Test with a specific seed
|
||||
Platform.seed_everything(42)
|
||||
random_value_1 = random.randint(0, 100)
|
||||
np_random_value_1 = np.random.randint(0, 100)
|
||||
torch_random_value_1 = torch.randint(0, 100, (1, )).item()
|
||||
|
||||
Platform.seed_everything(None)
|
||||
Platform.seed_everything(42)
|
||||
random_value_2 = random.randint(0, 100)
|
||||
np_random_value_2 = np.random.randint(0, 100)
|
||||
torch_random_value_2 = torch.randint(0, 100, (1, )).item()
|
||||
|
||||
assert random_value_1 != random_value_2
|
||||
assert np_random_value_1 != np_random_value_2
|
||||
assert torch_random_value_1 != torch_random_value_2
|
||||
|
||||
# Test with a specific seed
|
||||
Platform.seed_everything(42)
|
||||
random_value_3 = random.randint(0, 100)
|
||||
np_random_value_3 = np.random.randint(0, 100)
|
||||
torch_random_value_3 = torch.randint(0, 100, (1, )).item()
|
||||
|
||||
Platform.seed_everything(42)
|
||||
random_value_4 = random.randint(0, 100)
|
||||
np_random_value_4 = np.random.randint(0, 100)
|
||||
torch_random_value_4 = torch.randint(0, 100, (1, )).item()
|
||||
|
||||
assert random_value_3 == random_value_4
|
||||
assert np_random_value_3 == np_random_value_4
|
||||
assert torch_random_value_3 == torch_random_value_4
|
||||
assert random_value_1 == random_value_2
|
||||
assert np_random_value_1 == np_random_value_2
|
||||
assert torch_random_value_1 == torch_random_value_2
|
||||
|
||||
50
tests/tokenization/test_mistral_tokenizer.py
Normal file
50
tests/tokenization/test_mistral_tokenizer.py
Normal file
@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.instruct.tool_calls import Function, Tool
|
||||
|
||||
from vllm.transformers_utils.tokenizers.mistral import (
|
||||
make_mistral_chat_completion_request)
|
||||
|
||||
|
||||
# yapf: enable
|
||||
@pytest.mark.parametrize(
|
||||
"openai_request,expected_mistral_request",
|
||||
[(
|
||||
{
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": "What is the current local date and time?",
|
||||
}],
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"description": "Fetch the current local date and time.",
|
||||
"name": "get_current_time",
|
||||
},
|
||||
}],
|
||||
},
|
||||
ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(content="What is the current local date and time?")
|
||||
],
|
||||
tools=[
|
||||
Tool(
|
||||
type="function",
|
||||
function=Function(
|
||||
name="get_current_time",
|
||||
description="Fetch the current local date and time.",
|
||||
parameters={},
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)],
|
||||
)
|
||||
def test_make_mistral_chat_completion_request(openai_request,
|
||||
expected_mistral_request):
|
||||
assert (make_mistral_chat_completion_request(
|
||||
openai_request["messages"],
|
||||
openai_request["tools"]) == expected_mistral_request)
|
||||
123
tests/tokenization/test_tokenizer_registry.py
Normal file
123
tests/tokenization/test_tokenizer_registry.py
Normal file
@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
|
||||
TokenizerRegistry)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
|
||||
class TestTokenizer(TokenizerBase):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
|
||||
return TestTokenizer()
|
||||
|
||||
@property
|
||||
def all_special_tokens_extended(self) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def all_special_tokens(self) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def all_special_ids(self) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def sep_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def max_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str], List[int]],
|
||||
text_pair: Optional[str] = None,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def encode_one(
|
||||
self,
|
||||
text: str,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def encode(self,
|
||||
text: str,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_chat_template(self,
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def decode(self,
|
||||
ids: Union[List[int], int],
|
||||
skip_special_tokens: bool = True) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: List[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def test_customized_tokenizer():
|
||||
TokenizerRegistry.register("test_tokenizer",
|
||||
"tests.tokenization.test_tokenizer_registry",
|
||||
"TestTokenizer")
|
||||
|
||||
tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
|
||||
assert isinstance(tokenizer, TestTokenizer)
|
||||
assert tokenizer.bos_token_id == 0
|
||||
assert tokenizer.eos_token_id == 1
|
||||
|
||||
tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
|
||||
assert isinstance(tokenizer, TestTokenizer)
|
||||
assert tokenizer.bos_token_id == 0
|
||||
assert tokenizer.eos_token_id == 1
|
||||
@ -5,10 +5,11 @@ import pytest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
KVCacheBlock, PrefixCachingMetrics,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens,
|
||||
hash_request_tokens)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@ -277,3 +278,39 @@ def test_hash_request_tokens_no_mm_inputs():
|
||||
assert block_hashes[0].extra_keys is None
|
||||
assert block_hashes[1].token_ids == (3, 4, 5)
|
||||
assert block_hashes[1].extra_keys is None
|
||||
|
||||
|
||||
def test_metrics():
|
||||
"""
|
||||
Test the prefix caching metrics.
|
||||
"""
|
||||
|
||||
def stats(requests, queries, hits):
|
||||
return PrefixCacheStats(requests=requests, queries=queries, hits=hits)
|
||||
|
||||
metrics = PrefixCachingMetrics(interval=5)
|
||||
assert metrics.hit_rate == 0.0
|
||||
|
||||
metrics.observe(stats(1, 20, 9))
|
||||
# 9 / 20 = 0.45
|
||||
assert metrics.hit_rate == 0.45
|
||||
|
||||
metrics.observe(stats(4, 80, 16))
|
||||
|
||||
# 25 / 100 = 0.25
|
||||
assert metrics.hit_rate == 0.25
|
||||
|
||||
metrics.observe(stats(1, 10, 2))
|
||||
|
||||
# Remove (20, 9) and add (10, 2): 18 / 90 = 0.2
|
||||
assert metrics.aggregated_requests == 5
|
||||
assert metrics.aggregated_query_total == 90
|
||||
assert metrics.aggregated_query_hit == 18
|
||||
assert metrics.hit_rate == 0.2
|
||||
|
||||
metrics.reset()
|
||||
assert metrics.hit_rate == 0.0
|
||||
assert metrics.aggregated_requests == 0
|
||||
assert metrics.aggregated_query_total == 0
|
||||
assert metrics.aggregated_query_hit == 0
|
||||
assert not metrics.query_queue
|
||||
|
||||
@ -38,7 +38,8 @@ def create_scheduler(
|
||||
return Scheduler(scheduler_config,
|
||||
model_config,
|
||||
cache_config,
|
||||
lora_config=None)
|
||||
lora_config=None,
|
||||
log_stats=True)
|
||||
|
||||
|
||||
def create_requests(
|
||||
|
||||
@ -50,7 +50,8 @@ def test_engine_core(monkeypatch):
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class)
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
"""Test basic request lifecycle."""
|
||||
|
||||
# First request.
|
||||
@ -157,7 +158,8 @@ def test_engine_core_advanced_sampling(monkeypatch):
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class)
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
"""Test basic request lifecycle."""
|
||||
# First request.
|
||||
request: EngineCoreRequest = make_request()
|
||||
|
||||
@ -94,6 +94,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
@ -163,6 +164,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
@ -15,6 +16,7 @@ from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
def _ref_convert_id_to_token(
|
||||
@ -603,6 +605,7 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||
log_stats=True)
|
||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
||||
engine_core_timestamp = time.monotonic()
|
||||
|
||||
# Make N requests.
|
||||
requests = [
|
||||
@ -630,8 +633,9 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
|
||||
# First iteration has 2 prefills.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
iteration_stats = processed_outputs.iteration_stats
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
total_prompt_tokens = sum([
|
||||
len(prompt_tokens)
|
||||
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
|
||||
@ -642,8 +646,9 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
|
||||
# Just decodes in this step.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
iteration_stats = processed_outputs.iteration_stats
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == 0
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
@ -652,8 +657,9 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
output_processor.add_request(inactive_request)
|
||||
num_active += 1
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
iteration_stats = processed_outputs.iteration_stats
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
||||
@ -661,8 +667,9 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
|
||||
# Just decodes in this step.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
iteration_stats = processed_outputs.iteration_stats
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == 0
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
|
||||
@ -320,9 +320,14 @@ class PlaceholderAttentionMetadataBuilder(
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
# Some input builders such as ModelInputForCPUBuilder do not have the
|
||||
# "inter_data_list" attribute.
|
||||
# Let's check inter_data_list exists before we reference it.
|
||||
if hasattr(self.input_builder, "inter_data_list"):
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
@ -28,7 +28,6 @@ class FlashConfig:
|
||||
def transpose_p_local(p_local_transposed,
|
||||
p_local,
|
||||
LARGE_TILE_SZ,
|
||||
forward_mask,
|
||||
B_F_SIZE=512):
|
||||
for i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||
@ -46,13 +45,13 @@ def transpose_p_local(p_local_transposed,
|
||||
|
||||
if nisa.get_nc_version() == nisa.nc_version.gen3:
|
||||
p_local_t_tmp[:, j_128_slice] = nisa.dma_transpose(
|
||||
p_local[:, i_j_128_slice], mask=forward_mask)
|
||||
p_local[:, i_j_128_slice])
|
||||
else:
|
||||
p_local_t_tmp[:, j_128_slice] = nisa.nc_transpose(
|
||||
p_local[:, i_j_128_slice], mask=forward_mask)
|
||||
p_local[:, i_j_128_slice])
|
||||
|
||||
p_local_transposed[:, nl.ds(i * B_F_SIZE, B_F_SIZE)] = nl.copy(
|
||||
p_local_t_tmp, dtype=p_local_transposed.dtype, mask=forward_mask)
|
||||
p_local_t_tmp, dtype=p_local_transposed.dtype)
|
||||
|
||||
|
||||
@nki.jit
|
||||
@ -60,36 +59,25 @@ def _flash_attention_core(
|
||||
q_local_tile,
|
||||
k,
|
||||
v,
|
||||
q_h_per_k_h,
|
||||
seqlen_q,
|
||||
nheads,
|
||||
o_buffer,
|
||||
l_buffer,
|
||||
m_buffer,
|
||||
batch_id,
|
||||
head_id,
|
||||
gqa_head_idx,
|
||||
q_tile_idx,
|
||||
local_k_large_tile_idx,
|
||||
kernel_dtype,
|
||||
acc_type,
|
||||
flash_config: FlashConfig,
|
||||
use_causal_mask=False,
|
||||
continuous_batching_mask=None,
|
||||
use_causal_mask,
|
||||
tile_mask,
|
||||
initialize=False,
|
||||
B_P_SIZE=128,
|
||||
B_F_SIZE=512,
|
||||
B_D_SIZE=128,
|
||||
dropout_p=0.0,
|
||||
dropout_p_tensor=None,
|
||||
seed_tensor=None,
|
||||
logit_bias_tile=None,
|
||||
qk_res_buffer=None,
|
||||
):
|
||||
"""
|
||||
The flash attention core function to calculate self attention between a tile
|
||||
of q and a block of K and V.
|
||||
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
|
||||
The q_local_tile has (B_P_SIZE, B_F_SIZE), which is loaded into the SBUF
|
||||
already. The block size of K and V
|
||||
is defined in the seq_tile_size of the flash_config. The results are stored
|
||||
in the following three buffers
|
||||
@ -99,24 +87,9 @@ def _flash_attention_core(
|
||||
"""
|
||||
LARGE_TILE_SZ = flash_config.seq_tile_size
|
||||
num_k_tile_per_large_tile = LARGE_TILE_SZ // B_F_SIZE
|
||||
seqlen_k = k.shape[-1]
|
||||
seqlen_q // B_P_SIZE
|
||||
seqlen_k // B_F_SIZE
|
||||
|
||||
# TODO : support logit_bias with continuous_batching_mask
|
||||
assert not use_causal_mask, "causal mask is not supported."
|
||||
assert (continuous_batching_mask
|
||||
is not None), "continuous_batching_mask input is required."
|
||||
if continuous_batching_mask is not None:
|
||||
assert (
|
||||
logit_bias_tile
|
||||
is None), "continuous_batching_mask does not support logit_bias!"
|
||||
|
||||
# mask are used to only apply computation to the lower half of the matrix,
|
||||
# which reduce the arithmetic intensity by half
|
||||
forward_mask = (q_tile_idx * B_P_SIZE >= local_k_large_tile_idx *
|
||||
LARGE_TILE_SZ if use_causal_mask else None)
|
||||
|
||||
qk_res_buf = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
buffer=nl.sbuf,
|
||||
dtype=acc_type)
|
||||
@ -125,20 +98,27 @@ def _flash_attention_core(
|
||||
for k_i in nl.affine_range(num_k_tile_per_large_tile):
|
||||
k_i_b_f_slice = nl.ds(k_i * B_F_SIZE, B_F_SIZE)
|
||||
|
||||
qk_psum = nl.zeros((par_dim(B_P_SIZE), B_F_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum) # (128, 512)
|
||||
qk_psum[:, :] = nl.matmul(q_local_tile,
|
||||
k[:, k_i_b_f_slice],
|
||||
transpose_x=True,
|
||||
mask=None) # (p(128), 512)
|
||||
if use_causal_mask:
|
||||
multiplication_required_selection = (q_tile_idx * B_P_SIZE
|
||||
>= k_i * B_F_SIZE)
|
||||
else:
|
||||
multiplication_required_selection = True
|
||||
|
||||
qk_res_buf[:, k_i_b_f_slice] = nl.where(
|
||||
continuous_batching_mask[:, k_i_b_f_slice],
|
||||
qk_psum[:, nl.ds(0, B_F_SIZE)],
|
||||
-9984.0,
|
||||
dtype=acc_type,
|
||||
)
|
||||
if multiplication_required_selection:
|
||||
qk_psum = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE),
|
||||
dtype=np.float32,
|
||||
buffer=nl.psum) # (128, 512)
|
||||
qk_psum[:, :] = nl.matmul(q_local_tile,
|
||||
k[:, k_i_b_f_slice],
|
||||
transpose_x=True) # (p(128), 512)
|
||||
qk_res_buf[:, k_i_b_f_slice] = nl.where(
|
||||
tile_mask[:, k_i_b_f_slice],
|
||||
qk_psum[:, nl.ds(0, B_F_SIZE)],
|
||||
-9984.0,
|
||||
dtype=acc_type,
|
||||
)
|
||||
else:
|
||||
qk_res_buf[:, k_i_b_f_slice] = -9984.0
|
||||
|
||||
# Calculate max of the current tile
|
||||
max_local[:, k_i] = nisa.tensor_reduce(
|
||||
@ -147,7 +127,6 @@ def _flash_attention_core(
|
||||
axis=(1, ),
|
||||
dtype=acc_type,
|
||||
negate=False,
|
||||
mask=forward_mask,
|
||||
)
|
||||
|
||||
if qk_res_buffer is not None:
|
||||
@ -159,7 +138,6 @@ def _flash_attention_core(
|
||||
axis=(1, ),
|
||||
dtype=acc_type,
|
||||
negate=False,
|
||||
mask=forward_mask,
|
||||
)
|
||||
|
||||
o_previous_scaled = nl.ndarray((par_dim(B_P_SIZE), B_D_SIZE),
|
||||
@ -170,8 +148,7 @@ def _flash_attention_core(
|
||||
m_current = max_
|
||||
else:
|
||||
m_previous = nl.copy(m_buffer[:, 0])
|
||||
m_buffer[:, 0] = nl.maximum(m_previous, max_,
|
||||
mask=forward_mask) # (128,1)
|
||||
m_buffer[:, 0] = nl.maximum(m_previous, max_) # (128,1)
|
||||
|
||||
m_current = m_buffer[:, 0]
|
||||
# Compute scaling factor
|
||||
@ -180,11 +157,8 @@ def _flash_attention_core(
|
||||
m_previous,
|
||||
bias=-1 * m_current,
|
||||
scale=1.0,
|
||||
mask=forward_mask,
|
||||
)
|
||||
o_previous_scaled[...] = nl.multiply(o_buffer[:, :],
|
||||
alpha,
|
||||
mask=forward_mask)
|
||||
o_previous_scaled[...] = nl.multiply(o_buffer[:, :], alpha)
|
||||
|
||||
p_local = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
@ -207,10 +181,9 @@ def _flash_attention_core(
|
||||
reduce_op=nl.add,
|
||||
reduce_res=p_partial_sum[:, k_r_i],
|
||||
dtype=kernel_dtype,
|
||||
mask=forward_mask,
|
||||
)
|
||||
|
||||
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type, mask=forward_mask)
|
||||
ps = nl.sum(p_partial_sum, axis=1, dtype=acc_type)
|
||||
|
||||
p_local_transposed = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=kernel_dtype)
|
||||
@ -218,7 +191,6 @@ def _flash_attention_core(
|
||||
p_local_transposed=p_local_transposed,
|
||||
p_local=p_local,
|
||||
LARGE_TILE_SZ=LARGE_TILE_SZ,
|
||||
forward_mask=forward_mask,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
)
|
||||
|
||||
@ -230,27 +202,20 @@ def _flash_attention_core(
|
||||
p_local_transposed[:, nl.ds(k_i * B_P_SIZE, B_P_SIZE)],
|
||||
v[k_i, :, :],
|
||||
transpose_x=True,
|
||||
mask=forward_mask,
|
||||
) # (128, 128) (p(Br), d)
|
||||
|
||||
if initialize:
|
||||
o_buffer[:, :] = nl.copy(pv_psum[:, :])
|
||||
l_buffer[:, 0] = nl.add(nl.log(ps), max_)
|
||||
else:
|
||||
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum, mask=forward_mask)
|
||||
o_buffer[:, :] = nl.add(o_previous_scaled, pv_psum)
|
||||
|
||||
l_prev = l_buffer[:, 0]
|
||||
l_exp = nl.add(
|
||||
nl.exp(
|
||||
nl.subtract(l_prev, m_current, mask=forward_mask),
|
||||
mask=forward_mask,
|
||||
),
|
||||
nl.exp(nl.subtract(l_prev, m_current)),
|
||||
ps,
|
||||
mask=forward_mask,
|
||||
)
|
||||
l_buffer[:, 0] = nl.add(m_current,
|
||||
nl.log(l_exp, mask=forward_mask),
|
||||
mask=forward_mask)
|
||||
l_buffer[:, 0] = nl.add(m_current, nl.log(l_exp))
|
||||
|
||||
|
||||
@nki.jit
|
||||
@ -279,6 +244,21 @@ def load_v_tile(v_hbm_tile, cur_v_tile, j, v_i, config):
|
||||
)
|
||||
|
||||
|
||||
@nki.jit
|
||||
def load_block_tables(block_tables_hbm, num_tiles):
|
||||
(num_blocks, ) = block_tables_hbm.shape
|
||||
assert num_blocks % num_tiles == 0
|
||||
num_blocks_per_tile = num_blocks // num_tiles
|
||||
block_tables_hbm = block_tables_hbm.reshape(
|
||||
(num_tiles, num_blocks_per_tile))
|
||||
block_tables_buffer = nl.load(block_tables_hbm, dtype=nl.int32)
|
||||
return block_tables_buffer
|
||||
|
||||
|
||||
def is_power_of_2(x):
|
||||
return x > 0 and (x & (x - 1)) == 0
|
||||
|
||||
|
||||
@nki.jit
|
||||
def flash_paged_attention(
|
||||
query,
|
||||
@ -316,24 +296,24 @@ def flash_paged_attention(
|
||||
- We use paged cache blocks (key_cache, value_cache) to store KV cache.
|
||||
|
||||
IO tensor dtypes:
|
||||
- This kernel assumes all IO tensors have the same dtype except for
|
||||
- This kernel assumes all IO tensors have the same dtype except for
|
||||
block_tables (int32) and mask (int32)
|
||||
- If mixed_percision is True, then all Tensor Engine operation will be
|
||||
performed in bfloat16 and accumulation will be performed in float32.
|
||||
- If mixed_percision is True, then all Tensor Engine operation will be
|
||||
performed in bfloat16 and accumulation will be performed in float32.
|
||||
Otherwise the intermediates will be in the same type as the inputs.
|
||||
|
||||
Compile-time Constants:
|
||||
- softmax_scale: scaling for softmax, is None, default is `1.0/(d**0.5)`
|
||||
- mixed_precision: flag to set non-matmul ops in fp32 precision, default
|
||||
is set to `true`, if false, we use same precision as input types
|
||||
is set to `true`, if false, we use same precision as input types
|
||||
- config: Instance of dataclass :class:`nki.kernels.attention.FlashConfig`
|
||||
with Performance config parameters for flash attention with default
|
||||
values
|
||||
seq_tile_size: `default=2048`, size of the kv tile size for attention
|
||||
seq_tile_size: `default=2048`, size of the kv tile size for attention
|
||||
computation reduction
|
||||
|
||||
GQA support Notes:
|
||||
the spmd kernel for launching kernel should be on kv_heads instead of
|
||||
the spmd kernel for launching kernel should be on kv_heads instead of
|
||||
nheads
|
||||
|
||||
Example usage:
|
||||
@ -415,18 +395,13 @@ def flash_paged_attention(
|
||||
), f"Need B_P_SIZE ({B_P_SIZE}) to be divisible by {block_size=}"
|
||||
num_large_k_tile = context_kv_len // LARGE_TILE_SZ
|
||||
num_blocks_per_large_tile = LARGE_TILE_SZ // block_size
|
||||
assert (num_blocks_per_large_tile <= B_P_SIZE
|
||||
), f"The number of blocks in each large tile " \
|
||||
f"({num_blocks_per_large_tile}) shouldn't exceed partition size {B_P_SIZE}"
|
||||
assert block_size % 32 == 0, "block_size is expected to be a multiple of 32"
|
||||
assert is_power_of_2(
|
||||
num_blocks_per_large_tile
|
||||
), "The number of blocks in each large tile is expected of be power of 2"
|
||||
assert is_power_of_2(seqlen_q), "seqlen_q is expected to be power of 2"
|
||||
|
||||
block_tables_sbuf = nl.full((par_dim(B_P_SIZE), num_large_k_tile),
|
||||
0,
|
||||
dtype=np.int32,
|
||||
buffer=nl.sbuf)
|
||||
for j in nl.affine_range(num_large_k_tile):
|
||||
i_p = nl.arange(num_blocks_per_large_tile)[:, None]
|
||||
block_tables_sbuf[i_p, j] = nl.load(
|
||||
block_tables[j * num_blocks_per_large_tile + i_p], dtype=np.int32)
|
||||
block_tables_sbuf = load_block_tables(block_tables, num_large_k_tile)
|
||||
|
||||
# Global Flash Attention accumulators
|
||||
o_buffer = nl.zeros(
|
||||
@ -457,7 +432,7 @@ def flash_paged_attention(
|
||||
)
|
||||
|
||||
for k_i in nl.affine_range(num_blocks_per_large_tile):
|
||||
loaded = nl.load(key_cache[block_tables_sbuf[k_i, j], :,
|
||||
loaded = nl.load(key_cache[block_tables_sbuf[j, k_i], :,
|
||||
head_id, :])
|
||||
cur_k_tile[:, nl.ds(k_i *
|
||||
block_size, block_size)] = nl.transpose(loaded)
|
||||
@ -469,7 +444,7 @@ def flash_paged_attention(
|
||||
num_blocks_per_partition):
|
||||
v_i = (partition_idx * num_blocks_per_partition +
|
||||
block_in_partition)
|
||||
loaded_v = nl.load(value_cache[block_tables_sbuf[v_i, j], :,
|
||||
loaded_v = nl.load(value_cache[block_tables_sbuf[j, v_i], :,
|
||||
head_id, :])
|
||||
cur_v_tile[
|
||||
partition_idx,
|
||||
@ -477,14 +452,15 @@ def flash_paged_attention(
|
||||
:,
|
||||
] = loaded_v
|
||||
|
||||
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=mask.dtype)
|
||||
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(
|
||||
mask[:, nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE)])
|
||||
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.ndarray((par_dim(B_P_SIZE), LARGE_TILE_SZ),
|
||||
dtype=mask.dtype)
|
||||
for m_i in nl.affine_range(LARGE_TILE_SZ // B_F_SIZE):
|
||||
cur_mask[:, nl.ds(m_i * B_F_SIZE, B_F_SIZE)] = nl.load(mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(j * LARGE_TILE_SZ + m_i * B_F_SIZE, B_F_SIZE),
|
||||
])
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(
|
||||
@ -497,35 +473,24 @@ def flash_paged_attention(
|
||||
q_local_tile=q_tile,
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
q_h_per_k_h=q_h_per_k_h,
|
||||
seqlen_q=seqlen_q,
|
||||
nheads=h,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[:, i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
batch_id=batch_id,
|
||||
head_id=head_id,
|
||||
gqa_head_idx=i_q_h,
|
||||
q_tile_idx=i,
|
||||
local_k_large_tile_idx=j,
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
flash_config=config,
|
||||
use_causal_mask=False,
|
||||
continuous_batching_mask=cur_mask,
|
||||
tile_mask=cur_mask,
|
||||
initialize=j == 0,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
dropout_p=0.0,
|
||||
dropout_p_tensor=None,
|
||||
seed_tensor=None,
|
||||
logit_bias_tile=None,
|
||||
)
|
||||
|
||||
# compute attention between input query, key and value
|
||||
if key is not None and value is not None:
|
||||
B_F_SIZE = seqlen_q
|
||||
B_F_SIZE = min(seqlen_q, B_F_SIZE)
|
||||
LARGE_TILE_SZ = seqlen_q
|
||||
active_config = FlashConfig(
|
||||
seq_tile_size=LARGE_TILE_SZ,
|
||||
@ -552,11 +517,16 @@ def flash_paged_attention(
|
||||
config=active_config,
|
||||
)
|
||||
|
||||
cur_mask = nl.ndarray((par_dim(B_P_SIZE), B_F_SIZE), dtype=mask.dtype)
|
||||
cur_mask[:, :] = nl.load(mask[:, nl.ds(context_kv_len, B_F_SIZE)])
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
cur_mask = nl.load(
|
||||
mask[
|
||||
nl.ds(i * B_P_SIZE, B_P_SIZE),
|
||||
nl.ds(context_kv_len, LARGE_TILE_SZ),
|
||||
],
|
||||
dtype=mask.dtype,
|
||||
)
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
|
||||
for i_q_h in nl.affine_range(q_h_per_k_h):
|
||||
for i in nl.affine_range(n_tile_q):
|
||||
q_tile = nl.ndarray((B_D_SIZE, B_P_SIZE), dtype=kernel_dtype)
|
||||
q_hbm_tile = query[batch_id, head_id * q_h_per_k_h + i_q_h]
|
||||
q_sbuf_tile = nl.load(
|
||||
@ -568,32 +538,21 @@ def flash_paged_attention(
|
||||
q_local_tile=q_tile,
|
||||
k=cur_k_tile,
|
||||
v=cur_v_tile,
|
||||
q_h_per_k_h=q_h_per_k_h,
|
||||
seqlen_q=seqlen_q,
|
||||
nheads=h,
|
||||
o_buffer=o_buffer[i, i_q_h],
|
||||
l_buffer=l_buffer[:, i, i_q_h],
|
||||
m_buffer=m_buffer[i, i_q_h],
|
||||
batch_id=batch_id,
|
||||
head_id=head_id,
|
||||
gqa_head_idx=i_q_h,
|
||||
q_tile_idx=i,
|
||||
local_k_large_tile_idx=0,
|
||||
kernel_dtype=kernel_dtype,
|
||||
acc_type=acc_type,
|
||||
flash_config=active_config,
|
||||
use_causal_mask=False,
|
||||
continuous_batching_mask=cur_mask,
|
||||
use_causal_mask=True,
|
||||
tile_mask=cur_mask,
|
||||
initialize=False,
|
||||
B_P_SIZE=B_P_SIZE,
|
||||
B_F_SIZE=B_F_SIZE,
|
||||
B_D_SIZE=B_D_SIZE,
|
||||
dropout_p=0.0,
|
||||
dropout_p_tensor=None,
|
||||
seed_tensor=None,
|
||||
logit_bias_tile=None,
|
||||
qk_res_buffer=qk_res_buffer[i, i_q_h]
|
||||
if qk_res_buffer is not None else None,
|
||||
qk_res_buffer=(qk_res_buffer[i, i_q_h]
|
||||
if qk_res_buffer is not None else None),
|
||||
)
|
||||
|
||||
# -- -- -- -- write output to buffer on HBM -- -- -- -- -- -- #
|
||||
@ -652,7 +611,6 @@ def flash_attn_varlen_nkifunc(
|
||||
attn_mask,
|
||||
n_kv_head=None,
|
||||
head_size=None,
|
||||
B_P_SIZE=128,
|
||||
LARGE_TILE_SZ=2048,
|
||||
return_debug_tensors=False,
|
||||
mixed_precision=True,
|
||||
|
||||
@ -102,8 +102,9 @@ class ModelConfig:
|
||||
it; otherwise, you must specify explicitly which task to use.
|
||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||
available, "slow" will always use the slow tokenizer, and
|
||||
"mistral" will always use the tokenizer from `mistral_common`.
|
||||
available, "slow" will always use the slow tokenizer,
|
||||
"mistral" will always use the tokenizer from `mistral_common`, and
|
||||
"custom" will use --tokenizer to select the preregistered tokenizer.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
allowed_local_media_path: Allowing API requests to read local images or
|
||||
@ -467,10 +468,10 @@ class ModelConfig:
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
if tokenizer_mode not in ["auto", "slow", "mistral"]:
|
||||
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
|
||||
raise ValueError(
|
||||
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
||||
"either 'auto', 'slow' or 'mistral'.")
|
||||
"either 'auto', 'slow', 'mistral' or 'custom'.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _get_preferred_task(
|
||||
@ -3057,7 +3058,8 @@ class VllmConfig:
|
||||
kv_transfer_config: KVTransferConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
# some opaque config, only used to provide additional information
|
||||
# for the hash computation, mainly used for testing and debugging.
|
||||
# for the hash computation, mainly used for testing, debugging or out of
|
||||
# tree config registration.
|
||||
additional_config: SupportsHash = field(default=None,
|
||||
init=True) # type: ignore
|
||||
instance_id: str = ""
|
||||
|
||||
@ -5,12 +5,14 @@ convenient for use when we just need to call a few functions.
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
import glob
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# this line makes it possible to directly load `libcudart.so` using `ctypes`
|
||||
import torch # noqa
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -60,6 +62,29 @@ def find_loaded_library(lib_name) -> Optional[str]:
|
||||
return path
|
||||
|
||||
|
||||
def get_cudart_lib_path_from_env() -> Optional[str]:
|
||||
"""
|
||||
In some system, find_loaded_library() may not work. So we allow users to
|
||||
specify the path through environment variable VLLM_CUDART_SO_PATH.
|
||||
"""
|
||||
cudart_so_env = envs.VLLM_CUDART_SO_PATH
|
||||
if cudart_so_env is not None:
|
||||
cudart_paths = [
|
||||
cudart_so_env,
|
||||
]
|
||||
for path in cudart_paths:
|
||||
file_paths = glob.glob(path)
|
||||
if len(file_paths) > 0:
|
||||
logger.info(
|
||||
"Found cudart library at %s through env var"
|
||||
"VLLM_CUDART_SO_PATH=%s",
|
||||
file_paths[0],
|
||||
cudart_so_env,
|
||||
)
|
||||
return file_paths[0]
|
||||
return None
|
||||
|
||||
|
||||
class CudaRTLibrary:
|
||||
exported_functions = [
|
||||
# cudaError_t cudaSetDevice ( int device )
|
||||
@ -105,8 +130,13 @@ class CudaRTLibrary:
|
||||
def __init__(self, so_file: Optional[str] = None):
|
||||
if so_file is None:
|
||||
so_file = find_loaded_library("libcudart")
|
||||
if so_file is None:
|
||||
so_file = get_cudart_lib_path_from_env()
|
||||
assert so_file is not None, \
|
||||
"libcudart is not loaded in the current process"
|
||||
(
|
||||
"libcudart is not loaded in the current process, "
|
||||
"try setting VLLM_CUDART_SO_PATH"
|
||||
)
|
||||
if so_file not in CudaRTLibrary.path_to_library_cache:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
CudaRTLibrary.path_to_library_cache[so_file] = lib
|
||||
|
||||
@ -20,6 +20,7 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.plugins import load_general_plugins
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, StoreBoolean
|
||||
@ -203,6 +204,8 @@ class EngineArgs:
|
||||
|
||||
calculate_kv_scales: Optional[bool] = None
|
||||
|
||||
additional_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tokenizer:
|
||||
self.tokenizer = self.model
|
||||
@ -281,11 +284,13 @@ class EngineArgs:
|
||||
'--tokenizer-mode',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer_mode,
|
||||
choices=['auto', 'slow', 'mistral'],
|
||||
choices=['auto', 'slow', 'mistral', 'custom'],
|
||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||
'fast tokenizer if available.\n* "slow" will '
|
||||
'always use the slow tokenizer. \n* '
|
||||
'"mistral" will always use the `mistral_common` tokenizer.')
|
||||
'"mistral" will always use the `mistral_common` tokenizer. \n* '
|
||||
'"custom" will use --tokenizer to select the '
|
||||
'preregistered tokenizer.')
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='Trust remote code from huggingface.')
|
||||
@ -984,6 +989,14 @@ class EngineArgs:
|
||||
'be loaded from the model checkpoint if available. '
|
||||
'Otherwise, the scales will default to 1.0.')
|
||||
|
||||
parser.add_argument(
|
||||
"--additional-config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Additional config for specified platform in JSON format. "
|
||||
"Different platforms may support different configs. Make sure the "
|
||||
"configs are valid for the platform you are using. The input format"
|
||||
" is like '{\"config_key\":\"config_value\"}'")
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -1044,6 +1057,9 @@ class EngineArgs:
|
||||
def create_engine_config(self,
|
||||
usage_context: Optional[UsageContext] = None
|
||||
) -> VllmConfig:
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.pre_register_and_update()
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
self._override_v1_engine_args(usage_context)
|
||||
|
||||
@ -1287,6 +1303,7 @@ class EngineArgs:
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
compilation_config=self.compilation_config,
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
additional_config=self.additional_config,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
@ -1347,6 +1364,12 @@ class AsyncEngineArgs(EngineArgs):
|
||||
parser.add_argument('--disable-log-requests',
|
||||
action='store_true',
|
||||
help='Disable logging requests.')
|
||||
# Initialize plugin to update the parser, for example, The plugin may
|
||||
# adding a new kind of quantization method to --quantization argument or
|
||||
# a new device to --device argument.
|
||||
load_general_plugins()
|
||||
from vllm.platforms import current_platform
|
||||
current_platform.pre_register_and_update(parser)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@ -237,7 +237,7 @@ class Metrics:
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + [Metrics.labelname_finish_reason])
|
||||
|
||||
# Speculatie decoding stats
|
||||
# Speculative decoding stats
|
||||
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
|
||||
name="vllm:spec_decode_draft_acceptance_rate",
|
||||
documentation="Speulative token acceptance rate.",
|
||||
|
||||
@ -1051,9 +1051,9 @@ class LLM:
|
||||
|
||||
def _cross_encoding_score(
|
||||
self,
|
||||
tokenizer: Union[AnyTokenizer],
|
||||
text_1: List[Union[str, TextPrompt, TokensPrompt]],
|
||||
text_2: List[Union[str, TextPrompt, TokensPrompt]],
|
||||
tokenizer: AnyTokenizer,
|
||||
text_1: List[str],
|
||||
text_2: List[str],
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
@ -1176,29 +1176,36 @@ class LLM:
|
||||
if isinstance(text_1, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_1 = [text_1]
|
||||
text_1 = [ensure_str(t) for t in text_1]
|
||||
input_text_1: List[str] = [ensure_str(t) for t in text_1]
|
||||
|
||||
if isinstance(text_2, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
text_2 = [text_2]
|
||||
text_2 = [ensure_str(t) for t in text_2]
|
||||
input_text_2: List[str] = [ensure_str(t) for t in text_2]
|
||||
|
||||
if len(text_1) > 1 and len(text_1) != len(text_2):
|
||||
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
|
||||
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
|
||||
if len(text_1) == 0:
|
||||
if len(input_text_1) == 0:
|
||||
raise ValueError("At least one text element must be given")
|
||||
if len(text_2) == 0:
|
||||
if len(input_text_2) == 0:
|
||||
raise ValueError("At least one text_pair element must be given")
|
||||
|
||||
if self.llm_engine.model_config.is_cross_encoder:
|
||||
return self._cross_encoding_score(tokenizer, text_1, text_2,
|
||||
return self._cross_encoding_score(tokenizer, input_text_1,
|
||||
input_text_2,
|
||||
truncate_prompt_tokens, use_tqdm,
|
||||
lora_request,
|
||||
prompt_adapter_request)
|
||||
else:
|
||||
return self._embedding_score(tokenizer, text_1, text_2,
|
||||
truncate_prompt_tokens, use_tqdm,
|
||||
lora_request, prompt_adapter_request)
|
||||
|
||||
return self._embedding_score(
|
||||
tokenizer,
|
||||
input_text_1, # type: ignore[arg-type]
|
||||
input_text_2, # type: ignore[arg-type]
|
||||
truncate_prompt_tokens,
|
||||
use_tqdm,
|
||||
lora_request,
|
||||
prompt_adapter_request)
|
||||
|
||||
def start_profile(self) -> None:
|
||||
self.llm_engine.start_profile()
|
||||
|
||||
@ -67,6 +67,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
|
||||
]):
|
||||
return None
|
||||
|
||||
# Check if <think> is present in previous or delta.
|
||||
# Keep compatibility with models that don't generate <think> tokens.
|
||||
if self.think_start_token_id in previous_token_ids:
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
# <think> in previous, </think> in delta,
|
||||
@ -85,7 +87,6 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
|
||||
# reasoning content continues
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
elif self.think_start_token_id in delta_token_ids:
|
||||
logger.info(delta_text)
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
# <think> in delta, </think> in delta, extract reasoning content
|
||||
start_index = delta_text.find(self.think_start_token)
|
||||
@ -101,35 +102,46 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
|
||||
# reasoning content continues
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
else:
|
||||
# No <think> in previous or delta, reasoning content continues.
|
||||
return DeltaMessage(content=delta_text)
|
||||
# No <think> in previous or delta, also need to check for </think>.
|
||||
# Because the model may have generated </think> without <think>
|
||||
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
# </think> in delta with more tokens,
|
||||
# extract reasoning content and content
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
# </think> in previous, thinking content ends
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
# no </think> in previous or delta, reasoning content continues
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
# Check if the model output contains the <think> tokens.
|
||||
if (self.think_start_token not in model_output
|
||||
or self.think_end_token not in model_output):
|
||||
# DeepSeek R1 doesn't generate <think> now.
|
||||
# Thus we assume the reasoning content is always at the start.
|
||||
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
|
||||
if self.think_end_token not in model_output:
|
||||
return None, model_output
|
||||
else:
|
||||
# Add a start token if it's missing to keep compatibility.
|
||||
if self.think_start_token not in model_output:
|
||||
model_output = f"{self.think_start_token}{model_output}"
|
||||
# Use a regex to find the reasoning content
|
||||
reasoning_content = self.reasoning_regex.findall(model_output)[0]
|
||||
|
||||
# Remove the reasoning content from the model output
|
||||
# Although deepseek's <think> token is always at the
|
||||
# beginning of the line, we cannot guarantee that the
|
||||
# other models will follow this convention.
|
||||
# Therefore, we need to add :start_index.
|
||||
start_index = model_output.find(self.think_start_token)
|
||||
if start_index != -1:
|
||||
end_index = start_index + len(
|
||||
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
|
||||
)
|
||||
model_output = model_output[:start_index] + \
|
||||
model_output[end_index:]
|
||||
end_index = len(
|
||||
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
|
||||
)
|
||||
final_output = model_output[end_index:]
|
||||
|
||||
if len(model_output) == 0:
|
||||
return reasoning_content, None
|
||||
if len(final_output) == 0:
|
||||
return reasoning_content, None
|
||||
|
||||
return reasoning_content, model_output
|
||||
return reasoning_content, final_output
|
||||
|
||||
@ -28,12 +28,15 @@ from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolCall)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
|
||||
from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls,
|
||||
truncate_tool_call_ids)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -150,11 +153,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"tool_choice = \"required\" is not supported!")
|
||||
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
maybe_serialize_tool_calls(request)
|
||||
truncate_tool_call_ids(request)
|
||||
|
||||
if (request.tool_choice == "auto" and
|
||||
not (self.enable_auto_tools and tool_parser is not None)
|
||||
@ -745,11 +749,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
elif request.tool_choice and type(
|
||||
request.tool_choice) is ChatCompletionNamedToolChoiceParam:
|
||||
|
||||
tool_call_class = MistralToolCall if isinstance(
|
||||
tokenizer, MistralTokenizer) else ToolCall
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCall(function=FunctionCall(
|
||||
tool_call_class(function=FunctionCall(
|
||||
name=request.tool_choice.function.name,
|
||||
arguments=output.text))
|
||||
])
|
||||
|
||||
@ -400,8 +400,7 @@ class OpenAIServing:
|
||||
_chat_template_kwargs.update(chat_template_kwargs or {})
|
||||
|
||||
request_prompt: Union[str, List[int]]
|
||||
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
|
||||
if is_mistral_tokenizer:
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
request_prompt = apply_mistral_chat_template(
|
||||
tokenizer,
|
||||
messages=messages,
|
||||
|
||||
@ -121,7 +121,7 @@ class OpenAIServingScores(OpenAIServing):
|
||||
|
||||
tokenize_async = make_async(tokenizer.__call__,
|
||||
executor=self._tokenizer_executor)
|
||||
prompt_inputs = await tokenize_async(text=q,
|
||||
prompt_inputs = await tokenize_async(q,
|
||||
text_pair=t,
|
||||
**tokenization_kwargs)
|
||||
|
||||
|
||||
@ -33,7 +33,7 @@ class MistralToolCall(ToolCall):
|
||||
|
||||
@staticmethod
|
||||
def generate_random_id():
|
||||
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
|
||||
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
|
||||
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
||||
return "".join(choices(ALPHANUMERIC, k=9))
|
||||
|
||||
|
||||
@ -87,6 +87,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||
VLLM_CUDART_SO_PATH: Optional[str] = None
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -572,6 +573,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# models the alignment is already naturally aligned to 256 bytes.
|
||||
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
|
||||
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
|
||||
|
||||
# In some system, find_loaded_library() may not work. So we allow users to
|
||||
# specify the path through environment variable VLLM_CUDART_SO_PATH.
|
||||
"VLLM_CUDART_SO_PATH":
|
||||
lambda: os.getenv("VLLM_CUDART_SO_PATH", None),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@ -8,11 +8,11 @@ from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
||||
import torch.nn as nn
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.platforms
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||
from vllm.utils import make_async
|
||||
@ -108,8 +108,8 @@ class ExecutorBase(ABC):
|
||||
"""
|
||||
# NOTE: This is logged in the executor because there can be >1 workers.
|
||||
logger.info("# %s blocks: %d, # CPU blocks: %d",
|
||||
current_platform.dispatch_key, num_gpu_blocks,
|
||||
num_cpu_blocks)
|
||||
vllm.platforms.current_platform.dispatch_key,
|
||||
num_gpu_blocks, num_cpu_blocks)
|
||||
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
|
||||
self.model_config.max_model_len)
|
||||
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
|
||||
|
||||
@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import msgspec
|
||||
|
||||
import vllm.platforms
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.utils import get_ip
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
@ -54,10 +54,10 @@ try:
|
||||
|
||||
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
|
||||
node_id = ray.get_runtime_context().get_node_id()
|
||||
device_key = current_platform.ray_device_key
|
||||
device_key = vllm.platforms.current_platform.ray_device_key
|
||||
if not device_key:
|
||||
raise RuntimeError("current platform %s does not support ray.",
|
||||
current_platform.device_name)
|
||||
vllm.platforms.current_platform.device_name)
|
||||
gpu_ids = ray.get_runtime_context().get_accelerator_ids(
|
||||
)[device_key]
|
||||
return node_id, gpu_ids
|
||||
|
||||
@ -28,6 +28,11 @@ class UniProcExecutor(ExecutorBase):
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
local_rank = 0
|
||||
# set local rank as the device index if specified
|
||||
device_info = self.vllm_config.device_config.device.__str__().split(
|
||||
":")
|
||||
if len(device_info) > 1:
|
||||
local_rank = int(device_info[1])
|
||||
rank = 0
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
|
||||
@ -254,8 +254,14 @@ class InputPreprocessor:
|
||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||
returning the corresponding token IDs and metadata.
|
||||
"""
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||
# initialized without a tokenizer while using also multi-modal
|
||||
# input.
|
||||
if not self.tokenizer:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config, tokenizer)
|
||||
@ -273,9 +279,15 @@ class InputPreprocessor:
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> MultiModalInputs:
|
||||
"""Async version of :meth:`_process_multimodal`."""
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async(lora_request
|
||||
)
|
||||
# At the moment on model (PrithviGeoSpatialMAE) requires to be
|
||||
# initialized without a tokenizer while using also multi-modal
|
||||
# input.
|
||||
if not self.tokenizer:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer_group = self.get_tokenizer_group()
|
||||
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
|
||||
lora_request)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config, tokenizer)
|
||||
|
||||
@ -31,7 +31,7 @@ def get_bad_words_logits_processors(
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# Mistral tokenizers should not add special tokens
|
||||
prompt_token_ids = tokenizer.encode(prompt=prompt)
|
||||
prompt_token_ids = tokenizer.encode(text=prompt)
|
||||
else:
|
||||
prompt_token_ids = tokenizer.encode(text=prompt,
|
||||
add_special_tokens=False)
|
||||
|
||||
@ -1039,7 +1039,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.linear_method.apply(lm_head, hidden_states)
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states)
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
|
||||
|
||||
@ -5,7 +5,8 @@ import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
|
||||
Union)
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@ -619,12 +620,14 @@ class LoRAModelManager(AdapterModelManager):
|
||||
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
||||
for module_name, new_module_names in self.packed_modules.items():
|
||||
replacement_loras: List[Optional[LoRALayerWeights]] = []
|
||||
replaced_module: Set[str] = set()
|
||||
has_replacement = False
|
||||
for r in new_module_names:
|
||||
lora = lora_model.get_lora(r)
|
||||
replacement_loras.append(lora)
|
||||
if lora:
|
||||
has_replacement = True
|
||||
replaced_module.add(r)
|
||||
if not has_replacement:
|
||||
continue
|
||||
for i in range(len(replacement_loras)):
|
||||
@ -633,6 +636,9 @@ class LoRAModelManager(AdapterModelManager):
|
||||
replacement_loras[i] = None
|
||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||
replacement_loras)
|
||||
# Remove the modules that have been replaced.
|
||||
for module in replaced_module:
|
||||
lora_model.loras.pop(module, None)
|
||||
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
return deactivate_adapter(adapter_id, self._active_adapters,
|
||||
|
||||
@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
# 5 is the number of indicies tensors.
|
||||
# 5 is the number of indices tensors.
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices,long_lora_indices
|
||||
self.indices_len: List[Optional[int]] = [None] * 5
|
||||
|
||||
@ -40,6 +40,8 @@ def maybe_backend_fallback(
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
if guided_params.backend == "xgrammar":
|
||||
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
|
||||
xgr_installed)
|
||||
# xgrammar only has x86 wheels for linux, fallback to outlines
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.get_cpu_architecture() is not CpuArchEnum.X86:
|
||||
@ -77,6 +79,13 @@ def maybe_backend_fallback(
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
# If the xgrammar module cannot be imported successfully,
|
||||
# we should still allow users to use guided decoding with a fallback.
|
||||
elif not xgr_installed:
|
||||
logger.warning("xgrammar module cannot be imported successfully. "
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
if (guided_params.backend == "outlines"
|
||||
and guided_params.json_object is not None):
|
||||
# outlines doesn't support json_object, fallback to xgrammar
|
||||
|
||||
@ -14,7 +14,9 @@ from transformers import PreTrainedTokenizerFast
|
||||
try:
|
||||
import xgrammar as xgr
|
||||
from xgrammar.base import _core as xgr_core
|
||||
xgr_installed = True
|
||||
except ImportError:
|
||||
xgr_installed = False
|
||||
pass
|
||||
|
||||
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
|
||||
|
||||
@ -290,29 +290,30 @@ class ColumnParallelLinear(LinearBase):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[list[int]] = None,
|
||||
prefix: str = ""):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config, prefix)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.quant_method is not None
|
||||
self.output_size_per_partition = divide(self.output_size, tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
@ -335,6 +336,12 @@ class ColumnParallelLinear(LinearBase):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
|
||||
# Special case for GGUF
|
||||
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
||||
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
||||
@ -343,13 +350,12 @@ class ColumnParallelLinear(LinearBase):
|
||||
|
||||
# Materialize GGUF UninitializedParameter
|
||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
||||
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
||||
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
||||
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
||||
# bitsandbytes loads the weights of the specific portion
|
||||
# no need to narrow
|
||||
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
||||
final_shape = list(loaded_weight.shape)
|
||||
if output_dim is not None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert final_shape[output_dim] % tp_size == 0
|
||||
final_shape[output_dim] = final_shape[output_dim] // tp_size
|
||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||
|
||||
param_data = param.data
|
||||
if output_dim is not None and not is_sharded_weight:
|
||||
@ -1039,22 +1045,24 @@ class RowParallelLinear(LinearBase):
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config, prefix)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=[self.output_size],
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
|
||||
@ -108,9 +108,9 @@ class LogitsProcessor(nn.Module):
|
||||
embedding_bias: Optional[torch.Tensor],
|
||||
) -> Optional[torch.Tensor]:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.linear_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
logits = lm_head.quant_method.apply(lm_head,
|
||||
hidden_states,
|
||||
bias=embedding_bias)
|
||||
|
||||
# Gather logits for TP
|
||||
logits = self._gather_logits(logits)
|
||||
|
||||
@ -13,15 +13,17 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
|
||||
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
|
||||
marlin_permute_scales, moe_awq_to_marlin_zero_points,
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
check_marlin_supports_layer, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales,
|
||||
moe_awq_to_marlin_zero_points, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
@ -40,18 +42,17 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
8: scalar_types.uint8,
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
modules_to_not_convert: Optional[List[str]] = None) -> None:
|
||||
modules_to_not_convert: Optional[List[str]],
|
||||
full_config: Dict[str, Any]) -> None:
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.weight_bits = weight_bits
|
||||
self.modules_to_not_convert = modules_to_not_convert or []
|
||||
self.full_config = full_config
|
||||
|
||||
if self.weight_bits not in self.TYPE_MAP:
|
||||
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
|
||||
@ -96,7 +97,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
|
||||
modules_to_not_convert)
|
||||
modules_to_not_convert, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
@ -124,6 +125,13 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
return UnquantizedLinearMethod()
|
||||
# Check if the layer is supported by AWQMarlin.
|
||||
if not check_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
f"Layer '{prefix}' is not supported by AWQMarlin. "
|
||||
"Falling back to unoptimized AWQ kernels.")
|
||||
return AWQConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
return AWQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return AWQMoEMethod(self)
|
||||
|
||||
@ -3,16 +3,17 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
@ -32,7 +33,33 @@ class GPTQConfig(QuantizationConfig):
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
@ -47,7 +74,8 @@ class GPTQConfig(QuantizationConfig):
|
||||
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}),"
|
||||
f"lm_head_quantized={self.lm_head_quantized}")
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@ -68,19 +96,20 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized)
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
|
||||
dynamic)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return GPTQLinearMethod(self)
|
||||
return None
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
|
||||
|
||||
class ExllamaState(Enum):
|
||||
|
||||
@ -9,17 +9,21 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig, choose_mp_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, marlin_moe_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
@ -47,12 +51,41 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
||||
) -> None:
|
||||
if desc_act and group_size == -1:
|
||||
# In this case, act_order == True is the same as act_order == False
|
||||
# (since we have only one group per output channel)
|
||||
desc_act = False
|
||||
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
# Format is Dict[str, Dict] where key is a regex string that can
|
||||
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
||||
# matching of a module.
|
||||
# Default to positive match, override base quant config mode, if no
|
||||
# prefix is used. Value is in dict format of field key and override
|
||||
# value.
|
||||
# Negative matching will skip quantization init for this module
|
||||
# entirely:
|
||||
# non-quantized inference. More details and quantization examples can be
|
||||
# found at: https://github.com/ModelCloud/GPTQModel
|
||||
# Example:
|
||||
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
||||
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
||||
# dynamic = {
|
||||
# #`.*\.` matches the layers_node prefix
|
||||
# # positive match layer 10-15
|
||||
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
||||
# # positive match layer 16-21
|
||||
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
||||
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
||||
# }
|
||||
self.dynamic = dynamic
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.is_sym = is_sym
|
||||
|
||||
self.pack_factor = 32 // weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
@ -68,7 +101,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@ -88,6 +122,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
||||
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
||||
dynamic = {} if dynamic is None else dynamic
|
||||
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
@ -95,7 +132,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym,
|
||||
lm_head_quantized)
|
||||
lm_head_quantized, dynamic)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
@ -120,17 +157,15 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]:
|
||||
if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead)
|
||||
and self.lm_head_quantized):
|
||||
return GPTQMarlinLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod",
|
||||
UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]:
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return None
|
||||
return get_linear_quant_method(self, layer, prefix,
|
||||
GPTQMarlinLinearMethod)
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits")
|
||||
group_size = quant_config.get("group_size")
|
||||
@ -143,7 +178,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
# Marlin conversion is only valid if required properties are found
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
@ -16,6 +16,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supports_layer)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -87,8 +89,8 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
modules_to_not_convert = []
|
||||
elif linear_quant_method == "awq":
|
||||
has_zp = cls.get_from_keys(config, ["zero_point"])
|
||||
modules_to_not_convert = cls.get_from_keys(
|
||||
config, ["modules_to_not_convert"])
|
||||
modules_to_not_convert = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None)
|
||||
else:
|
||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||
|
||||
@ -135,7 +137,8 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
return GPTQConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
elif self.linear_quant_method == "awq":
|
||||
if self.use_marlin:
|
||||
if self.use_marlin and check_marlin_supports_layer(
|
||||
layer, self.group_size):
|
||||
return AWQMarlinConfig.from_config(
|
||||
self.full_config).get_quant_method(layer, prefix)
|
||||
else:
|
||||
|
||||
94
vllm/model_executor/layers/quantization/utils/gptq_utils.py
Normal file
94
vllm/model_executor/layers/quantization/utils/gptq_utils.py
Normal file
@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import QuantizationConfig
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, UnquantizedEmbeddingMethod)
|
||||
|
||||
|
||||
# Match dynamic rules with module name (prefix) and override quantize
|
||||
# config if module (prefix) matches a rule
|
||||
def override_config(config: QuantizationConfig, prefix: str):
|
||||
weight_bits = get_dynamic_override(config, prefix, "bits",
|
||||
config.weight_bits)
|
||||
if isinstance(weight_bits, int):
|
||||
config.weight_bits = weight_bits
|
||||
group_size = get_dynamic_override(config, prefix, "group_size",
|
||||
config.group_size)
|
||||
if isinstance(group_size, int):
|
||||
config.group_size = group_size
|
||||
desc_act = get_dynamic_override(config, prefix, "desc_act",
|
||||
config.desc_act)
|
||||
if isinstance(desc_act, bool):
|
||||
config.desc_act = desc_act
|
||||
|
||||
config.pack_factor = 32 // config.weight_bits # packed into int32
|
||||
if config.get_name() == "gptq_marlin":
|
||||
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
||||
if isinstance(is_sym, bool):
|
||||
config.is_sym = is_sym
|
||||
|
||||
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
||||
raise ValueError("Unsupported quantization config: "
|
||||
f"bits={config.weight_bits}, sym={config.is_sym}")
|
||||
|
||||
config.quant_type = config.TYPE_MAP[(config.weight_bits,
|
||||
config.is_sym)]
|
||||
elif config.get_name() == "gptq":
|
||||
if config.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {config.weight_bits} bits.")
|
||||
|
||||
|
||||
def get_dynamic_override(
|
||||
config: QuantizationConfig,
|
||||
layer_name: str,
|
||||
key: Optional[str] = None,
|
||||
default_value: Union[int, bool,
|
||||
None] = None) -> Union[Dict, int, bool, None]:
|
||||
for pattern, pattern_dict in config.dynamic.items():
|
||||
# Negative match: matched modules are excluded from quantized init
|
||||
if pattern.startswith("-:"):
|
||||
if re.match(pattern.removeprefix("-:"), layer_name):
|
||||
return False
|
||||
# Positive match: matched modules have quant properties overrides
|
||||
# base quant config
|
||||
elif re.match(pattern.removeprefix("+:"), layer_name):
|
||||
if key is None:
|
||||
return pattern_dict
|
||||
else:
|
||||
return pattern_dict.get(key, default_value)
|
||||
return default_value
|
||||
|
||||
|
||||
def get_linear_quant_method(
|
||||
config: QuantizationConfig,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
linear_method_cls: type,
|
||||
):
|
||||
cloned_config = deepcopy(config)
|
||||
parallel_lm_head_quantized = isinstance(
|
||||
layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
||||
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
||||
# False = skip module, None = no override, else = Positive match
|
||||
if get_dynamic_override( # noqa: E712
|
||||
cloned_config, # noqa: E712
|
||||
layer_name=prefix) == False: # noqa: E712
|
||||
if parallel_lm_head_quantized:
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if prefix:
|
||||
# Dynamic per module/layer rules may override base config
|
||||
override_config(cloned_config, prefix=prefix)
|
||||
|
||||
return linear_method_cls(cloned_config)
|
||||
return None
|
||||
@ -6,6 +6,7 @@ import numpy
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
|
||||
@ -135,6 +136,20 @@ def check_marlin_supports_shape(output_size_per_partition: int,
|
||||
return True, None
|
||||
|
||||
|
||||
def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
|
||||
-> bool:
|
||||
output_size_per_partition = getattr(layer, "output_size_per_partition",
|
||||
None) or layer.output_size
|
||||
input_size_per_partition = getattr(layer, "input_size_per_partition",
|
||||
None) or layer.input_size
|
||||
|
||||
return check_marlin_supports_shape(
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=layer.input_size,
|
||||
group_size=group_size)[0]
|
||||
|
||||
|
||||
def marlin_make_workspace(output_size_per_partition: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
max_workspace_size = (output_size_per_partition //
|
||||
|
||||
@ -226,24 +226,24 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.tp_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
linear_method = None
|
||||
quant_method = None
|
||||
if quant_config is not None:
|
||||
linear_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedEmbeddingMethod()
|
||||
quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||
if quant_method is None:
|
||||
quant_method = UnquantizedEmbeddingMethod()
|
||||
|
||||
# If we are making an embedding layer, then our quantization linear
|
||||
# method must implement the embedding operation. If we are another
|
||||
# layer type like ParallelLMHead, this is not important.
|
||||
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
||||
linear_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(linear_method))
|
||||
if is_embedding_layer and not linear_method_implements_embedding:
|
||||
quant_method_implements_embedding = method_has_implemented_embedding(
|
||||
type(quant_method))
|
||||
if is_embedding_layer and not quant_method_implements_embedding:
|
||||
raise NotImplementedError(
|
||||
f"The class {type(linear_method).__name__} must implement "
|
||||
f"The class {type(quant_method).__name__} must implement "
|
||||
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
|
||||
|
||||
self.linear_method: QuantizeMethodBase = linear_method
|
||||
self.quant_method: QuantizeMethodBase = quant_method
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@ -260,13 +260,13 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
|
||||
self.linear_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
self.quant_method.create_weights(self,
|
||||
self.embedding_dim,
|
||||
[self.num_embeddings_per_partition],
|
||||
self.embedding_dim,
|
||||
self.num_embeddings_padded,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
@classmethod
|
||||
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
||||
@ -412,8 +412,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = self.linear_method.embedding(self,
|
||||
masked_input.long())
|
||||
output_parallel = self.quant_method.embedding(self,
|
||||
masked_input.long())
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
||||
|
||||
238
vllm/model_executor/models/prithvi_geospatial_mae.py
Normal file
238
vllm/model_executor/models/prithvi_geospatial_mae.py
Normal file
@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 IBM.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (IsAttentionFree,
|
||||
SupportsMultiModal)
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
PoolingSequenceGroupOutput)
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
pass
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEInputBuilder(
|
||||
BaseDummyInputsBuilder[PrithviGeoSpatialMAEProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
# This model input is fixed and is in the form of a torch Tensor.
|
||||
# The size of pixel_values might change in the cases where we resize
|
||||
# the input but never exceeds the dimensions below.
|
||||
mm_data={
|
||||
"pixel_values": torch.full((1, 6, 512, 512), 1.0),
|
||||
"location_coords": torch.full((1, 2), 1.0)
|
||||
})
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
location_coords=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
pass
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
pass
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputs:
|
||||
mm_kwargs = {}
|
||||
|
||||
for k, v in mm_data.items():
|
||||
mm_kwargs[k] = v
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=[1],
|
||||
mm_kwargs=MultiModalKwargs(mm_kwargs),
|
||||
mm_placeholders={},
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
PrithviGeoSpatialMAEMultiModalProcessor,
|
||||
info=PrithviGeoSpatialMAEProcessingInfo,
|
||||
dummy_inputs=PrithviGeoSpatialMAEInputBuilder)
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
""" Prithvi Masked Autoencoder"""
|
||||
|
||||
def _instantiate_model(self, config: dict) -> nn.Module | None:
|
||||
|
||||
# We might be able/need to support different tasks with this same model
|
||||
if config["task_args"]["task"] == "SemanticSegmentationTask":
|
||||
from terratorch.cli_tools import SemanticSegmentationTask
|
||||
task = SemanticSegmentationTask(
|
||||
config["model_args"],
|
||||
config["task_args"]["model_factory"],
|
||||
loss=config["task_args"]["loss"],
|
||||
lr=config["task_args"]["lr"],
|
||||
ignore_index=config["task_args"]["ignore_index"],
|
||||
optimizer=config["task_args"]["optimizer"],
|
||||
optimizer_hparams=config["optimizer_params"],
|
||||
scheduler=config["task_args"]["scheduler"],
|
||||
scheduler_hparams=config["scheduler_params"],
|
||||
plot_on_val=config["task_args"]["plot_on_val"],
|
||||
freeze_decoder=config["task_args"]["freeze_decoder"],
|
||||
freeze_backbone=config["task_args"]["freeze_backbone"])
|
||||
|
||||
return task.model
|
||||
else:
|
||||
return None
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
# the actual model is dynamically instantiated using terratorch
|
||||
# allowing us to perform changes to the model architecture
|
||||
# at startup time (e.g., change the model decoder class.)
|
||||
self.model = self._instantiate_model(
|
||||
vllm_config.model_config.hf_config.to_dict()["pretrained_cfg"])
|
||||
if self.model is None:
|
||||
raise ValueError(
|
||||
"Unsupported task."
|
||||
"Only SemanticSegmentationTask is supported for now"
|
||||
"by PrithviGeospatialMAE.")
|
||||
|
||||
def _parse_and_validate_multimodal_data(
|
||||
self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]:
|
||||
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of pixel_values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
pixel_values = torch.unbind(pixel_values, dim=0)[0]
|
||||
|
||||
location_coords = kwargs.pop("location_coords", None)
|
||||
if not isinstance(location_coords, torch.Tensor):
|
||||
raise ValueError(f"Incorrect type of location_coords. "
|
||||
f"Got type: {type(location_coords)}")
|
||||
location_coords = torch.unbind(location_coords, dim=0)[0]
|
||||
if location_coords.shape == torch.Size([0]):
|
||||
location_coords = None
|
||||
|
||||
return pixel_values, location_coords
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
):
|
||||
|
||||
pixel_values, location_coords = (
|
||||
self._parse_and_validate_multimodal_data(**kwargs))
|
||||
model_output = self.model(pixel_values,
|
||||
location_coords=location_coords)
|
||||
|
||||
return model_output.output
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)])
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_list = []
|
||||
model_buffers = dict(self.named_buffers())
|
||||
loaded_buffers = []
|
||||
for key, value in weights:
|
||||
if key == "state_dict":
|
||||
weights_to_parse = value
|
||||
for name, weight in weights_to_parse.items():
|
||||
if "pos_embed" in name:
|
||||
continue
|
||||
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
|
||||
# this model requires a couple of buffers to be loaded
|
||||
# that are not loadable with the AutoWeightsLoader
|
||||
if name in model_buffers:
|
||||
if "_timm_module." in name:
|
||||
name = name.replace("_timm_module.", "")
|
||||
buffer = model_buffers[name]
|
||||
weight_loader = getattr(buffer, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(buffer, weight)
|
||||
loaded_buffers.append(name)
|
||||
else:
|
||||
params_list.append((name, weight))
|
||||
break
|
||||
|
||||
# Load the remaining model parameters
|
||||
loader = AutoWeightsLoader(self)
|
||||
autoloaded_weights = loader.load_weights(params_list)
|
||||
|
||||
return autoloaded_weights.union(set(loaded_buffers))
|
||||
@ -800,7 +800,11 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
preprocessed_size = ImageSize(width=image_width,
|
||||
height=image_height)
|
||||
|
||||
grid_t = max(num_frames // temporal_patch_size, 1)
|
||||
# NOTE: Frames are padded to be divisible by `temporal_patch_size`
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py#L294
|
||||
padded_num_frames = num_frames + num_frames % temporal_patch_size
|
||||
|
||||
grid_t = max(padded_num_frames // temporal_patch_size, 1)
|
||||
grid_h = preprocessed_size.height // patch_size
|
||||
grid_w = preprocessed_size.width // patch_size
|
||||
|
||||
|
||||
@ -137,6 +137,10 @@ _EMBEDDING_MODELS = {
|
||||
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
|
||||
# [Auto-converted (see adapters.py)]
|
||||
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForCausalLM"),
|
||||
# Technically PrithviGeoSpatialMAE is a model that works on images, both in
|
||||
# input and output. I am adding it here because it piggy-backs on embedding
|
||||
# models for the time being.
|
||||
"PrithviGeoSpatialMAE": ("prithvi_geospatial_mae", "PrithviGeoSpatialMAE"),
|
||||
}
|
||||
|
||||
_CROSS_ENCODER_MODELS = {
|
||||
@ -201,6 +205,14 @@ _VLLM_MODELS = {
|
||||
**_FALLBACK_MODEL,
|
||||
}
|
||||
|
||||
# This variable is used as the args for subprocess.run(). We
|
||||
# can modify this variable to alter the args if needed. e.g.
|
||||
# when we use par format to pack things together, sys.executable
|
||||
# might not be the target we want to run.
|
||||
_SUBPROCESS_COMMAND = [
|
||||
sys.executable, "-m", "vllm.model_executor.models.registry"
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ModelInfo:
|
||||
@ -498,10 +510,9 @@ def _run_in_subprocess(fn: Callable[[], _T]) -> _T:
|
||||
|
||||
# cannot use `sys.executable __file__` here because the script
|
||||
# contains relative imports
|
||||
returned = subprocess.run(
|
||||
[sys.executable, "-m", "vllm.model_executor.models.registry"],
|
||||
input=input_bytes,
|
||||
capture_output=True)
|
||||
returned = subprocess.run(_SUBPROCESS_COMMAND,
|
||||
input=input_bytes,
|
||||
capture_output=True)
|
||||
|
||||
# check if the subprocess is successful
|
||||
try:
|
||||
|
||||
@ -143,6 +143,7 @@ class TransformersModel(nn.Module):
|
||||
self.model: PreTrainedModel = AutoModel.from_config(
|
||||
self.config,
|
||||
attn_implementation="vllm",
|
||||
torch_dtype=vllm_config.model_config.dtype,
|
||||
trust_remote_code=vllm_config.model_config.trust_remote_code,
|
||||
)
|
||||
prefix = self.model.base_model_prefix
|
||||
|
||||
@ -258,27 +258,35 @@ class UltravoxProjector(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_dim = config.hidden_size
|
||||
self._pad_and_stack = StackAudioFrames(config.stack_factor)
|
||||
dim = config.audio_config.hidden_size * config.stack_factor
|
||||
self.ln_pre = RMSNorm(dim)
|
||||
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
|
||||
dim = self.hidden_dim
|
||||
dim_in = config.audio_config.hidden_size * config.stack_factor
|
||||
self.ln_pre = RMSNorm(dim_in)
|
||||
self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
|
||||
dim_mid = self.hidden_dim
|
||||
|
||||
if config.projector_act == "swiglu":
|
||||
self.act = MulAndSilu()
|
||||
dim = dim // 2
|
||||
dim_mid = dim_mid // 2
|
||||
else:
|
||||
self.act = get_act_fn(config.projector_act)
|
||||
|
||||
self.linear_2 = nn.Linear(dim,
|
||||
config.text_config.hidden_size,
|
||||
bias=False)
|
||||
self.ln_post = RMSNorm(config.text_config.hidden_size)
|
||||
dim_out = config.text_config.hidden_size
|
||||
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
|
||||
|
||||
# Ultravox v0.4.1 and below use layer_norm after the second linear layer
|
||||
# while v0.5.0 and above uses layer_norm after the first linear layer.
|
||||
if config.projector_ln_mid:
|
||||
self.ln_mid: nn.Module = RMSNorm(dim_mid)
|
||||
self.ln_post = nn.Identity()
|
||||
else:
|
||||
self.ln_mid = nn.Identity()
|
||||
self.ln_post = RMSNorm(dim_out)
|
||||
|
||||
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
||||
audio_features = self._pad_and_stack(audio_features)
|
||||
audio_features = self.ln_pre(audio_features)
|
||||
hidden_states = self.linear_1(audio_features)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.ln_mid(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states = self.ln_post(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@ -115,6 +115,9 @@ class CpuPlatform(Platform):
|
||||
# Environment variables for CPU executor
|
||||
#
|
||||
|
||||
# Set default threads num for OpenMP parallel
|
||||
os.environ["OMP_NUM_THREADS"] = str(torch.get_num_threads())
|
||||
|
||||
# Disable torch async compiling which won't work with daemonic processes
|
||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
||||
|
||||
|
||||
@ -334,10 +334,10 @@ class NvmlCudaPlatform(CudaPlatformBase):
|
||||
if (len(set(device_names)) > 1
|
||||
and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID"):
|
||||
logger.warning(
|
||||
"Detected different devices in the system: \n%s\nPlease"
|
||||
"Detected different devices in the system: %s. Please"
|
||||
" make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to "
|
||||
"avoid unexpected behavior.",
|
||||
"\n".join(device_names),
|
||||
", ".join(device_names),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -13,8 +13,10 @@ from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
else:
|
||||
VllmConfig = None
|
||||
FlexibleArgumentParser = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -223,6 +225,22 @@ class Platform:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
@classmethod
|
||||
def pre_register_and_update(cls,
|
||||
parser: Optional[FlexibleArgumentParser] = None
|
||||
) -> None:
|
||||
"""
|
||||
Do some pre-registeration or update action for the current platform.
|
||||
|
||||
This function is called before global VllmConfig is initialized or cli
|
||||
arguments are parsed. It's used for out-of-tree platforms to register or
|
||||
update the configuration.
|
||||
|
||||
For example, the out-of-tree quantization config can be imported and
|
||||
registered here dynamically.
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
"""
|
||||
|
||||
@ -169,4 +169,5 @@ class RocmPlatform(Platform):
|
||||
device: Optional[torch.types.Device] = None
|
||||
) -> float:
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return torch.cuda.max_memory_allocated(device)
|
||||
return torch.cuda.mem_get_info(device)[1] - torch.cuda.mem_get_info(
|
||||
device)[0]
|
||||
|
||||
@ -4,12 +4,14 @@ import enum
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Type, Union
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
|
||||
try_to_load_from_cache)
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import list_repo_files as hf_list_repo_files
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
||||
HFValidationError, LocalEntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
@ -86,6 +88,65 @@ class ConfigFormat(str, enum.Enum):
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
def with_retry(func: Callable[[], Any],
|
||||
log_msg: str,
|
||||
max_retries: int = 2,
|
||||
retry_delay: int = 2):
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func()
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("%s: %s", log_msg, e)
|
||||
raise
|
||||
logger.error("%s: %s, retrying %d of %d", log_msg, e, attempt + 1,
|
||||
max_retries)
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
|
||||
|
||||
# @cache doesn't cache exceptions
|
||||
@cache
|
||||
def list_repo_files(
|
||||
repo_id: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
token: Union[str, bool, None] = None,
|
||||
) -> list[str]:
|
||||
|
||||
def lookup_files():
|
||||
try:
|
||||
return hf_list_repo_files(repo_id,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
token=token)
|
||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||
# Don't raise in offline mode,
|
||||
# all we know is that we don't have this
|
||||
# file cached.
|
||||
return []
|
||||
|
||||
return with_retry(lookup_files, "Error retrieving file list")
|
||||
|
||||
|
||||
def file_exists(
|
||||
repo_id: str,
|
||||
file_name: str,
|
||||
*,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
token: Union[str, bool, None] = None,
|
||||
) -> bool:
|
||||
|
||||
file_list = list_repo_files(repo_id,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
token=token)
|
||||
return file_name in file_list
|
||||
|
||||
|
||||
# In offline mode the result can be a false negative
|
||||
def file_or_path_exists(model: Union[str, Path], config_name: str,
|
||||
revision: Optional[str]) -> bool:
|
||||
if Path(model).exists():
|
||||
@ -103,31 +164,10 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
|
||||
# hf_hub. This will fail in offline mode.
|
||||
|
||||
# Call HF to check if the file exists
|
||||
# 2 retries and exponential backoff
|
||||
max_retries = 2
|
||||
retry_delay = 2
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return file_exists(model,
|
||||
config_name,
|
||||
revision=revision,
|
||||
token=HF_TOKEN)
|
||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||
# Don't raise in offline mode,
|
||||
# all we know is that we don't have this
|
||||
# file cached.
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error checking file existence: %s, retrying %d of %d", e,
|
||||
attempt + 1, max_retries)
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("Error checking file existence: %s", e)
|
||||
raise
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
continue
|
||||
return False
|
||||
return file_exists(str(model),
|
||||
config_name,
|
||||
revision=revision,
|
||||
token=HF_TOKEN)
|
||||
|
||||
|
||||
def patch_rope_scaling(config: PretrainedConfig) -> None:
|
||||
@ -208,32 +248,7 @@ def get_config(
|
||||
revision=revision):
|
||||
config_format = ConfigFormat.MISTRAL
|
||||
else:
|
||||
# If we're in offline mode and found no valid config format, then
|
||||
# raise an offline mode error to indicate to the user that they
|
||||
# don't have files cached and may need to go online.
|
||||
# This is conveniently triggered by calling file_exists().
|
||||
|
||||
# Call HF to check if the file exists
|
||||
# 2 retries and exponential backoff
|
||||
max_retries = 2
|
||||
retry_delay = 2
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
file_exists(model,
|
||||
HF_CONFIG_NAME,
|
||||
revision=revision,
|
||||
token=HF_TOKEN)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error checking file existence: %s, retrying %d of %d",
|
||||
e, attempt + 1, max_retries)
|
||||
if attempt == max_retries:
|
||||
logger.error("Error checking file existence: %s", e)
|
||||
raise e
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
|
||||
raise ValueError(f"No supported config format found in {model}")
|
||||
raise ValueError(f"No supported config format found in {model}.")
|
||||
|
||||
if config_format == ConfigFormat.HF:
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
@ -339,10 +354,11 @@ def get_hf_file_to_dict(file_name: str,
|
||||
file_name=file_name,
|
||||
revision=revision)
|
||||
|
||||
if file_path is None and file_or_path_exists(
|
||||
model=model, config_name=file_name, revision=revision):
|
||||
if file_path is None:
|
||||
try:
|
||||
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
|
||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||
return None
|
||||
except (RepositoryNotFoundError, RevisionNotFoundError,
|
||||
EntryNotFoundError, LocalEntryNotFoundError) as e:
|
||||
logger.debug("File or repository not found in hf_hub_download", e)
|
||||
@ -363,6 +379,7 @@ def get_hf_file_to_dict(file_name: str,
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
||||
"""
|
||||
This function gets the pooling and normalize
|
||||
@ -390,6 +407,8 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
||||
if modules_dict is None:
|
||||
return None
|
||||
|
||||
logger.info("Found sentence-transformers modules configuration.")
|
||||
|
||||
pooling = next((item for item in modules_dict
|
||||
if item["type"] == "sentence_transformers.models.Pooling"),
|
||||
None)
|
||||
@ -408,6 +427,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
||||
if pooling_type_name is not None:
|
||||
pooling_type_name = get_pooling_config_name(pooling_type_name)
|
||||
|
||||
logger.info("Found pooling configuration.")
|
||||
return {"pooling_type": pooling_type_name, "normalize": normalize}
|
||||
|
||||
return None
|
||||
@ -435,6 +455,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def get_sentence_transformer_tokenizer_config(model: str,
|
||||
revision: Optional[str] = 'main'
|
||||
):
|
||||
@ -491,6 +512,8 @@ def get_sentence_transformer_tokenizer_config(model: str,
|
||||
if not encoder_dict:
|
||||
return None
|
||||
|
||||
logger.info("Found sentence-transformers tokenize configuration.")
|
||||
|
||||
if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")):
|
||||
return encoder_dict
|
||||
return None
|
||||
|
||||
@ -45,4 +45,4 @@ __all__ = [
|
||||
"SolarConfig",
|
||||
"Telechat2Config",
|
||||
"UltravoxConfig",
|
||||
]
|
||||
]
|
||||
|
||||
@ -37,6 +37,10 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
The LoRA configuration for finetuning the text model.
|
||||
audio_model_lora_config (`LoraConfigSimplified`, *optional*):
|
||||
The LoRA configuration for finetuning the audio model.
|
||||
projector_ln_mid (`bool`, *optional*, defaults to `False`):
|
||||
Whether to apply layer normalization at the middle of the
|
||||
projector or at the end. Versions v0.4.1 and below
|
||||
use `False`, but v0.5 and above use `True`.
|
||||
"""
|
||||
|
||||
model_type = "ultravox"
|
||||
@ -56,6 +60,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
projector_act: str = "swiglu",
|
||||
text_model_lora_config: Optional[Dict[str, Any]] = None,
|
||||
audio_model_lora_config: Optional[Dict[str, Any]] = None,
|
||||
projector_ln_mid: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
@ -68,6 +73,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
|
||||
self.stack_factor = stack_factor
|
||||
self.norm_init = norm_init
|
||||
self.projector_act = projector_act
|
||||
self.projector_ln_mid = projector_ln_mid
|
||||
|
||||
if text_model_id is not None:
|
||||
# Avoid circular import
|
||||
|
||||
@ -14,6 +14,8 @@ from transformers import (AutoTokenizer, PreTrainedTokenizer,
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
|
||||
TokenizerRegistry)
|
||||
from vllm.transformers_utils.tokenizers import MistralTokenizer
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import make_async
|
||||
@ -21,7 +23,7 @@ from vllm.utils import make_async
|
||||
logger = init_logger(__name__)
|
||||
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||
MistralTokenizer]
|
||||
TokenizerBase]
|
||||
|
||||
|
||||
def decode_tokens(
|
||||
@ -47,11 +49,7 @@ def encode_tokens(
|
||||
Backend-agnostic equivalent of HF's
|
||||
:code:`tokenizer.encode(text, add_special_tokens=...)`.
|
||||
"""
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
return tokenizer.tokenizer.encode(text,
|
||||
bos=add_special_tokens,
|
||||
eos=add_special_tokens)
|
||||
elif add_special_tokens is not None:
|
||||
if add_special_tokens is not None:
|
||||
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
|
||||
return tokenizer.encode(text)
|
||||
|
||||
@ -183,9 +181,17 @@ def get_tokenizer(
|
||||
'encoding and decoding.',
|
||||
FutureWarning,
|
||||
stacklevel=2)
|
||||
|
||||
tokenizer: AnyTokenizer
|
||||
if tokenizer_mode == "mistral":
|
||||
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
|
||||
revision=revision)
|
||||
elif tokenizer_mode == "custom":
|
||||
tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
|
||||
*args,
|
||||
revision=revision,
|
||||
download_dir=download_dir,
|
||||
**kwargs)
|
||||
else:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
|
||||
146
vllm/transformers_utils/tokenizer_base.py
Normal file
146
vllm/transformers_utils/tokenizer_base.py
Normal file
@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import importlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
|
||||
|
||||
class TokenizerBase(ABC):
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_tokens_extended(self) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_tokens(self) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_special_ids(self) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def bos_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def eos_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def sep_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def pad_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_fast(self) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def vocab_size(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def max_token_id(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.vocab_size
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
text: Union[str, List[str], List[int]],
|
||||
text_pair: Optional[str] = None,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_added_vocab(self) -> Dict[str, int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def encode_one(
|
||||
self,
|
||||
text: str,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def encode(self,
|
||||
text: str,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def apply_chat_template(self,
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs) -> List[int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def decode(self,
|
||||
ids: Union[List[int], int],
|
||||
skip_special_tokens: bool = True) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: List[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> List[str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TokenizerRegistry:
|
||||
# Tokenizer name -> (tokenizer module, tokenizer class)
|
||||
REGISTRY: Dict[str, Tuple[str, str]] = {}
|
||||
|
||||
@staticmethod
|
||||
def register(name: str, module: str, class_name: str) -> None:
|
||||
TokenizerRegistry.REGISTRY[name] = (module, class_name)
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> TokenizerBase:
|
||||
tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name)
|
||||
if tokenizer_cls is None:
|
||||
raise ValueError(f"Tokenizer {tokenizer_name} not found.")
|
||||
|
||||
tokenizer_module = importlib.import_module(tokenizer_cls[0])
|
||||
class_ = getattr(tokenizer_module, tokenizer_cls[1])
|
||||
return class_.from_pretrained(*args, **kwargs)
|
||||
@ -1,5 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .mistral import MistralTokenizer, maybe_serialize_tool_calls
|
||||
from .mistral import (MistralTokenizer, maybe_serialize_tool_calls,
|
||||
truncate_tool_call_ids)
|
||||
|
||||
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
|
||||
__all__ = [
|
||||
"MistralTokenizer", "maybe_serialize_tool_calls", "truncate_tool_call_ids"
|
||||
]
|
||||
|
||||
@ -10,6 +10,7 @@ import huggingface_hub
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer_base import TokenizerBase
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -67,6 +68,36 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
|
||||
request.messages[i]["tool_calls"] = validated_tool_calls
|
||||
|
||||
|
||||
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
|
||||
"""Truncates tool call IDs for Mistral's ID requirements."""
|
||||
for i, message in enumerate(request.messages):
|
||||
if message.get("role") == 'assistant':
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
for tool_call in tool_calls:
|
||||
if len(tool_call["id"]) > 9:
|
||||
logger.warning(
|
||||
"Truncating tool call ID: %s to %s",
|
||||
tool_call["id"],
|
||||
tool_call["id"][-9:],
|
||||
)
|
||||
tool_call["id"] = tool_call["id"][-9:]
|
||||
|
||||
request.messages[i]["tool_calls"] = tool_calls
|
||||
|
||||
elif message.get("role") in {"tool_results", "tool"}:
|
||||
if "tool_call_id" in message:
|
||||
tool_call_id = message["tool_call_id"]
|
||||
|
||||
if len(tool_call_id) > 9:
|
||||
logger.warning(
|
||||
"Truncating tool_call_id: %s to %s",
|
||||
tool_call_id,
|
||||
tool_call_id[-9:],
|
||||
)
|
||||
tool_call_id = tool_call_id[-9:]
|
||||
request.messages[i]["tool_call_id"] = tool_call_id
|
||||
|
||||
|
||||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
||||
repo_cache = os.path.join(
|
||||
huggingface_hub.constants.HF_HUB_CACHE,
|
||||
@ -104,7 +135,43 @@ def find_tokenizer_file(files: List[str]):
|
||||
return matched_files[0]
|
||||
|
||||
|
||||
class MistralTokenizer:
|
||||
def make_mistral_chat_completion_request(
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[List[Dict[str,
|
||||
Any]]] = None) -> "ChatCompletionRequest":
|
||||
last_message = cast(Dict[str, Any], messages[-1])
|
||||
if last_message["role"] == "assistant":
|
||||
last_message["prefix"] = True
|
||||
|
||||
last_message = cast(Dict[str, Any], messages[-1])
|
||||
if last_message["role"] == "assistant":
|
||||
last_message["prefix"] = True
|
||||
|
||||
# mistral-common requires AssistantMessage content to be string [1].
|
||||
#
|
||||
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
|
||||
for message in messages:
|
||||
if message.get("role") == "assistant":
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
content = "\n".join(chunk.get("text") for chunk in content)
|
||||
message["content"] = content
|
||||
|
||||
# The Mistral client, in comparison to the OpenAI client, requires the
|
||||
# "parameters" dict to be present, even if it's empty.
|
||||
if tools:
|
||||
for function in [
|
||||
tool["function"] for tool in tools
|
||||
if tool["type"] == "function"
|
||||
]:
|
||||
function.setdefault("parameters", {})
|
||||
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
return ChatCompletionRequest(messages=messages,
|
||||
tools=tools) # type: ignore[type-var]
|
||||
|
||||
|
||||
class MistralTokenizer(TokenizerBase):
|
||||
|
||||
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||
self.mistral = tokenizer
|
||||
@ -215,6 +282,14 @@ class MistralTokenizer:
|
||||
def eos_token_id(self) -> int:
|
||||
return self.tokenizer.eos_id
|
||||
|
||||
@property
|
||||
def sep_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
return True
|
||||
@ -232,25 +307,26 @@ class MistralTokenizer:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str], List[int]],
|
||||
text: Union[str, List[str], List[int]],
|
||||
text_pair: Optional[str] = None,
|
||||
add_special_tokens: bool = False,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
input_ids: Union[List[int], List[List[int]]]
|
||||
# For List[str], original prompt text
|
||||
if is_list_of(prompt, str):
|
||||
if is_list_of(text, str):
|
||||
input_ids_: List[List[int]] = []
|
||||
for p in prompt:
|
||||
for p in text:
|
||||
each_input_ids = self.encode_one(p, truncation, max_length)
|
||||
input_ids_.append(each_input_ids)
|
||||
input_ids = input_ids_
|
||||
# For List[int], apply chat template output, already tokens.
|
||||
elif is_list_of(prompt, int):
|
||||
input_ids = prompt
|
||||
elif is_list_of(text, int):
|
||||
input_ids = text
|
||||
# For str, single prompt text
|
||||
else:
|
||||
input_ids = self.encode_one(prompt, truncation, max_length)
|
||||
input_ids = self.encode_one(text, truncation, max_length)
|
||||
return Encoding(input_ids=input_ids)
|
||||
|
||||
def get_vocab(self) -> Dict[str, int]:
|
||||
@ -264,46 +340,36 @@ class MistralTokenizer:
|
||||
|
||||
def encode_one(
|
||||
self,
|
||||
prompt: str,
|
||||
text: str,
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
) -> List[int]:
|
||||
# Mistral Tokenizers should not add special tokens
|
||||
input_ids = self.encode(prompt)
|
||||
input_ids = self.encode(text)
|
||||
|
||||
if truncation:
|
||||
input_ids = input_ids[:max_length]
|
||||
return input_ids
|
||||
|
||||
def encode(self, prompt: str) -> List[int]:
|
||||
def encode(self,
|
||||
text: str,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
# `encode` should only be used for prompt completion
|
||||
# it should never be used for chat_completion.
|
||||
# For chat completion use `apply_chat_template`
|
||||
return self.tokenizer.encode(prompt, bos=True, eos=False)
|
||||
if add_special_tokens is not None:
|
||||
return self.tokenizer.encode(text,
|
||||
bos=add_special_tokens,
|
||||
eos=add_special_tokens)
|
||||
else:
|
||||
return self.tokenizer.encode(text, bos=True, eos=False)
|
||||
|
||||
def apply_chat_template(self,
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[Dict[str, Any]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs) -> List[int]:
|
||||
|
||||
last_message = cast(Dict[str, Any], messages[-1])
|
||||
if last_message["role"] == "assistant":
|
||||
last_message["prefix"] = True
|
||||
|
||||
from mistral_common.protocol.instruct.request import (
|
||||
ChatCompletionRequest)
|
||||
|
||||
# mistral-common requires AssistantMessage content to be string [1].
|
||||
#
|
||||
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
|
||||
for message in messages:
|
||||
if message.get("role") == "assistant":
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
content = "\n".join(chunk.get("text") for chunk in content)
|
||||
message["content"] = content
|
||||
request = ChatCompletionRequest(messages=messages,
|
||||
tools=tools) # type: ignore[type-var]
|
||||
request = make_mistral_chat_completion_request(messages, tools)
|
||||
encoded = self.mistral.encode_chat_completion(request)
|
||||
|
||||
# encode-decode to get clean prompt
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user