Previous News
+- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
diff --git a/SECURITY.md b/SECURITY.md
index 414669fb3712e..d6319cdb1ac27 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -42,4 +42,9 @@ For certain security issues of CRITICAL, HIGH, or MODERATE severity level, we ma
* If you wish to be added to the prenotification group, please send an email copying all the members of the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). Each vendor contact will be analyzed on a case-by-case basis.
+* Organizations and vendors who either ship or use vLLM, are eligible to join the prenotification group if they meet at least one of the following qualifications
+ * Substantial internal deployment leveraging the upstream vLLM project.
+ * Established internal security teams and comprehensive compliance measures.
+ * Active and consistent contributions to the upstream vLLM project.
+
* We may withdraw organizations from receiving future prenotifications if they release fixes or any other information about issues before they are public. Group membership may also change based on policy refinements for who may be included.
diff --git a/benchmarks/README.md b/benchmarks/README.md
index d6442a4fc3872..38072152b653b 100644
--- a/benchmarks/README.md
+++ b/benchmarks/README.md
@@ -22,6 +22,25 @@ become available.
✅ |
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json |
+
+ | ShareGPT4V (Image) |
+ ✅ |
+ ✅ |
+
+ wget https://huggingface.co/datasets/Lin-Chen/ShareGPT4V/blob/main/sharegpt4v_instruct_gpt4-vision_cap100k.json
+
+ Note that the images need to be downloaded separately. For example, to download COCO's 2017 Train images:
+ wget http://images.cocodataset.org/zips/train2017.zip
+ |
+
+
+ | ShareGPT4Video (Video) |
+ ✅ |
+ ✅ |
+
+ git clone https://huggingface.co/datasets/ShareGPT4Video/ShareGPT4Video
+ |
+
| BurstGPT |
✅ |
@@ -29,7 +48,7 @@ become available.
wget https://github.com/HPMLL/BurstGPT/releases/download/v1.1/BurstGPT_without_fails_2.csv |
- | Sonnet |
+ Sonnet (deprecated) |
✅ |
✅ |
Local file: benchmarks/sonnet.txt |
@@ -40,6 +59,18 @@ become available.
✅ |
synthetic |
+
+ | RandomMultiModal (Image/Video) |
+ 🟡 |
+ 🚧 |
+ synthetic |
+
+
+ | Prefix Repetition |
+ ✅ |
+ ✅ |
+ synthetic |
+
| HuggingFace-VisionArena |
✅ |
@@ -177,6 +208,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
```bash
vllm bench serve \
--backend openai-chat \
+ --endpoint-type openai-chat \
--model Qwen/Qwen2-VL-7B-Instruct \
--endpoint /v1/chat/completions \
--dataset-name hf \
@@ -213,6 +245,7 @@ vllm serve Qwen/Qwen2-VL-7B-Instruct
```bash
vllm bench serve \
--backend openai-chat \
+ --endpoint-type openai-chat \
--model Qwen/Qwen2-VL-7B-Instruct \
--endpoint /v1/chat/completions \
--dataset-name hf \
@@ -227,6 +260,7 @@ vllm bench serve \
```bash
vllm bench serve \
--backend openai-chat \
+ --endpoint-type openai-chat \
--model Qwen/Qwen2-VL-7B-Instruct \
--endpoint /v1/chat/completions \
--dataset-name hf \
@@ -581,6 +615,20 @@ python3 benchmarks/benchmark_prefix_caching.py \
--input-length-range 128:256
```
+### Prefix Repetition Dataset
+
+```bash
+vllm bench serve \
+ --backend openai \
+ --model meta-llama/Llama-2-7b-chat-hf \
+ --dataset-name prefix_repetition \
+ --num-prompts 100 \
+ --prefix-repetition-prefix-len 512 \
+ --prefix-repetition-suffix-len 128 \
+ --prefix-repetition-num-prefixes 5 \
+ --prefix-repetition-output-len 128
+```
+
## ⚡ Example - Request Prioritization Benchmark
@@ -616,3 +664,139 @@ python3 benchmarks/benchmark_prioritization.py \
```
+
+## 👁️ Example - Multi-Modal Benchmark
+
+
+Show more
+
+
+
+Benchmark the performance of multi-modal requests in vLLM.
+
+### Images (ShareGPT4V)
+
+Start vLLM:
+
+```bash
+python -m vllm.entrypoints.openai.api_server \
+ --model Qwen/Qwen2.5-VL-7B-Instruct \
+ --dtype bfloat16 \
+ --limit-mm-per-prompt '{"image": 1}' \
+ --allowed-local-media-path /path/to/sharegpt4v/images
+```
+
+Send requests with images:
+
+```bash
+python benchmarks/benchmark_serving.py \
+ --backend openai-chat \
+ --model Qwen/Qwen2.5-VL-7B-Instruct \
+ --dataset-name sharegpt \
+ --dataset-path /path/to/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json \
+ --num-prompts 100 \
+ --save-result \
+ --result-dir ~/vllm_benchmark_results \
+ --save-detailed \
+ --endpoint /v1/chat/completion
+```
+
+### Videos (ShareGPT4Video)
+
+Start vLLM:
+
+```bash
+python -m vllm.entrypoints.openai.api_server \
+ --model Qwen/Qwen2.5-VL-7B-Instruct \
+ --dtype bfloat16 \
+ --limit-mm-per-prompt '{"video": 1}' \
+ --allowed-local-media-path /path/to/sharegpt4video/videos
+```
+
+Send requests with videos:
+
+```bash
+python benchmarks/benchmark_serving.py \
+ --backend openai-chat \
+ --model Qwen/Qwen2.5-VL-7B-Instruct \
+ --dataset-name sharegpt \
+ --dataset-path /path/to/ShareGPT4Video/llava_v1_5_mix665k_with_video_chatgpt72k_share4video28k.json \
+ --num-prompts 100 \
+ --save-result \
+ --result-dir ~/vllm_benchmark_results \
+ --save-detailed \
+ --endpoint /v1/chat/completion
+```
+
+### Synthetic Random Images (random-mm)
+
+Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets.
+
+Notes:
+
+- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`.
+- Video sampling is not yet implemented.
+
+Start the server (example):
+
+```bash
+vllm serve Qwen/Qwen2.5-VL-3B-Instruct \
+ --dtype bfloat16 \
+ --max-model-len 16384 \
+ --limit-mm-per-prompt '{"image": 3, "video": 0}' \
+ --mm-processor-kwargs max_pixels=1003520
+```
+
+Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`.
+
+Ex.1: Fixed number of items and a single image resolution, enforcing generation of approx 40 tokens:
+
+```bash
+vllm bench serve \
+ --backend openai-chat \
+ --model Qwen/Qwen2.5-VL-3B-Instruct \
+ --endpoint /v1/chat/completions \
+ --dataset-name random-mm \
+ --num-prompts 100 \
+ --max-concurrency 10 \
+ --random-prefix-len 25 \
+ --random-input-len 300 \
+ --random-output-len 40 \
+ --random-range-ratio 0.2 \
+ --random-mm-base-items-per-request 2 \
+ --random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \
+ --random-mm-bucket-config '{(224, 224, 1): 1.0}' \
+ --request-rate inf \
+ --ignore-eos \
+ --seed 42
+```
+
+The number of items per request can be controlled by passing multiple image buckets:
+
+```bash
+ --random-mm-base-items-per-request 2 \
+ --random-mm-num-mm-items-range-ratio 0.5 \
+ --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \
+ --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \
+```
+
+Flags specific to `random-mm`:
+
+- `--random-mm-base-items-per-request`: base number of multimodal items per request.
+- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(n·(1−r)), ceil(n·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items.
+- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'.
+- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported).
+
+Behavioral notes:
+
+- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping.
+
+How sampling works:
+
+- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits.
+- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added.
+- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing.
+This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`.
+- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`.
+
+
diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py
index 1559ca2d92841..ba7c733be0b25 100644
--- a/benchmarks/backend_request_func.py
+++ b/benchmarks/backend_request_func.py
@@ -34,6 +34,7 @@ class RequestFuncInput:
multi_modal_content: Optional[dict | list[dict]] = None
ignore_eos: bool = False
language: Optional[str] = None
+ request_id: Optional[str] = None
@dataclass
@@ -71,6 +72,9 @@ async def async_request_tgi(
"inputs": request_func_input.prompt,
"parameters": params,
}
+ headers = None
+ if request_func_input.request_id:
+ headers = {"x-request-id": request_func_input.request_id}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
if request_func_input.ignore_eos:
@@ -82,7 +86,9 @@ async def async_request_tgi(
st = time.perf_counter()
most_recent_timestamp = st
try:
- async with session.post(url=api_url, json=payload) as response:
+ async with session.post(
+ url=api_url, json=payload, headers=headers
+ ) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
@@ -145,6 +151,9 @@ async def async_request_trt_llm(
}
if request_func_input.ignore_eos:
payload["min_length"] = request_func_input.output_len
+ headers = None
+ if request_func_input.request_id:
+ headers = {"x-request-id": request_func_input.request_id}
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -152,7 +161,9 @@ async def async_request_trt_llm(
st = time.perf_counter()
most_recent_timestamp = st
try:
- async with session.post(url=api_url, json=payload) as response:
+ async with session.post(
+ url=api_url, json=payload, headers=headers
+ ) as response:
if response.status == 200:
async for chunk_bytes in response.content:
chunk_bytes = chunk_bytes.strip()
@@ -211,6 +222,8 @@ async def async_request_deepspeed_mii(
"top_p": 1.0,
}
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
+ if request_func_input.request_id:
+ headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -283,6 +296,8 @@ async def async_request_openai_completions(
if request_func_input.extra_body:
payload.update(request_func_input.extra_body)
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
+ if request_func_input.request_id:
+ headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -395,6 +410,8 @@ async def async_request_openai_chat_completions(
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
+ if request_func_input.request_id:
+ headers["x-request-id"] = request_func_input.request_id
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
@@ -491,6 +508,8 @@ async def async_request_openai_audio(
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
}
+ if request_func_input.request_id:
+ headers["x-request-id"] = request_func_input.request_id
# Send audio file
def to_bytes(y, sr):
diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py
index ea684f18a7421..2ea4f9ccaff2b 100644
--- a/benchmarks/benchmark_dataset.py
+++ b/benchmarks/benchmark_dataset.py
@@ -19,6 +19,7 @@ import logging
import random
from abc import ABC, abstractmethod
from collections.abc import Mapping
+from copy import deepcopy
from dataclasses import dataclass
from functools import cache
from io import BytesIO
@@ -54,6 +55,7 @@ class SampleRequest:
expected_output_len: int
multi_modal_data: Optional[Union[MultiModalDataDict, dict, list[dict]]] = None
lora_request: Optional[LoRARequest] = None
+ request_id: Optional[str] = None
# -----------------------------------------------------------------------------
@@ -155,7 +157,10 @@ class BenchmarkDataset(ABC):
@abstractmethod
def sample(
- self, tokenizer: PreTrainedTokenizerBase, num_requests: int
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ request_id_prefix: str = "",
) -> list[SampleRequest]:
"""
Abstract method to generate sample requests from the dataset.
@@ -167,6 +172,7 @@ class BenchmarkDataset(ABC):
tokenizer (PreTrainedTokenizerBase): The tokenizer to be used
for processing the dataset's text.
num_requests (int): The number of sample requests to generate.
+ request_id_prefix (str) The prefix of request_id.
Returns:
list[SampleRequest]: A list of sample requests generated from the
@@ -175,7 +181,10 @@ class BenchmarkDataset(ABC):
raise NotImplementedError("sample must be implemented in subclasses.")
def maybe_oversample_requests(
- self, requests: list[SampleRequest], num_requests: int
+ self,
+ requests: list[SampleRequest],
+ num_requests: int,
+ request_id_prefix: str = "",
) -> None:
"""
Oversamples the list of requests if its size is less than the desired
@@ -183,11 +192,18 @@ class BenchmarkDataset(ABC):
Args:
requests (List[SampleRequest]): The current list of sampled
- requests. num_requests (int): The target number of requests.
+ requests.
+ num_requests (int): The target number of requests.
+ request_id_prefix (str) The prefix of the request ids.
"""
if len(requests) < num_requests:
random.seed(self.random_seed)
- additional = random.choices(requests, k=num_requests - len(requests))
+ additional = deepcopy(
+ random.choices(requests, k=num_requests - len(requests))
+ )
+ for i in range(len(additional)):
+ req = additional[i]
+ req.request_id = request_id_prefix + str(len(requests) + i)
requests.extend(additional)
logger.info("Oversampled requests to reach %d total samples.", num_requests)
@@ -277,6 +293,41 @@ def process_image(image: Any) -> Mapping[str, Any]:
)
+def process_video(video: Any) -> Mapping[str, Any]:
+ """
+ Process a single video input and return a multimedia content dictionary.
+
+ Supports the following input types:
+
+ 1. Dictionary with raw video bytes: - Expects a dict with a 'bytes' key
+ containing raw video data.
+
+ 2. String input: - Treats the string as a URL or local file path. -
+ Prepends "file://" if the string doesn't start with "http://" or
+ "file://". - Returns a dictionary with the image URL.
+
+ Raises:
+ ValueError: If the input is not a supported type.
+ """
+ if isinstance(video, dict) and "bytes" in video:
+ video_bytes = video["bytes"]
+ video_base64 = base64.b64encode(video_bytes).decode("utf-8")
+ return {
+ "type": "video_url",
+ "video_url": {"url": f"data:video/mp4;base64,{video_base64}"},
+ }
+
+ if isinstance(video, str):
+ video_url = (
+ video if video.startswith(("http://", "file://")) else f"file://{video}"
+ )
+ return {"type": "video_url", "video_url": {"url": video_url}}
+
+ raise ValueError(
+ f"Invalid video input {video}. Must be a string of local path/remote url, or a dictionary with raw video bytes in the form of `{{'bytes': raw_video_bytes}}`." # noqa: E501
+ )
+
+
# -----------------------------------------------------------------------------
# Random Dataset Implementation (Synthetic Data)
# -----------------------------------------------------------------------------
@@ -303,6 +354,7 @@ class RandomDataset(BenchmarkDataset):
range_ratio: float = DEFAULT_RANGE_RATIO,
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
+ request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
# Enforce range_ratio < 1
@@ -363,8 +415,10 @@ class RandomDataset(BenchmarkDataset):
prompt=prompt,
prompt_len=total_input_len,
expected_output_len=int(output_lens[i]),
+ request_id=request_id_prefix + str(i),
)
)
+
return requests
@@ -406,9 +460,11 @@ class ShareGPTDataset(BenchmarkDataset):
max_loras: Optional[int] = None,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
samples: list = []
+ ind = 0
for entry in self.data:
if len(samples) >= num_requests:
break
@@ -430,17 +486,26 @@ class ShareGPTDataset(BenchmarkDataset):
skip_min_output_len_check=output_len is not None,
):
continue
+ if image_path := entry.get("image"):
+ mm_content = process_image(image_path)
+ elif video_path := entry.get("video"):
+ mm_content = process_video(video_path)
+ else:
+ mm_content = None
if enable_multimodal_chat:
- prompt = self.apply_multimodal_chat_transformation(prompt, None)
+ prompt = self.apply_multimodal_chat_transformation(prompt, mm_content)
samples.append(
SampleRequest(
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=new_output_len,
lora_request=lora_request,
+ multi_modal_data=mm_content,
+ request_id=request_id_prefix + str(ind),
)
)
- self.maybe_oversample_requests(samples, num_requests)
+ ind += 1
+ self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples
@@ -506,10 +571,11 @@ class CustomDataset(BenchmarkDataset):
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
skip_chat_template: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
sampled_requests = []
- for item in self.data:
+ for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = item["prompt"]
@@ -528,9 +594,12 @@ class CustomDataset(BenchmarkDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
+ request_id=request_id_prefix + str(i),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -572,6 +641,7 @@ class SonnetDataset(BenchmarkDataset):
input_len: int = DEFAULT_INPUT_LEN,
output_len: int = DEFAULT_OUTPUT_LEN,
return_prompt_formatted: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
# Calculate average token length for a poem line.
@@ -597,6 +667,7 @@ class SonnetDataset(BenchmarkDataset):
prefix_lines = self.data[:num_prefix_lines]
samples = []
+ ind = 0
while len(samples) < num_requests:
extra_lines = random.choices(
self.data, k=num_input_lines - num_prefix_lines
@@ -607,14 +678,17 @@ class SonnetDataset(BenchmarkDataset):
msg, add_generation_prompt=True, tokenize=False
)
prompt_len = len(tokenizer(prompt_formatted).input_ids)
+
if prompt_len <= input_len:
samples.append(
SampleRequest(
prompt=prompt_formatted if return_prompt_formatted else prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
+ request_id=request_id_prefix + str(ind),
)
)
+ ind += 1
return samples
@@ -666,6 +740,7 @@ class BurstGPTDataset(BenchmarkDataset):
num_requests: int,
max_loras: Optional[int] = None,
lora_path: Optional[str] = None,
+ request_id_prefix: str = "",
**kwargs,
) -> list[SampleRequest]:
samples = []
@@ -687,6 +762,7 @@ class BurstGPTDataset(BenchmarkDataset):
prompt_len=input_len,
expected_output_len=output_len,
lora_request=lora_req,
+ request_id=request_id_prefix + str(i),
)
)
return samples
@@ -746,12 +822,14 @@ class ConversationDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
# Filter examples with at least 2 conversations
filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2)
sampled_requests = []
dynamic_output = output_len is None
+ ind = 0
for item in filtered_data:
if len(sampled_requests) >= num_requests:
@@ -779,9 +857,13 @@ class ConversationDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
+ request_id=request_id_prefix + str(ind),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ ind += 1
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -808,11 +890,12 @@ class VisionArenaDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
- for item in self.data:
+ for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path)
@@ -832,9 +915,12 @@ class VisionArenaDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
+ request_id=request_id_prefix + str(i),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -864,15 +950,18 @@ class InstructCoderDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
- for item in self.data:
+ for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
- prompt = f"{item['input']}\n\n{item['instruction']} Just output \
- the code, do not include any explanation."
+ prompt = (
+ f"{item['input']}\n\n{item['instruction']} Just output "
+ "the code, do not include any explanation."
+ )
# apply template
prompt = tokenizer.apply_chat_template(
@@ -886,9 +975,12 @@ class InstructCoderDataset(HuggingFaceDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
+ request_id=request_id_prefix + str(i),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -918,12 +1010,13 @@ class MTBenchDataset(HuggingFaceDataset):
num_requests: int,
output_len: Optional[int] = None,
enable_multimodal_chat: bool = False,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
sampled_requests = []
- for item in self.data:
+ for i, item in enumerate(self.data):
if len(sampled_requests) >= num_requests:
break
prompt = item["turns"][0]
@@ -941,9 +1034,12 @@ class MTBenchDataset(HuggingFaceDataset):
prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
+ request_id=request_id_prefix + str(i),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -968,10 +1064,12 @@ class AIMODataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
sampled_requests = []
dynamic_output = output_len is None
+ ind = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
@@ -994,9 +1092,13 @@ class AIMODataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=None,
+ request_id=request_id_prefix + str(ind),
)
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ ind += 1
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
@@ -1066,12 +1168,18 @@ class NextEditPredictionDataset(HuggingFaceDataset):
"zed-industries/zeta": _format_zeta_prompt,
}
- def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs):
+ def sample(
+ self,
+ tokenizer: PreTrainedTokenizerBase,
+ num_requests: int,
+ request_id_prefix: str = "",
+ **kwargs,
+ ):
formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path)
if formatting_prompt_func is None:
raise ValueError(f"Unsupported dataset path: {self.dataset_path}")
samples = []
- for sample in self.data:
+ for i, sample in enumerate(self.data):
sample = formatting_prompt_func(sample)
samples.append(
SampleRequest(
@@ -1080,11 +1188,12 @@ class NextEditPredictionDataset(HuggingFaceDataset):
expected_output_len=len(
tokenizer(sample["expected_output"]).input_ids
),
+ request_id=request_id_prefix + str(i),
)
)
if len(samples) >= num_requests:
break
- self.maybe_oversample_requests(samples, num_requests)
+ self.maybe_oversample_requests(samples, num_requests, request_id_prefix)
return samples
@@ -1133,6 +1242,7 @@ class ASRDataset(HuggingFaceDataset):
tokenizer: PreTrainedTokenizerBase,
num_requests: int,
output_len: Optional[int] = None,
+ request_id_prefix: str = "",
**kwargs,
) -> list:
import librosa
@@ -1142,6 +1252,7 @@ class ASRDataset(HuggingFaceDataset):
prompt_len = len(tokenizer(prompt).input_ids)
sampled_requests = []
skipped = 0
+ ind = 0
for item in self.data:
if len(sampled_requests) >= num_requests:
break
@@ -1160,8 +1271,10 @@ class ASRDataset(HuggingFaceDataset):
prompt_len=prompt_len,
expected_output_len=output_len,
multi_modal_data=mm_content,
+ request_id=request_id_prefix + str(ind),
)
)
+ ind += 1
if skipped:
logger.warning(
"%d samples discarded from dataset due to"
@@ -1169,5 +1282,7 @@ class ASRDataset(HuggingFaceDataset):
" what Whisper supports.",
skipped,
)
- self.maybe_oversample_requests(sampled_requests, num_requests)
+ self.maybe_oversample_requests(
+ sampled_requests, num_requests, request_id_prefix
+ )
return sampled_requests
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
index ae38caf7290b1..02f5f585c0c16 100644
--- a/benchmarks/benchmark_serving.py
+++ b/benchmarks/benchmark_serving.py
@@ -375,11 +375,12 @@ async def benchmark(
rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
last_int_rps = current_int_rps
- prompt, prompt_len, output_len, mm_content = (
+ prompt, prompt_len, output_len, mm_content, request_id = (
request.prompt,
request.prompt_len,
request.expected_output_len,
request.multi_modal_data,
+ request.request_id,
)
req_model_id, req_model_name = model_id, model_name
if lora_modules:
@@ -397,6 +398,7 @@ async def benchmark(
multi_modal_content=mm_content,
ignore_eos=ignore_eos,
extra_body=extra_body,
+ request_id=request_id,
)
task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
tasks.append(asyncio.create_task(task))
@@ -665,6 +667,7 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer,
output_len=args.custom_output_len,
skip_chat_template=args.custom_skip_chat_template,
+ request_id_prefix=args.request_id_prefix,
)
elif args.dataset_name == "sonnet":
@@ -678,6 +681,7 @@ def main(args: argparse.Namespace):
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=False,
+ request_id_prefix=args.request_id_prefix,
)
else:
assert tokenizer.chat_template or tokenizer.default_chat_template, (
@@ -690,6 +694,7 @@ def main(args: argparse.Namespace):
prefix_len=args.sonnet_prefix_len,
tokenizer=tokenizer,
return_prompt_formatted=True,
+ request_id_prefix=args.request_id_prefix,
)
elif args.dataset_name == "hf":
@@ -751,6 +756,7 @@ def main(args: argparse.Namespace):
num_requests=args.num_prompts,
tokenizer=tokenizer,
output_len=args.hf_output_len,
+ request_id_prefix=args.request_id_prefix,
)
else:
@@ -762,10 +768,15 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer,
num_requests=args.num_prompts,
output_len=args.sharegpt_output_len,
+ request_id_prefix=args.request_id_prefix,
),
"burstgpt": lambda: BurstGPTDataset(
random_seed=args.seed, dataset_path=args.dataset_path
- ).sample(tokenizer=tokenizer, num_requests=args.num_prompts),
+ ).sample(
+ tokenizer=tokenizer,
+ num_requests=args.num_prompts,
+ request_id_prefix=args.request_id_prefix,
+ ),
"random": lambda: RandomDataset(dataset_path=args.dataset_path).sample(
tokenizer=tokenizer,
num_requests=args.num_prompts,
@@ -773,6 +784,7 @@ def main(args: argparse.Namespace):
input_len=args.random_input_len,
output_len=args.random_output_len,
range_ratio=args.random_range_ratio,
+ request_id_prefix=args.request_id_prefix,
),
}
@@ -1118,6 +1130,13 @@ def create_argument_parser():
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve",
)
+ parser.add_argument(
+ "--request-id-prefix",
+ type=str,
+ required=False,
+ default="benchmark-serving",
+ help="Specify the prefix of request id.",
+ )
# group for dataset specific arguments
custom_group = parser.add_argument_group("custom dataset options")
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index c51b579686529..6b24b8c8f3c67 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -96,7 +96,6 @@ def run_vllm(
end = time.perf_counter()
else:
assert lora_requests is None, "BeamSearch API does not support LoRA"
- prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0].expected_output_len
for request in requests:
@@ -597,8 +596,8 @@ def validate_args(args):
# https://github.com/vllm-project/vllm/issues/16222
if args.data_parallel_size > 1:
raise ValueError(
- "Data parallel is not supported in offline benchmark, \
- please use benchmark serving instead"
+ "Data parallel is not supported in offline benchmark, "
+ "please use benchmark serving instead"
)
diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py
deleted file mode 100644
index 42de062b08e42..0000000000000
--- a/benchmarks/kernels/benchmark_aqlm.py
+++ /dev/null
@@ -1,345 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import os
-import sys
-from typing import Optional
-
-import torch
-import torch.nn.functional as F
-
-from vllm import _custom_ops as ops
-from vllm.model_executor.layers.quantization.aqlm import (
- dequantize_weight,
- generic_dequantize_gemm,
- get_int_dtype,
- optimized_dequantize_gemm,
-)
-from vllm.utils import FlexibleArgumentParser
-
-os.environ["CUDA_VISIBLE_DEVICES"] = "0"
-
-
-def torch_mult(
- # [..., in_features]
- input: torch.Tensor,
- weights: torch.Tensor,
- # [num_out_groups, 1, 1, 1]
- scales: torch.Tensor,
-) -> torch.Tensor:
- output = F.linear(input, weights)
- return output
-
-
-def dequant_out_scale(
- # [..., in_features]
- input: torch.Tensor,
- # [num_out_groups, num_in_groups, num_codebooks]
- codes: torch.IntTensor,
- # [num_codebooks, codebook_size, out_group_size, in_group_size]
- codebooks: torch.Tensor,
- # [num_out_groups, 1, 1, 1]
- scales: torch.Tensor,
- output_partition_sizes: torch.IntTensor,
- bias: Optional[torch.Tensor],
-) -> torch.Tensor:
- weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
-
- if bias is None:
- output = F.linear(input, weights, bias)
- orig_shape = output.shape
- flattened_output = output.view(-1, output.size(-1))
- f_scales = scales.view(-1, scales.shape[0])
- b_scales = f_scales.expand(flattened_output.shape[0], -1)
- flattened_output *= b_scales
- return flattened_output.view(orig_shape)
- else:
- b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
- weights *= b_scales
- return F.linear(input, weights, bias)
-
-
-def dequant_weight_scale(
- # [..., in_features]
- input: torch.Tensor,
- # [num_out_groups, num_in_groups, num_codebooks]
- codes: torch.IntTensor,
- # [num_codebooks, codebook_size, out_group_size, in_group_size]
- codebooks: torch.Tensor,
- # [num_out_groups, 1, 1, 1]
- scales: torch.Tensor,
- output_partition_sizes: torch.IntTensor,
- bias: Optional[torch.Tensor],
-) -> torch.Tensor:
- weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
-
- b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1])
- weights *= b_scales
- return F.linear(input, weights, bias)
-
-
-def dequant_no_scale(
- # [..., in_features]
- input: torch.Tensor,
- # [num_out_groups, num_in_groups, num_codebooks]
- codes: torch.IntTensor,
- # [num_codebooks, codebook_size, out_group_size, in_group_size]
- codebooks: torch.Tensor,
- # [num_out_groups, 1, 1, 1]
- scales: torch.Tensor,
- output_partition_sizes: torch.IntTensor,
- bias: Optional[torch.Tensor],
-) -> torch.Tensor:
- weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
-
- return F.linear(input, weights, bias)
-
-
-# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
-# the generic pytorch version.
-# Just visual comparison.
-def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:
- n = int(parts.sum().item())
-
- device = torch.device("cuda:0")
-
- code_range = (1 << bits) // 2
- ingroups = 8
-
- codes = torch.randint(
- -code_range,
- code_range,
- size=(n, k // ingroups, nbooks),
- dtype=get_int_dtype(bits),
- device=device,
- )
-
- codebooks = torch.randn(
- size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
- dtype=torch.float16,
- device=device,
- )
-
- count = 0
- for index in range(16):
- for i in range(8):
- for book in range(nbooks):
- codebooks[book, index, 0, i] = count * (10**book)
- count += 1
-
- print("codes shape", codes.shape)
-
- for i in range(16):
- for book in range(nbooks):
- codes[0, i, book] = i
- codes[0, -i, book] = i
-
- weights = dequantize_weight(codes, codebooks, None)
- weights2 = ops.aqlm_dequant(codes, codebooks, parts)
-
- print("weights shape:", weights.shape)
- print("weights2 shape:", weights2.shape)
-
- print("weights are:", weights)
- print("weights2 are:", weights2)
-
- print("first 128 weights are", weights[0, 0:128].to(torch.int32))
- print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
-
- print("last 128 weights are", weights[0, -128:])
- print("last 128 weights2 are:", weights2[0, -128:])
-
-
-def main():
- parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")
-
- # Add arguments
- parser.add_argument(
- "--nbooks", type=int, default=1, help="Number of codebooks (default: 1)"
- )
- parser.add_argument(
- "--bits",
- type=int,
- default=16,
- help="Number of bits per code element (default: 16)",
- )
- parser.add_argument(
- "--test",
- type=bool,
- default=False,
- help="Run the decompression/dequant tester rather than benchmarking "
- "(default: False)",
- )
-
- # Parse the arguments
- args = parser.parse_args()
-
- # Extract values
- nbooks = args.nbooks
- bits = args.bits
-
- if args.test:
- dequant_test(4096, torch.tensor((4096,)), nbooks, bits)
- return
-
- # Otherwise, benchmark.
- methods = [
- ops.aqlm_gemm,
- dequant_out_scale,
- generic_dequantize_gemm,
- optimized_dequantize_gemm,
- dequant_weight_scale,
- torch_mult,
- dequant_no_scale,
- ]
-
- filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
- print(f"writing benchmarks to file {filename}")
- with open(filename, "w") as f:
- sys.stdout = f
-
- print("m | k | n | n parts", end="")
- for method in methods:
- print(f" | {method.__name__.replace('_', ' ')} (µs)", end="")
- print("")
-
- # These are reasonable prefill sizes.
- ksandpartions = (
- (4096, (4096, 4096, 4096)),
- (4096, (4096,)),
- (4096, (11008, 11008)),
- (11008, (4096,)),
- )
-
- # reasonable ranges for m.
- for m in [
- 1,
- 2,
- 4,
- 8,
- 10,
- 12,
- 14,
- 16,
- 24,
- 32,
- 48,
- 52,
- 56,
- 64,
- 96,
- 112,
- 128,
- 256,
- 512,
- 1024,
- 1536,
- 2048,
- 3072,
- 4096,
- ]:
- print(f"{m}", file=sys.__stdout__)
- for ksp in ksandpartions:
- run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods)
-
- sys.stdout = sys.__stdout__
-
-
-def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods):
- # I didn't see visible improvements from increasing these, but feel free :)
- num_warmup_trials = 1
- num_trials = 1
-
- num_calls = 100
-
- # warmup.
- for method in methods:
- for _ in range(num_warmup_trials):
- run_timing(
- num_calls=num_calls,
- m=m,
- k=k,
- parts=parts,
- nbooks=nbooks,
- bits=bits,
- method=method,
- )
-
- n = parts.sum().item()
- print(f"{m} | {k} | {n} | {parts.tolist()}", end="")
-
- for method in methods:
- best_time_us = 1e20
- for _ in range(num_trials):
- kernel_dur_ms = run_timing(
- num_calls=num_calls,
- m=m,
- k=k,
- parts=parts,
- nbooks=nbooks,
- bits=bits,
- method=method,
- )
-
- kernel_dur_us = 1000 * kernel_dur_ms
-
- if kernel_dur_us < best_time_us:
- best_time_us = kernel_dur_us
-
- print(f" | {kernel_dur_us:.0f}", end="")
-
- print("")
-
-
-def run_timing(
- num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method
-) -> float:
- n = int(parts.sum().item())
-
- device = torch.device("cuda:0")
-
- input = torch.randn((1, m, k), dtype=torch.float16, device=device)
-
- code_range = (1 << bits) // 2
- ingroups = 8
-
- codes = torch.randint(
- -code_range,
- code_range,
- size=(n, k // ingroups, nbooks),
- dtype=get_int_dtype(bits),
- device=device,
- )
-
- codebooks = torch.randn(
- size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
- dtype=torch.float16,
- device=device,
- )
-
- scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
-
- # for comparison to just a pytorch mult.
- weights = torch.randn((n, k), dtype=torch.float16, device=device)
-
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
-
- start_event.record()
-
- if method is torch_mult:
- for i in range(num_calls):
- torch_mult(input, weights, scales)
- else:
- for i in range(num_calls):
- method(input, codes, codebooks, scales, parts, None)
-
- end_event.record()
- end_event.synchronize()
-
- dur_ms = start_event.elapsed_time(end_event) / num_calls
- return dur_ms
-
-
-if __name__ == "__main__":
- sys.exit(main())
diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
index 1d4e730f99ae9..a6b42406b5cb0 100644
--- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
+++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
@@ -80,6 +80,11 @@ def bench_run(
a, score, topk, renormalize=False
)
+ ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
+ ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
+ c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
+ c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
+
def run_triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
@@ -111,6 +116,10 @@ def bench_run(
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
+ ab_strides1: torch.Tensor,
+ ab_strides2: torch.Tensor,
+ c_strides1: torch.Tensor,
+ c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
@@ -125,6 +134,10 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
per_act_token,
a1_scale=None,
)
@@ -136,6 +149,10 @@ def bench_run(
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
+ ab_strides1: torch.Tensor,
+ ab_strides2: torch.Tensor,
+ c_strides1: torch.Tensor,
+ c_strides2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
):
@@ -150,6 +167,10 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
per_act_token,
a1_scale=None,
)
@@ -194,6 +215,10 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
topk_weights,
topk_ids,
)
@@ -231,6 +256,10 @@ def bench_run(
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
+ "ab_strides1": ab_strides1,
+ "ab_strides2": ab_strides2,
+ "c_strides1": c_strides1,
+ "c_strides2": c_strides2,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
@@ -289,6 +318,10 @@ def bench_run(
w2_q,
w1_scale,
w2_scale,
+ ab_strides1,
+ ab_strides2,
+ c_strides1,
+ c_strides2,
topk_weights,
topk_ids,
per_act_token,
@@ -297,7 +330,7 @@ def bench_run(
results.append(
benchmark.Timer(
- stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
+ stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,
diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py
index 975d10f2e92ec..1b1c3b321cce4 100644
--- a/benchmarks/kernels/benchmark_machete.py
+++ b/benchmarks/kernels/benchmark_machete.py
@@ -253,28 +253,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
else:
assert bt.a.dtype == torch.int8
assert bt.wtype == scalar_types.uint4b8
-
- if bt.w_ch_s is not None:
- s_ch = bt.w_ch_s.to(torch.float32)
- else:
- s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device)
-
- if bt.w_tok_s is not None:
- s_tok = bt.w_tok_s.to(torch.float32)
- else:
- s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device)
-
- fn = lambda: ops.marlin_qqq_gemm(
- a=bt.a,
- b_q_weight=w_q,
- s_group=w_s,
- s_tok=s_tok,
- s_ch=s_ch,
- workspace=workspace.scratch,
- size_m=bt.a.shape[0],
- size_n=bt.w_ref.shape[1],
- size_k=bt.w_ref.shape[0],
- )
+ raise NotImplementedError("QQQ is not supported anymore")
return fn
@@ -305,6 +284,25 @@ def machete_create_bench_fn(
)
+def cutlass_w4a8_create_bench_fn(
+ bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
+) -> Callable:
+ w_q = bt.w_q.t().contiguous().t() # make col major
+ w_q = ops.cutlass_encode_and_reorder_int4b(w_q)
+ # expects fp8 scales
+ w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn))
+
+ return lambda: ops.cutlass_w4a8_mm(
+ a=bt.a,
+ b_q=w_q,
+ b_group_scales=w_s,
+ b_group_size=bt.group_size,
+ b_channel_scales=bt.w_ch_s,
+ a_token_scales=bt.w_tok_s,
+ maybe_schedule=schedule,
+ )
+
+
# impl
# bench
@@ -406,6 +404,20 @@ def bench(
)
)
+ # cutlass w4a8
+ if types.act_type == torch.float8_e4m3fn and group_size == 128:
+ timers.append(
+ bench_fns(
+ label,
+ sub_label,
+ f"cutlass w4a8 ({name_type_string})",
+ [
+ cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type)
+ for bt in benchmark_tensors
+ ],
+ )
+ )
+
if sweep_schedules:
global _SWEEP_SCHEDULES_RESULTS
diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py
index 13bf1be836f6a..752c2d0082167 100644
--- a/benchmarks/kernels/benchmark_moe.py
+++ b/benchmarks/kernels/benchmark_moe.py
@@ -3,6 +3,7 @@
import argparse
import json
+import os
import time
from contextlib import nullcontext
from datetime import datetime
@@ -429,7 +430,6 @@ class BenchmarkWorker:
hidden_size,
topk,
dtype_str,
- is_marlin=False,
)
else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
@@ -542,6 +542,7 @@ def save_configs(
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_quant_shape: list[int],
+ save_dir: str,
) -> None:
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
@@ -552,7 +553,8 @@ def save_configs(
filename = get_config_file_name(
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
)
-
+ os.makedirs(save_dir, exist_ok=True)
+ filename = os.path.join(save_dir, filename)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
@@ -707,6 +709,7 @@ def main(args: argparse.Namespace):
use_fp8_w8a8,
use_int8_w8a16,
block_quant_shape,
+ args.save_dir,
)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
@@ -748,6 +751,9 @@ if __name__ == "__main__":
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
)
parser.add_argument("--use-deep-gemm", action="store_true")
+ parser.add_argument(
+ "--save-dir", type=str, default="./", help="Directory to save tuned results"
+ )
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
parser.add_argument("--tune", action="store_true")
diff --git a/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
new file mode 100644
index 0000000000000..0650cbf3cc18e
--- /dev/null
+++ b/benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
@@ -0,0 +1,77 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import time
+
+import torch
+
+from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
+ silu_mul_fp8_quant_deep_gemm,
+)
+from vllm.platforms import current_platform
+
+
+def benchmark(E, T, H, G=128, runs=50):
+ current_platform.seed_everything(42)
+ y = torch.randn((E, T, 2 * H), dtype=torch.bfloat16, device="cuda")
+ tokens_per_expert = torch.randint(
+ T // 2, T, size=(E,), dtype=torch.int32, device="cuda"
+ )
+
+ # Warmup
+ for _ in range(10):
+ silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
+ torch.cuda.synchronize()
+
+ # Benchmark
+ torch.cuda.synchronize()
+ start = time.perf_counter()
+ for _ in range(runs):
+ silu_mul_fp8_quant_deep_gemm(y, tokens_per_expert, group_size=G)
+ torch.cuda.synchronize()
+
+ avg_time = (time.perf_counter() - start) / runs * 1000
+
+ # Calculate actual work done (only count valid tokens)
+ actual_tokens = tokens_per_expert.sum().item()
+ actual_elements = actual_tokens * H
+
+ # GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
+ ops_per_element = 8
+ total_ops = actual_elements * ops_per_element
+ gflops = total_ops / (avg_time / 1000) / 1e9
+
+ # Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
+ input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs
+ output_bytes = actual_tokens * H * 1 # H fp8 outputs
+ scale_bytes = actual_tokens * (H // G) * 4 # scales in float32
+ total_bytes = input_bytes + output_bytes + scale_bytes
+ memory_bw = total_bytes / (avg_time / 1000) / 1e9
+
+ return avg_time, gflops, memory_bw
+
+
+configs = [
+ (8, 32, 1024),
+ (16, 64, 2048),
+ (32, 128, 4096),
+ # DeepSeekV3 Configs
+ (256, 16, 7168),
+ (256, 32, 7168),
+ (256, 64, 7168),
+ (256, 128, 7168),
+ (256, 256, 7168),
+ (256, 512, 7168),
+ (256, 1024, 7168),
+]
+
+print(f"GPU: {torch.cuda.get_device_name()}")
+print(f"{'Config':<20} {'Time(ms)':<10} {'GFLOPS':<10} {'GB/s':<10}")
+print("-" * 50)
+
+for E, T, H in configs:
+ try:
+ time_ms, gflops, gbps = benchmark(E, T, H)
+ print(f"E={E:3d},T={T:4d},H={H:4d} {time_ms:8.3f} {gflops:8.1f} {gbps:8.1f}")
+ except Exception:
+ print(f"E={E:3d},T={T:4d},H={H:4d} FAILED")
diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py
index 77136edca45b5..603ce5ecf0d2c 100644
--- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py
+++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py
@@ -3,16 +3,17 @@
import csv
import os
-import random
from datetime import datetime
+from typing import Optional
import flashinfer
import torch
-FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
+from vllm.utils import round_up
-# KV Cache Layout for TRT-LLM
-# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
+FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
+FP8_DTYPE = torch.float8_e4m3fn
+FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn):
@@ -26,65 +27,106 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad()
def benchmark_decode(
- num_seqs,
- max_seq_len,
- page_size=16,
- dtype=torch.bfloat16,
- kv_layout="HND",
- num_kv_heads=8,
- kv_cache_dtype="auto",
- head_dim=128,
- warmup=10,
- trials=20,
+ dtype: torch.dtype,
+ quant_dtypes: tuple[
+ Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
+ ],
+ batch_size: int,
+ max_seq_len: int,
+ num_heads: tuple[int, int] = (64, 8),
+ head_size: int = 128,
+ kv_layout: str = "HND",
+ block_size: int = 16,
+ warmup: int = 10,
+ trials: int = 20,
):
torch.set_default_device("cuda")
- device = "cuda"
torch.manual_seed(0)
- HEAD_GRP_SIZE = 8
- MAX_SEQ_LEN = max_seq_len
+ q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
+ q_quant_dtype = q_quant_dtype or dtype
+ kv_quant_dtype = kv_quant_dtype or dtype
+ o_quant_dtype = o_quant_dtype or dtype
+
+ num_qo_heads, num_kv_heads = num_heads
+ assert num_qo_heads % num_kv_heads == 0
+
+ sm_scale = float(1.0 / (head_size**0.5))
# large number to reduce kv_cache reuse
- NUM_BLOCKS = int(256000 / page_size)
+ NUM_BLOCKS = int(256000 / block_size)
- workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
+ kv_cache_shape = None
+ if kv_layout == "NHD":
+ kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
+ elif kv_layout == "HND":
+ kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
+ else:
+ raise ValueError(f"Invalid kv_layout: {kv_layout}")
- # For decode, batch_size is num_decode_token
- num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
- sm_scale = float(1.0 / (head_dim**0.5))
- q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
- kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
+ # Always using 1.0 scale to reflect the real perf in benchmarking
+ q_scale = 1.0
+ ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
+ if q_quant_dtype == FP8_DTYPE:
+ query, _ = to_float8(ref_query)
+ else:
+ query = ref_query
- max_kv_len = max(kv_lens)
- kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
- max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
+ kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
+ kv_lens[-1] = max_seq_len
- block_tables = torch.randint(
- 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
- )
+ seq_lens = kv_lens
+ max_seq_len = torch.max(seq_lens).item()
- kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
- kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
+ # Always using 1.0 scale to reflect the real perf in benchmarking
k_scale = v_scale = 1.0
+ ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
+ if kv_quant_dtype == FP8_DTYPE:
+ kv_cache, _ = to_float8(ref_kv_cache)
+ else:
+ kv_cache = ref_kv_cache
- if kv_cache_dtype.startswith("fp8"):
- kv_cache, _ = to_float8(kv_cache)
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
+ block_tables = torch.randint(
+ 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
+ )
+ kv_indptr = [0]
+ kv_indices = []
+ kv_last_page_lens = []
+ for i in range(batch_size):
+ seq_len = seq_lens[i]
+ assert seq_len > 0
+ num_blocks = (seq_len + block_size - 1) // block_size
+ kv_indices.extend(block_tables[i, :num_blocks])
+ kv_indptr.append(kv_indptr[-1] + num_blocks)
+ kv_last_page_len = seq_len % block_size
+ if kv_last_page_len == 0:
+ kv_last_page_len = block_size
+ kv_last_page_lens.append(kv_last_page_len)
- output_trtllm = torch.empty(q.shape, dtype=dtype)
+ kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
+ kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
+ kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
+ workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
- # Benchmark TRT decode
- def trt_decode():
- return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
- q,
- kv_cache,
- workspace_buffer,
- block_tables,
- kv_lens_tensor,
- max_kv_len,
- bmm1_scale=k_scale * sm_scale,
- bmm2_scale=v_scale,
- out=output_trtllm,
- )
+ wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
+ workspace_buffer,
+ kv_layout,
+ use_tensor_cores=True,
+ )
+ wrapper.plan(
+ kv_indptr,
+ kv_indices,
+ kv_last_page_lens,
+ num_qo_heads,
+ num_kv_heads,
+ head_size,
+ block_size,
+ "NONE",
+ sm_scale=sm_scale,
+ q_data_type=dtype,
+ kv_data_type=dtype,
+ )
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
@@ -101,74 +143,72 @@ def benchmark_decode(
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
- # TRT Decode
- trt_mean, trt_std = time_fn(trt_decode)
-
- kv_indptr = [0]
- kv_indices = []
- kv_last_page_lens = []
- for i in range(num_seqs):
- seq_len = kv_lens[i]
- assert seq_len > 0
- num_blocks = (seq_len + page_size - 1) // page_size
- kv_indices.extend(block_tables[i, :num_blocks])
- kv_indptr.append(kv_indptr[-1] + num_blocks)
- kv_last_page_len = seq_len % page_size
- if kv_last_page_len == 0:
- kv_last_page_len = page_size
- kv_last_page_lens.append(kv_last_page_len)
-
- kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
- kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
- kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
-
- output_baseline = torch.empty(q.shape, dtype=dtype)
-
- wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
- workspace_buffer,
- kv_layout,
- use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
- )
-
- wrapper.plan(
- kv_indptr,
- kv_indices,
- kv_last_page_lens,
- num_qo_heads,
- num_kv_heads,
- head_dim,
- page_size,
- "NONE",
- q_data_type=dtype,
- kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
- )
+ o_scale = 1.0
+ o_sf_scale = None
+ output_baseline = torch.empty(ref_query.shape, dtype=dtype)
+ if o_quant_dtype == FP4_DTYPE:
+ o_sf_scale = 500.0
+ output_trtllm = flashinfer.utils.FP4Tensor(
+ torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
+ torch.empty(
+ (
+ round_up(query.shape[0], 128),
+ round_up(query.shape[1] * query.shape[2] // 16, 4),
+ ),
+ dtype=torch.float8_e4m3fn,
+ ),
+ )
+ else:
+ output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
def baseline_decode():
- return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale, output_baseline)
+ return wrapper.run(
+ ref_query,
+ ref_kv_cache,
+ k_scale=k_scale,
+ v_scale=v_scale,
+ out=output_baseline,
+ )
+
+ def trtllm_decode():
+ return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
+ query=query,
+ kv_cache=kv_cache,
+ workspace_buffer=workspace_buffer,
+ block_tables=block_tables,
+ seq_lens=seq_lens,
+ max_seq_len=max_seq_len,
+ bmm1_scale=q_scale * k_scale * sm_scale,
+ bmm2_scale=v_scale / o_scale,
+ o_sf_scale=o_sf_scale,
+ out=output_trtllm,
+ )
baseline_mean, baseline_std = time_fn(baseline_decode)
+ trtllm_mean, trtllm_std = time_fn(trtllm_decode)
# Calculate percentage speedup (positive means TRT is faster)
- speedup_percent = (baseline_mean - trt_mean) / baseline_mean
+ speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
print(
- f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
+ f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
)
# Return results for CSV writing
return {
- "num_seqs": num_seqs,
- "trt_mean": trt_mean,
- "trt_std": trt_std.item(),
+ "batch_size": batch_size,
+ "trtllm_mean": trtllm_mean,
+ "trtllm_std": trtllm_std.item(),
"baseline_mean": baseline_mean,
"baseline_std": baseline_std.item(),
"speedup_percent": speedup_percent,
- "q_dtype": str(dtype),
- "kv_cache_dtype": kv_cache_dtype,
- "page_size": page_size,
+ "q_dtype": str(q_quant_dtype),
+ "kv_cache_dtype": str(kv_quant_dtype),
+ "output_dtype": str(o_quant_dtype),
+ "block_size": block_size,
"num_kv_heads": num_kv_heads,
- "head_dim": head_dim,
+ "head_size": head_size,
"max_seq_len": max_seq_len,
}
@@ -180,17 +220,18 @@ def write_results_to_csv(results, filename=None):
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
fieldnames = [
- "num_seqs",
- "trt_mean",
- "trt_std",
+ "batch_size",
+ "trtllm_mean",
+ "trtllm_std",
"baseline_mean",
"baseline_std",
"speedup_percent",
"q_dtype",
"kv_cache_dtype",
- "page_size",
+ "output_dtype",
+ "block_size",
"num_kv_heads",
- "head_dim",
+ "head_size",
"max_seq_len",
]
@@ -209,45 +250,43 @@ def write_results_to_csv(results, filename=None):
if __name__ == "__main__":
- num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
+ batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = []
- print(
- "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
- "output_dtype: bfloat16"
- )
- print(
- "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
- "baseline_std\tspeedup_percent"
- )
- for max_seq_len in max_seq_lens:
- for bs in num_seqs:
- result = benchmark_decode(
- bs,
- max_seq_len,
- dtype=torch.bfloat16,
- kv_cache_dtype="auto",
- )
- all_results.append(result)
+ dtype = torch.bfloat16
+ quant_dtypes = [
+ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
+ (None, None, None),
+ (None, FP8_DTYPE, None),
+ (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
+ (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
+ ]
- print(
- "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
- "output_dtype: bfloat16"
- )
- print(
- "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
- "baseline_std\tspeedup_percent"
- )
- for max_seq_len in max_seq_lens:
- for bs in num_seqs:
- result = benchmark_decode(
- bs,
- max_seq_len,
- dtype=torch.bfloat16,
- kv_cache_dtype="fp8",
- )
- all_results.append(result)
+ for quant_dtype in quant_dtypes:
+ q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
+ q_quant_dtype = q_quant_dtype or dtype
+ kv_quant_dtype = kv_quant_dtype or dtype
+ o_quant_dtype = o_quant_dtype or dtype
+
+ print(
+ f"Running benchmark for q_dtype = {q_quant_dtype}, "
+ f"kv_cache_dtype: {kv_quant_dtype}, "
+ f"output_dtype: {o_quant_dtype}"
+ )
+ print(
+ "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
+ "baseline_std\tspeedup_percent"
+ )
+ for max_seq_len in max_seq_lens:
+ for bs in batch_sizes:
+ result = benchmark_decode(
+ dtype=dtype,
+ quant_dtypes=quant_dtype,
+ batch_size=bs,
+ max_seq_len=max_seq_len,
+ )
+ all_results.append(result)
# Write all results to CSV
write_results_to_csv(all_results)
diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
index 67bd9aebbcca9..40903c6c3444f 100644
--- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
+++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py
@@ -3,16 +3,17 @@
import csv
import os
-import random
from datetime import datetime
+from typing import Optional
import flashinfer
import torch
-FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
+from vllm.utils import round_up
-# KV Cache Layout for TRT-LLM
-# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
+FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
+FP8_DTYPE = torch.float8_e4m3fn
+FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn):
@@ -26,84 +27,100 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
@torch.no_grad()
def benchmark_prefill(
- num_seqs,
- max_seq_len,
- page_size=16,
- dtype=torch.bfloat16,
- kv_layout="HND",
- num_kv_heads=8,
- kv_cache_dtype="auto",
- head_dim=128,
- warmup=10,
- trials=20,
+ dtype: torch.dtype,
+ quant_dtypes: tuple[
+ Optional[torch.dtype], Optional[torch.dtype], Optional[torch.dtype]
+ ],
+ batch_size: int,
+ max_seq_len: int,
+ num_heads: tuple[int, int] = (64, 8),
+ head_size: int = 128,
+ kv_layout: str = "HND",
+ block_size: int = 16,
+ warmup: int = 10,
+ trials: int = 20,
):
torch.set_default_device("cuda")
torch.manual_seed(0)
- HEAD_GRP_SIZE = 8
- MAX_SEQ_LEN = max_seq_len
+ q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
+ q_quant_dtype = q_quant_dtype or dtype
+ kv_quant_dtype = kv_quant_dtype or dtype
+ o_quant_dtype = o_quant_dtype or dtype
+
+ max_q_len = max_kv_len = max_seq_len
+
+ num_qo_heads, num_kv_heads = num_heads
+ assert num_qo_heads % num_kv_heads == 0
+
+ sm_scale = float(1.0 / (head_size**0.5))
# large number to reduce kv_cache reuse
- NUM_BLOCKS = int(256000 / page_size)
+ NUM_BLOCKS = int(256000 / block_size)
- workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8)
+ kv_cache_shape = None
+ if kv_layout == "NHD":
+ kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
+ elif kv_layout == "HND":
+ kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
+ else:
+ raise ValueError(f"Invalid kv_layout: {kv_layout}")
- num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
- sm_scale = float(1.0 / (head_dim**0.5))
-
- q_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
- q_lens[-1] = MAX_SEQ_LEN
- max_q_len = max(q_lens)
+ q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
+ q_lens[-1] = max_q_len
q_indptr = torch.cat(
[
torch.tensor([0], dtype=torch.int32),
- torch.cumsum(
- torch.tensor(q_lens, dtype=torch.int32), dim=0, dtype=torch.int32
- ),
+ torch.cumsum(q_lens, dim=0, dtype=torch.int32),
]
)
- q = torch.randn(sum(q_lens), num_qo_heads, head_dim, dtype=dtype)
- kv_lens = [random.randint(0, MAX_SEQ_LEN) for _ in range(num_seqs)]
- kv_lens[-1] = MAX_SEQ_LEN
-
- seq_lens = [q_len + kv_len for q_len, kv_len in zip(q_lens, kv_lens)]
- max_seq_len = max(seq_lens)
- seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int32)
-
- max_num_blocks_per_seq = (max_seq_len + page_size - 1) // page_size
- block_tables = torch.randint(
- 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
+ # Always using 1.0 scale to reflect the real perf in benchmarking
+ q_scale = 1.0
+ ref_query = torch.randn(
+ torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
)
+ if q_quant_dtype == FP8_DTYPE:
+ query, _ = to_float8(ref_query)
+ else:
+ query = ref_query
- kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
- kv_cache = torch.randn(size=kv_cache_shape, dtype=dtype)
+ kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
+ kv_lens[-1] = max_kv_len
+
+ seq_lens = kv_lens + q_lens
+ max_seq_len = torch.max(seq_lens).item()
+
+ # Always using 1.0 scale to reflect the real perf in benchmarking
k_scale = v_scale = 1.0
+ ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
+ if kv_quant_dtype == FP8_DTYPE:
+ kv_cache, _ = to_float8(ref_kv_cache)
+ else:
+ kv_cache = ref_kv_cache
- if kv_cache_dtype.startswith("fp8"):
- kv_cache, _ = to_float8(kv_cache)
-
- output_trtllm = torch.empty(q.shape, dtype=dtype)
-
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
+ block_tables = torch.randint(
+ 0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
+ )
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
- for i in range(num_seqs):
+ for i in range(batch_size):
seq_len = seq_lens[i]
assert seq_len > 0
- num_blocks = (seq_len + page_size - 1) // page_size
+ num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
- kv_last_page_len = seq_len % page_size
+ kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
- kv_last_page_len = page_size
+ kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
-
- output_baseline = torch.empty(q.shape, dtype=dtype)
+ workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, kv_layout
@@ -115,12 +132,12 @@ def benchmark_prefill(
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
- head_dim,
- page_size,
+ head_size,
+ block_size,
causal=True,
sm_scale=sm_scale,
q_data_type=dtype,
- kv_data_type=kv_cache.dtype,
+ kv_data_type=dtype,
)
def time_fn(fn, warmup=10, trials=20):
@@ -138,52 +155,76 @@ def benchmark_prefill(
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
+ o_scale = 1.0
+ o_sf_scale = None
+ output_baseline = torch.empty(ref_query.shape, dtype=dtype)
+ if o_quant_dtype == FP4_DTYPE:
+ o_sf_scale = 500.0
+ output_trtllm = flashinfer.utils.FP4Tensor(
+ torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
+ torch.empty(
+ (
+ round_up(query.shape[0], 128),
+ round_up(query.shape[1] * query.shape[2] // 16, 4),
+ ),
+ dtype=torch.float8_e4m3fn,
+ ),
+ )
+ else:
+ output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
+
def baseline_prefill():
return wrapper.run(
- q, kv_cache, k_scale=k_scale, v_scale=v_scale, out=output_baseline
+ ref_query,
+ ref_kv_cache,
+ k_scale=k_scale,
+ v_scale=v_scale,
+ out=output_baseline,
)
- def trt_prefill():
+ def trtllm_prefill():
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
- query=q,
+ query=query,
kv_cache=kv_cache,
workspace_buffer=workspace_buffer,
block_tables=block_tables,
- seq_lens=seq_lens_tensor,
+ seq_lens=seq_lens,
max_q_len=max_q_len,
max_kv_len=max_seq_len,
- bmm1_scale=k_scale * sm_scale,
- bmm2_scale=v_scale,
- batch_size=num_seqs,
+ bmm1_scale=q_scale * k_scale * sm_scale,
+ bmm2_scale=v_scale / o_scale,
+ batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
+ o_sf_scale=o_sf_scale,
out=output_trtllm,
)
- trt_mean, trt_std = time_fn(trt_prefill)
baseline_mean, baseline_std = time_fn(baseline_prefill)
+ trtllm_mean, trtllm_std = time_fn(trtllm_prefill)
# Calculate percentage speedup (positive means TRT is faster)
- speedup_percent = (baseline_mean - trt_mean) / baseline_mean
+ speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
print(
- f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.5f}\t{trt_std.item():.5f}"
- f"\t{baseline_mean:.5f}\t{baseline_std.item():.5f}\t{speedup_percent:.5f}"
+ f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}"
+ f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}"
)
# Return results for CSV writing
return {
- "num_seqs": num_seqs,
- "trt_mean": trt_mean,
- "trt_std": trt_std.item(),
+ "batch_size": batch_size,
+ "trtllm_mean": trtllm_mean,
+ "trtllm_std": trtllm_std.item(),
"baseline_mean": baseline_mean,
"baseline_std": baseline_std.item(),
"speedup_percent": speedup_percent,
- "q_dtype": str(dtype),
- "kv_cache_dtype": kv_cache_dtype,
- "page_size": page_size,
+ "q_dtype": str(q_quant_dtype),
+ "kv_cache_dtype": str(kv_quant_dtype),
+ "output_dtype": str(o_quant_dtype),
+ "block_size": block_size,
"num_kv_heads": num_kv_heads,
- "head_dim": head_dim,
+ "head_size": head_size,
"max_seq_len": max_seq_len,
}
@@ -195,17 +236,18 @@ def write_results_to_csv(results, filename=None):
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
fieldnames = [
- "num_seqs",
- "trt_mean",
- "trt_std",
+ "batch_size",
+ "trtllm_mean",
+ "trtllm_std",
"baseline_mean",
"baseline_std",
"speedup_percent",
"q_dtype",
"kv_cache_dtype",
- "page_size",
+ "output_dtype",
+ "block_size",
"num_kv_heads",
- "head_dim",
+ "head_size",
"max_seq_len",
]
@@ -224,27 +266,42 @@ def write_results_to_csv(results, filename=None):
if __name__ == "__main__":
- num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
+ batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = []
- print(
- "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
- "output_dtype: bfloat16"
- )
- print(
- "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\t"
- "baseline_std\tspeedup_percent"
- )
- for max_seq_len in max_seq_lens:
- for bs in num_seqs:
- result = benchmark_prefill(
- bs,
- max_seq_len,
- dtype=torch.bfloat16,
- kv_cache_dtype="auto",
- )
- all_results.append(result)
+ dtype = torch.bfloat16
+ quant_dtypes = [
+ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
+ (None, None, None),
+ (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
+ (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
+ ]
+
+ for quant_dtype in quant_dtypes:
+ q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
+ q_quant_dtype = q_quant_dtype or dtype
+ kv_quant_dtype = kv_quant_dtype or dtype
+ o_quant_dtype = o_quant_dtype or dtype
+
+ print(
+ f"Running benchmark for q_dtype = {q_quant_dtype}, "
+ f"kv_cache_dtype: {kv_quant_dtype}, "
+ f"output_dtype: {o_quant_dtype}"
+ )
+ print(
+ "\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
+ "baseline_std\tspeedup_percent"
+ )
+ for max_seq_len in max_seq_lens:
+ for bs in batch_sizes:
+ result = benchmark_prefill(
+ dtype=dtype,
+ quant_dtypes=quant_dtype,
+ batch_size=bs,
+ max_seq_len=max_seq_len,
+ )
+ all_results.append(result)
# Write all results to CSV
write_results_to_csv(all_results)
diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py
index 4fcdbadd65ecd..e648a91077fdb 100644
--- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py
+++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py
@@ -11,8 +11,8 @@ from datetime import datetime
from typing import Any
import torch
-import tqdm
import triton
+from tqdm import tqdm
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_w8a8_block_fp8_matmul,
diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py
index a27f02394afbd..9a057990bda5f 100644
--- a/benchmarks/kernels/weight_shapes.py
+++ b/benchmarks/kernels/weight_shapes.py
@@ -95,4 +95,10 @@ WEIGHT_SHAPES = {
([2048, 2816], 1),
([1408, 2048], 0),
],
+ "CohereLabs/c4ai-command-a-03-2025": [
+ ([12288, 14336], 1),
+ ([12288, 12288], 0),
+ ([12288, 73728], 1),
+ ([36864, 12288], 0),
+ ],
}
diff --git a/benchmarks/multi_turn/README.md b/benchmarks/multi_turn/README.md
index ae0866ae60751..7adf97bcf5622 100644
--- a/benchmarks/multi_turn/README.md
+++ b/benchmarks/multi_turn/README.md
@@ -5,11 +5,13 @@ The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `re
First start serving your model
```bash
-export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
+export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
-vllm serve $MODEL_NAME --disable-log-requests
+vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests
```
+The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface).
+
## Synthetic Multi-Turn Conversations
Download the following text file (used for generation of synthetic conversations)
@@ -26,10 +28,10 @@ But you may use other text files if you prefer (using this specific file is not
Then run the benchmarking script
```bash
-export MODEL_NAME=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
+export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
-python benchmark_serving_multi_turn.py --model $MODEL_NAME --input-file generate_multi_turn.json \
---num-clients 2 --max-active-conversations 6
+python benchmark_serving_multi_turn.py --model $MODEL_PATH --served-model-name Llama \
+--input-file generate_multi_turn.json --num-clients 2 --max-active-conversations 6
```
You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.).
diff --git a/benchmarks/multi_turn/benchmark_serving_multi_turn.py b/benchmarks/multi_turn/benchmark_serving_multi_turn.py
index 53c3207491d18..d23b7b6e4571d 100644
--- a/benchmarks/multi_turn/benchmark_serving_multi_turn.py
+++ b/benchmarks/multi_turn/benchmark_serving_multi_turn.py
@@ -825,9 +825,11 @@ def get_client_config(
# Arguments for API requests
chat_url = f"{args.url}/v1/chat/completions"
+ model_name = args.served_model_name if args.served_model_name else args.model
+
req_args = RequestArgs(
chat_url=chat_url,
- model=args.model,
+ model=model_name,
stream=not args.no_stream,
limit_min_tokens=args.limit_min_tokens,
limit_max_tokens=args.limit_max_tokens,
@@ -1247,9 +1249,19 @@ async def main() -> None:
default=0,
help="Seed for random number generators (default: 0)",
)
+
parser.add_argument(
"-m", "--model", type=str, required=True, help="Path of the LLM model"
)
+ parser.add_argument(
+ "--served-model-name",
+ type=str,
+ default=None,
+ help="The model name used in the API. "
+ "If not specified, the model name will be the "
+ "same as the ``--model`` argument. ",
+ )
+
parser.add_argument(
"-u",
"--url",
diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake
index e0da46e2accaa..52bfd82c7fcfe 100644
--- a/cmake/cpu_extension.cmake
+++ b/cmake/cpu_extension.cmake
@@ -1,6 +1,7 @@
include(FetchContent)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
+set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_EXTENSIONS ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
@@ -182,17 +183,17 @@ endif()
#
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
# Flag to enable ACL kernels for AARCH64 platforms
-if ( VLLM_BUILD_ACL STREQUAL "ON")
+if (VLLM_BUILD_ACL STREQUAL "ON")
set(USE_ACL ON)
else()
set(USE_ACL OFF)
endif()
-if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
+if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
- GIT_TAG v3.8.1
+ GIT_TAG v3.9
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
@@ -204,7 +205,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
endif()
set(ONEDNN_AARCH64_USE_ACL "ON")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
- endif()
+ endif()
set(ONEDNN_LIBRARY_TYPE "STATIC")
set(ONEDNN_BUILD_DOC "OFF")
@@ -217,38 +218,23 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
+ set(ONEDNN_VERBOSE "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
FetchContent_MakeAvailable(oneDNN)
-
- list(APPEND LIBS dnnl)
-elseif(POWER10_FOUND)
- FetchContent_Declare(
- oneDNN
- GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
- GIT_TAG v3.7.2
- GIT_PROGRESS TRUE
- GIT_SHALLOW TRUE
+ add_library(dnnl_ext OBJECT "csrc/cpu/dnnl_helper.cpp")
+ target_include_directories(
+ dnnl_ext
+ PUBLIC ${oneDNN_SOURCE_DIR}/include
+ PUBLIC ${oneDNN_BINARY_DIR}/include
+ PRIVATE ${oneDNN_SOURCE_DIR}/src
)
-
- set(ONEDNN_LIBRARY_TYPE "STATIC")
- set(ONEDNN_BUILD_DOC "OFF")
- set(ONEDNN_BUILD_EXAMPLES "OFF")
- set(ONEDNN_BUILD_TESTS "OFF")
- set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
- set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
- set(ONEDNN_BUILD_GRAPH "OFF")
- set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
- set(ONEDNN_ENABLE_ITT_TASKS "OFF")
- set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
- set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
- set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
-
- set(DNNL_CPU_RUNTIME "OMP")
-
- FetchContent_MakeAvailable(oneDNN)
-
- list(APPEND LIBS dnnl)
+ target_link_libraries(dnnl_ext dnnl)
+ target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
+ list(APPEND LIBS dnnl_ext)
+ set(USE_ONEDNN ON)
+else()
+ set(USE_ONEDNN OFF)
endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
@@ -275,7 +261,6 @@ set(VLLM_EXT_SRC
if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC
- "csrc/cpu/quant.cpp"
"csrc/cpu/shm.cpp"
${VLLM_EXT_SRC})
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
@@ -289,14 +274,11 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
${VLLM_EXT_SRC})
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
endif()
-elseif(POWER10_FOUND)
- set(VLLM_EXT_SRC
- "csrc/cpu/quant.cpp"
- ${VLLM_EXT_SRC})
endif()
-if (ASIMD_FOUND)
+
+if(USE_ONEDNN)
set(VLLM_EXT_SRC
- "csrc/cpu/quant.cpp"
+ "csrc/cpu/dnnl_kernels.cpp"
${VLLM_EXT_SRC})
endif()
diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake
index ee6768bce26ca..02224cfe3ee81 100644
--- a/cmake/external_projects/flashmla.cmake
+++ b/cmake/external_projects/flashmla.cmake
@@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
- GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
+ GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
@@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
- ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
+ ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
- ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
+ ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
+ ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc/cutlass/include
- ${flashmla_SOURCE_DIR}/csrc/include)
+ ${flashmla_SOURCE_DIR}/csrc)
set_gencode_flags_for_srcs(
SRCS "${FlashMLA_SOURCES}"
diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake
index d24d8e8e5e795..49defccbb1fa4 100644
--- a/cmake/external_projects/vllm_flash_attn.cmake
+++ b/cmake/external_projects/vllm_flash_attn.cmake
@@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
- GIT_TAG 93cf5a08f421a3efd0c4a7e005ef8f742b578ce0
+ GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu
index 55e6596797010..a4a880f13cf7e 100644
--- a/csrc/activation_kernels.cu
+++ b/csrc/activation_kernels.cu
@@ -128,6 +128,45 @@ __global__ void act_and_mul_kernel_with_param(
}
}
+template
+__device__ __forceinline__ T swigluoai_and_mul(const T& gate, const T& up,
+ float alpha, float limit) {
+ // clamp gate: min=None, max=limit
+ const float gate_f = (float)gate;
+ const float clamped_gate = gate_f > limit ? limit : gate_f;
+
+ // clamp up: min=-limit, max=limit
+ const float up_f = (float)up;
+ const float clamped_up =
+ up_f > limit ? limit : (up_f < -limit ? -limit : up_f);
+
+ // glu = gate * sigmoid(gate * alpha)
+ const float sigmoid_val = 1.0f / (1.0f + expf(-clamped_gate * alpha));
+ const float glu = clamped_gate * sigmoid_val;
+
+ // (up + 1) * glu
+ return (T)((clamped_up + 1.0f) * glu);
+}
+
+template
+__global__ void swigluoai_and_mul_kernel(
+ scalar_t* __restrict__ out, // [..., d]
+ const scalar_t* __restrict__ input, // [..., 2, d]
+ const int d, const float alpha, const float limit) {
+ const int64_t token_idx = blockIdx.x;
+ // TODO: Vectorize loads and stores.
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
+ // gate = x[..., ::2] (even indices)
+ const scalar_t gate = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx]);
+ // up = x[..., 1::2] (odd indices)
+ const scalar_t up = VLLM_LDG(&input[token_idx * 2 * d + 2 * idx + 1]);
+
+ out[token_idx * d + idx] = ACT_FN(gate, up, alpha, limit);
+ }
+}
+
} // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
@@ -145,11 +184,31 @@ __global__ void act_and_mul_kernel_with_param(
PARAM); \
});
+#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
+ int d = input.size(-1) / 2; \
+ int64_t num_tokens = input.numel() / input.size(-1); \
+ dim3 grid(num_tokens); \
+ dim3 block(std::min(d, 1024)); \
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
+ VLLM_DISPATCH_FLOATING_TYPES( \
+ input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
+ vllm::swigluoai_and_mul_kernel> \
+ <<>>(out.data_ptr(), \
+ input.data_ptr(), d, ALPHA, \
+ LIMIT); \
+ });
+
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d]
double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
}
+void swigluoai_and_mul(torch::Tensor& out, // [..., d]
+ torch::Tensor& input, // [..., 2 * d]
+ double alpha, double limit) {
+ LAUNCH_SIGLUOAI_AND_MUL(vllm::swigluoai_and_mul, alpha, limit);
+}
namespace vllm {
// Element-wise activation kernel template.
diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu
index e0e95d06290df..6dd6f269f3dc9 100644
--- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu
+++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu
@@ -167,7 +167,7 @@ typename T::Fmha::Arguments args_from_options(
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
- num_kv_splits, // split_kv
+ static_cast(num_kv_splits), // split_kv
nullptr, // is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
@@ -264,7 +264,7 @@ int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_ba
// Assumes device 0 when getting sm_count.
arguments.hw_info.sm_count =
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
- arguments.split_kv = num_kv_splits;
+ arguments.split_kv = static_cast(num_kv_splits);
MlaSm100Type::Fmha::set_split_kv(arguments);
return MlaSm100Type::Fmha::get_workspace_size(arguments);
diff --git a/csrc/cache.h b/csrc/cache.h
index 0970b704be3ab..fb0c353b96137 100644
--- a/csrc/cache.h
+++ b/csrc/cache.h
@@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);
-void gather_cache(
+void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size, std::optional seq_starts = std::nullopt);
\ No newline at end of file
+ int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& scale,
+ std::optional seq_starts = std::nullopt);
\ No newline at end of file
diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu
index 131dcb15cd7e9..b3a985c2d5bbb 100644
--- a/csrc/cache_kernels.cu
+++ b/csrc/cache_kernels.cu
@@ -624,9 +624,9 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
namespace vllm {
// grid is launched with dimensions (batch, num_splits)
-template
-__global__ void gather_cache(
- const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
+template
+__global__ void gather_and_maybe_dequant_cache(
+ const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
@@ -634,6 +634,7 @@ __global__ void gather_cache(
const int32_t block_size, const int32_t entry_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
+ const float* __restrict__ scale,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch
@@ -675,10 +676,16 @@ __global__ void gather_cache(
if (partial_block_size) full_blocks_end -= 1;
}
- auto copy_entry = [&](const scalar_t* __restrict__ _src,
+ auto copy_entry = [&](const cache_t* __restrict__ _src,
scalar_t* __restrict__ _dst) {
- for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
- _dst[i] = _src[i];
+ for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
+ if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
+ _dst[i] = static_cast(_src[i]);
+ } else {
+ _dst[i] =
+ fp8::scaled_convert(_src[i], *scale);
+ }
+ }
};
for (int pid = split_start; pid < full_blocks_end; ++pid) {
@@ -705,25 +712,31 @@ __global__ void gather_cache(
} // namespace vllm
// Macro to dispatch the kernel based on the data type.
-#define CALL_GATHER_CACHE(CPY_DTYPE) \
- vllm::gather_cache<<>>( \
- reinterpret_cast(src_cache.data_ptr()), \
- reinterpret_cast(dst.data_ptr()), \
- block_table.data_ptr(), cu_seq_lens.data_ptr(), \
- block_size, entry_size, block_table_stride, cache_block_stride, \
- cache_entry_stride, dst_entry_stride, seq_starts_ptr);
+// SCALAR_T is the data type of the destination tensor.
+// CACHE_T is the stored data type of kv-cache.
+// KV_DTYPE is the real data type of kv-cache.
+#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
+ vllm::gather_and_maybe_dequant_cache \
+ <<>>( \
+ reinterpret_cast(src_cache.data_ptr()), \
+ reinterpret_cast(dst.data_ptr()), \
+ block_table.data_ptr(), cu_seq_lens.data_ptr(), \
+ block_size, entry_size, block_table_stride, cache_block_stride, \
+ cache_entry_stride, dst_entry_stride, \
+ reinterpret_cast(scale.data_ptr()), seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
-void gather_cache(
+void gather_and_maybe_dequant_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
- int64_t batch_size,
+ int64_t batch_size, const std::string& kv_cache_dtype,
+ torch::Tensor const& scale,
std::optional seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -761,20 +774,8 @@ void gather_cache(
dim3 grid(batch_size, num_splits);
dim3 block(1024);
- TORCH_CHECK(src_cache.dtype() == dst.dtype(),
- "src_cache and dst must have the same dtype");
-
- const int dtype_bits = src_cache.element_size() * 8;
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr() : nullptr;
- if (dtype_bits == 32) {
- CALL_GATHER_CACHE(uint32_t);
- } else if (dtype_bits == 16) {
- CALL_GATHER_CACHE(uint16_t);
- } else if (dtype_bits == 8) {
- CALL_GATHER_CACHE(uint8_t);
- } else {
- TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
- }
+ DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
}
diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp
index 3952c43cbc727..982f7c07a13bd 100644
--- a/csrc/cpu/cpu_types_x86.hpp
+++ b/csrc/cpu/cpu_types_x86.hpp
@@ -89,7 +89,7 @@ struct FP16Vec16 : public Vec {
explicit FP16Vec16(const FP32Vec16&);
- void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
+ void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
@@ -126,7 +126,7 @@ struct BF16Vec16 : public Vec {
explicit BF16Vec16(const FP32Vec16&);
- void save(void* ptr) const { *reinterpret_cast<__m256i*>(ptr) = reg; }
+ void save(void* ptr) const { _mm256_storeu_si256((__m256i*)ptr, reg); }
void save(void* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
@@ -180,8 +180,8 @@ struct BF16Vec32 : public Vec {
(__m128i)vec8_data.reg, 1)) {}
void save(void* ptr) const {
- *reinterpret_cast<__m256i*>(ptr) = reg_low;
- *reinterpret_cast<__m256i*>((__m256i*)ptr + 1) = reg_high;
+ _mm256_storeu_si256((__m256i*)ptr, reg_low);
+ _mm256_storeu_si256((__m256i*)ptr + 1, reg_high);
}
};
#endif
diff --git a/csrc/cpu/dnnl_helper.cpp b/csrc/cpu/dnnl_helper.cpp
new file mode 100644
index 0000000000000..f3f00edb36068
--- /dev/null
+++ b/csrc/cpu/dnnl_helper.cpp
@@ -0,0 +1,346 @@
+#include
+#include
+
+#include "common/memory_desc.hpp"
+#include "common/memory.hpp"
+
+#include "dnnl_helper.h"
+
+static dnnl::engine& default_engine() {
+ static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
+ return engine;
+}
+
+static dnnl::stream& default_stream() {
+ static dnnl::stream stream(default_engine());
+ return stream;
+}
+
+void release_dnnl_matmul_handler(int64_t handler) {
+ DNNLMatMulPrimitiveHandler* ptr =
+ reinterpret_cast(handler);
+ delete ptr;
+}
+
+template
+class DNNLPrimitiveCache {
+ public:
+ using cache_value_t = std::pair;
+ using result_value_t = VT;
+ using container_t = std::list;
+ using value_iterator_t = typename container_t::iterator;
+ using map_t = std::unordered_map;
+ using creator_t = VT (*)();
+
+ public:
+ DNNLPrimitiveCache(size_t capacity)
+ : capacity_(capacity),
+ values_(),
+ key_to_value_(std::min(256lu, capacity)) {
+ assert(capacity > 0);
+ }
+
+ template
+ result_value_t get_or_create(const KT& key, F&& creator) {
+ std::optional value = get_value(key);
+ if (value.has_value()) {
+ return value.value()->second;
+ } else {
+ return add_value({key, creator()})->second;
+ }
+ }
+
+ size_t size() const { return values_.size(); }
+
+ private:
+ void dump_data() {
+ std::stringstream ss;
+ ss << "table_id: " << std::hex << reinterpret_cast(this) << std::dec
+ << "\n";
+ ss << "container: [";
+ for (auto&& iter : values_) {
+ ss << "(" << iter.first << ", " << std::hex
+ << reinterpret_cast(iter.second.get()) << "), " << std::dec;
+ }
+ ss << "]\n";
+
+ ss << "map: [";
+ for (auto&& iter : key_to_value_) {
+ ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex
+ << reinterpret_cast(iter.second->second.get()) << std::dec
+ << "), ";
+ }
+ ss << "]\n";
+ std::printf("%s\n", ss.str().c_str());
+ }
+
+ value_iterator_t add_value(cache_value_t&& new_value) {
+ if (size() == capacity_) {
+ cache_value_t& last_item = values_.back();
+ key_to_value_.erase(last_item.first);
+ values_.pop_back();
+ }
+
+ auto& added_value_ = values_.emplace_front(std::move(new_value));
+ key_to_value_.emplace(added_value_.first, values_.begin());
+ return values_.begin();
+ }
+
+ std::optional get_value(const KT& key) {
+ if (key_to_value_.size() > 0 && key == values_.begin()->first) {
+ return values_.begin();
+ }
+
+ auto value_map_iterator = key_to_value_.find(key);
+ if (value_map_iterator != key_to_value_.end()) {
+ values_.splice(values_.begin(), values_, value_map_iterator->second);
+ return value_map_iterator->second;
+ } else {
+ return {};
+ }
+ }
+
+ private:
+ const size_t capacity_;
+ container_t values_;
+ map_t key_to_value_;
+};
+
+DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
+ const Args& args, dnnl::memory::data_type b_type)
+ : b_n_size_(args.b_n_size),
+ b_n_stride_(args.b_n_stride),
+ b_k_size_(args.b_k_size),
+ b_k_stride_(args.b_k_stride),
+ b_type_(b_type),
+ c_type_(args.c_type),
+ runtime_memory_ptrs_(8),
+ primitive_cache_size_(args.primitive_cache_size) {
+ assert(primitive_cache_size_ > 0);
+}
+
+void DNNLMatMulPrimitiveHandler::prepack_weight(
+ void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
+ dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
+ {b_k_stride_, b_n_stride_});
+ dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
+ dnnl::memory packed_weight(b_target_mem_desc, default_engine());
+ {
+ dnnl::reorder(original_weight, packed_weight)
+ .execute(default_stream(), original_weight, packed_weight);
+ default_stream().wait();
+ }
+ memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight;
+ b_target_mem_desc_ = b_target_mem_desc;
+}
+
+void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr(
+ size_t index, dnnl_memory* memory_ptr) {
+ dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage();
+ dnnl_memory_desc* mem_desc = const_cast(memory_ptr->md());
+ runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc};
+}
+
+std::pair
+DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) {
+ return runtime_memory_ptrs_[index];
+}
+
+namespace std {
+template <>
+struct hash {
+ size_t operator()(
+ const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
+ return hash()(val.b_n_size) ^ hash()(val.b_k_size) ^
+ hash()(static_cast(val.a_qs)) ^
+ hash()(static_cast(val.b_qs)) ^ hash()(val.use_azp) ^
+ hash()(static_cast(val.c_type));
+ }
+};
+
+template <>
+struct hash {
+ size_t operator()(
+ const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const {
+ return hash()(val.a_m_size) ^ hash()(val.use_bias) ^
+ hash()(static_cast(val.bias_type));
+ }
+};
+} // namespace std
+
+bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
+ const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
+ return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
+ l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp &&
+ l.c_type == r.c_type;
+}
+
+bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
+ const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) {
+ return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size &&
+ l.bias_type == r.bias_type;
+}
+
+static std::shared_ptr
+get_w8a8_class_primitive_cache(
+ const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
+ int64_t cache_size) {
+ static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128);
+ assert(cache_size > 0);
+ return cache.get_or_create(key, [&]() {
+ return std::make_shared(cache_size);
+ });
+}
+
+W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
+ : DNNLMatMulPrimitiveHandler(
+ static_cast(args),
+ dnnl::memory::data_type::s8),
+ use_azp_(args.use_a_zero_point),
+ a_qs_(args.a_quantization_strategy),
+ b_qs_(args.b_quantization_strategy),
+ m_size_cache_(nullptr) {
+ assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL);
+ assert(b_qs_ != QuantizationStrategy::PER_TOKEN);
+ if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
+ assert(!use_azp_);
+ };
+ prepack_weight(args.b_ptr,
+ create_primitive_desc(
+ MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
+ .use_bias = false,
+ .bias_type = dnnl::memory::data_type::undef},
+ true)
+ .weights_desc());
+ init_runtime_memory_cache(args);
+}
+
+void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
+ auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
+ auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
+ a_storage->set_data_handle((void*)args.a_ptr);
+ a_mem_desc->dims[0] = args.a_m_size;
+ c_storage->set_data_handle((void*)args.c_ptr);
+ c_mem_desc->dims[0] = args.a_m_size;
+
+ if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
+ auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2);
+ a_scale_storage->set_data_handle((void*)args.a_scales_ptr);
+ }
+ if (use_azp_) {
+ auto&& [a_zero_point_storage, a_zero_point_mem_desc] =
+ get_runtime_memory_ptr(3);
+ a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr);
+ }
+
+ if (args.use_bias) {
+ auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4);
+ bias_storage->set_data_handle((void*)args.bias_ptr);
+ }
+
+ dnnl::matmul matmul = get_matmul_cache(args);
+ matmul.execute(default_stream(), memory_cache_);
+ default_stream().wait();
+}
+
+dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
+ const MSizeCacheKey& key) {
+ if (m_size_cache_.get() == nullptr) {
+ ClassMatmulCacheKey key = {.b_n_size = b_n_size_,
+ .b_k_size = b_k_size_,
+ .a_qs = a_qs_,
+ .b_qs = b_qs_,
+ .use_azp = use_azp_,
+ .c_type = c_type_};
+ m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_);
+ }
+
+ return m_size_cache_->get_or_create(key, [&]() {
+ dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
+ return dnnl::matmul(desc);
+ });
+}
+
+void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
+ memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_},
+ dnnl::memory::data_type::s8,
+ dnnl::memory::format_tag::ab},
+ default_engine(), nullptr);
+ set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
+ memory_cache_[DNNL_ARG_DST] =
+ dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
+ default_engine(), nullptr);
+ set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
+
+ // For PER_TOKEN, scales will be applied in outside epilogue
+ if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
+ memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory(
+ {{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr);
+ set_runtime_memory_ptr(
+ 2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get());
+ if (use_azp_) {
+ memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory(
+ {{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr);
+ set_runtime_memory_ptr(
+ 3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get());
+ }
+ }
+
+ if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
+ memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
+ dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(),
+ (void*)args.b_scales_ptr);
+ } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
+ memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
+ dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
+ default_engine(), (void*)args.b_scales_ptr);
+ }
+
+ memory_cache_[DNNL_ARG_BIAS] =
+ dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
+ default_engine(), nullptr);
+ set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
+}
+
+dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
+ const MSizeCacheKey& key, bool first_time) {
+ dnnl::memory::desc a_md({key.a_m_size, b_k_size_},
+ dnnl::memory::data_type::s8,
+ dnnl::memory::format_tag::ab);
+ dnnl::memory::desc b_md;
+ if (first_time) {
+ b_md =
+ dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8,
+ dnnl::memory::format_tag::any);
+ } else {
+ b_md = b_target_mem_desc_;
+ }
+ dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
+ dnnl::memory::format_tag::ab);
+
+ dnnl::primitive_attr attr;
+ // For PER_TOKEN, scales will be applied in outside epilogue
+ if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
+ attr.set_scales_mask(DNNL_ARG_SRC, 0);
+ if (use_azp_) {
+ attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
+ }
+ }
+
+ if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
+ attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
+ } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
+ attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
+ }
+
+ if (key.use_bias) {
+ // For PER_TOKEN, bias will be applied in epilogue
+ assert(a_qs_ == QuantizationStrategy::PER_TENSOR);
+ dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
+ return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
+ c_md, attr);
+ } else {
+ return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
+ attr);
+ }
+}
diff --git a/csrc/cpu/dnnl_helper.h b/csrc/cpu/dnnl_helper.h
new file mode 100644
index 0000000000000..54ceefced9e98
--- /dev/null
+++ b/csrc/cpu/dnnl_helper.h
@@ -0,0 +1,169 @@
+#ifndef DNNL_HELPER_H
+#define DNNL_HELPER_H
+
+#include
+#include
+
+#include "oneapi/dnnl/dnnl.hpp"
+
+namespace c10 {
+struct BFloat16;
+struct Half;
+} // namespace c10
+
+namespace dnnl {
+namespace impl {
+struct memory_storage_t;
+struct matmul_pd_t;
+struct matmul_desc_t;
+} // namespace impl
+} // namespace dnnl
+struct dnnl_memory_desc;
+
+template
+class DNNLPrimitiveCache;
+
+template
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type =
+ dnnl::memory::data_type::undef;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
+};
+
+template <>
+struct DNNLType {
+ static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
+};
+
+template
+constexpr inline dnnl::memory::data_type get_dnnl_type() {
+ return DNNLType>::type;
+}
+
+class DNNLMatMulPrimitiveHandler {
+ public:
+ virtual ~DNNLMatMulPrimitiveHandler() = default;
+
+ protected:
+ struct Args {
+ dnnl_dim_t b_n_size;
+ dnnl_dim_t b_n_stride;
+ dnnl_dim_t b_k_size;
+ dnnl_dim_t b_k_stride;
+ void* b_ptr;
+ dnnl::memory::data_type c_type;
+ size_t primitive_cache_size;
+ };
+
+ protected:
+ DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
+
+ void prepack_weight(void* original_b_ptr,
+ dnnl::memory::desc b_target_mem_desc);
+
+ void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
+
+ std::pair
+ get_runtime_memory_ptr(size_t index);
+
+ protected:
+ const dnnl_dim_t b_n_size_;
+ const dnnl_dim_t b_n_stride_;
+ const dnnl_dim_t b_k_size_;
+ const dnnl_dim_t b_k_stride_;
+ dnnl::memory::data_type b_type_;
+ dnnl::memory::data_type c_type_;
+ std::unordered_map memory_cache_;
+ std::vector>
+ runtime_memory_ptrs_;
+ dnnl::memory::desc b_target_mem_desc_;
+ int64_t primitive_cache_size_;
+};
+
+class W8A8MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
+ public:
+ enum class QuantizationStrategy { PER_TOKEN, PER_TENSOR, PER_OUTPUT_CHANNEL };
+
+ struct Args : public DNNLMatMulPrimitiveHandler::Args {
+ bool use_a_zero_point;
+ QuantizationStrategy a_quantization_strategy;
+ QuantizationStrategy b_quantization_strategy;
+ float* b_scales_ptr;
+ };
+
+ struct ClassMatmulCacheKey {
+ dnnl_dim_t b_n_size;
+ dnnl_dim_t b_k_size;
+ QuantizationStrategy a_qs;
+ QuantizationStrategy b_qs;
+ bool use_azp;
+ dnnl::memory::data_type c_type;
+
+ friend bool operator==(const ClassMatmulCacheKey& l,
+ const ClassMatmulCacheKey& r);
+ };
+
+ struct MSizeCacheKey {
+ dnnl_dim_t a_m_size;
+ bool use_bias;
+ dnnl::memory::data_type bias_type;
+
+ friend bool operator==(const MSizeCacheKey& l, const MSizeCacheKey& r);
+ };
+
+ using MSizeCache = DNNLPrimitiveCache;
+ using ClassMatmulCache =
+ DNNLPrimitiveCache>;
+
+ struct ExecArgs : public MSizeCacheKey {
+ const int8_t* a_ptr;
+ const float* a_scales_ptr;
+ const int32_t* a_zero_points_ptr;
+ const void* bias_ptr;
+ void* c_ptr;
+ };
+
+ public:
+ W8A8MatMulPrimitiveHandler(const Args& args);
+
+ QuantizationStrategy get_input_scale_strategy() const { return a_qs_; }
+
+ bool get_input_use_zero_point() const { return use_azp_; }
+
+ void execute(ExecArgs& args);
+
+ private:
+ dnnl::matmul::primitive_desc create_primitive_desc(const MSizeCacheKey& key,
+ bool first_time);
+
+ void init_runtime_memory_cache(const Args& args);
+
+ dnnl::matmul get_matmul_cache(const MSizeCacheKey& key);
+
+ private:
+ const bool use_azp_;
+ const QuantizationStrategy a_qs_;
+ const QuantizationStrategy b_qs_;
+ std::shared_ptr m_size_cache_;
+};
+
+#endif
diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp
deleted file mode 100644
index 1cb8dc5b25a66..0000000000000
--- a/csrc/cpu/dnnl_helper.hpp
+++ /dev/null
@@ -1,206 +0,0 @@
-#ifndef DNNL_HELPER_HPP
-#define DNNL_HELPER_HPP
-
-#include
-#include
-
-#include "oneapi/dnnl/dnnl.hpp"
-
-namespace {
-template
-struct DNNLType {
- static constexpr dnnl::memory::data_type type =
- dnnl::memory::data_type::undef;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
-};
-
-template <>
-struct DNNLType {
- static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f16;
-};
-
-template
-constexpr inline dnnl::memory::data_type get_dnnl_type() {
- return DNNLType>::type;
-}
-}; // namespace
-
-template
-class DNNLPrimitiveHelper {
- public:
- // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
- // A: [M, K], row-major
- // B: [K, N], column-major
- // C: [M, N], row-major
- // bias: [N], row-major, optional
- // a_scales: [MS]
- // b_scales: [NS]
- // Note: Due to the limitation of oneDNN
- // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
- // not supported.
-
- template
- static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
- const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
- dnnl_dim_t K, const float* a_scales,
- const float* b_scales, dnnl_dim_t MS,
- dnnl_dim_t NS) {
- auto&& OutputType = get_dnnl_type();
- auto&& BiasType = get_dnnl_type();
-
- dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
- dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
- dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
-
- dnnl::primitive_attr attr;
- if constexpr (!InputNoScale) {
- if (MS == 1) {
- // per-tensor
- attr.set_scales_mask(DNNL_ARG_SRC, 0);
- } else {
- // per-token
- TORCH_CHECK(false, "per-token quantization is unsupported.");
- }
- }
-
- if (NS == 1) {
- // per-tensor
- attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
- } else {
- // per-channel
- attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
- }
-
- dnnl::matmul::primitive_desc matmul_pd;
-// Create memory descriptors with format_tag::any for the primitive. This
-// enables the matmul primitive to choose memory layouts for an
-// optimized primitive implementation, and these layouts may differ from the
-// ones provided by the user.
-#ifdef __aarch64__
- auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
- dnnl::memory::format_tag::any);
- auto mat_weights_md = dnnl::memory::desc(
- {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
- auto mat_dst_md =
- dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
- if (bias) {
- dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
- matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
- mat_weights_md, bias_md,
- mat_dst_md, attr);
- } else {
- matmul_pd = dnnl::matmul::primitive_desc(
- default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
- }
-#else
- if (bias) {
- dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
- matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
- bias_md, c_md, attr);
- } else {
- matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
- c_md, attr);
- }
-#endif
- dnnl::matmul matmul(matmul_pd);
-
- auto& engine = default_engine();
-
- dnnl::memory a_m(a_md, engine, (void*)a);
- dnnl::memory b_m(b_md, engine, (void*)b);
- dnnl::memory c_m(c_md, engine, (void*)c);
- dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
- (void*)a_scales);
- dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
- (void*)b_scales);
-
- auto& stream = default_stream();
-
- auto mat_src_mem = a_m;
- auto mat_weights_mem = b_m;
- auto mat_dst_mem = c_m;
-#ifdef __aarch64__
- if (matmul_pd.weights_desc() != b_m.get_desc()) {
- mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
- dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
- }
-#endif
- if constexpr (InputNoScale) {
- if (bias) {
- dnnl::memory::desc bias_md({N}, BiasType, {1});
- dnnl::memory bias_m(bias_md, engine, (void*)bias);
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_BIAS, bias_m},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- } else {
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- }
- } else {
- if (bias) {
- dnnl::memory::desc bias_md({N}, BiasType, {1});
- dnnl::memory bias_m(bias_md, engine, (void*)bias);
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_BIAS, bias_m},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- } else {
- matmul.execute(
- stream, {
- {DNNL_ARG_SRC, mat_src_mem},
- {DNNL_ARG_WEIGHTS, mat_weights_mem},
- {DNNL_ARG_DST, mat_dst_mem},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
- {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
- });
- }
- }
- stream.wait();
- }
-
- private:
- static dnnl::engine& default_engine() {
- static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
- return engine;
- }
-
- static dnnl::stream& default_stream() {
- static dnnl::stream stream(default_engine());
- return stream;
- }
-};
-#endif
diff --git a/csrc/cpu/dnnl_kernels.cpp b/csrc/cpu/dnnl_kernels.cpp
new file mode 100644
index 0000000000000..acc3b9ecde143
--- /dev/null
+++ b/csrc/cpu/dnnl_kernels.cpp
@@ -0,0 +1,494 @@
+#include "cpu_types.hpp"
+#include "dnnl_helper.h"
+
+namespace {
+template
+struct KernelVecType {
+ using load_vec_type = void;
+ using cvt_vec_type = void;
+};
+
+template <>
+struct KernelVecType {
+ using load_vec_type = vec_op::FP32Vec16;
+ using cvt_vec_type = vec_op::FP32Vec16;
+};
+
+#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
+template <>
+struct KernelVecType {
+ using load_vec_type = vec_op::BF16Vec16;
+ using cvt_vec_type = vec_op::FP32Vec16;
+};
+#endif
+
+template <>
+struct KernelVecType {
+#if defined(__powerpc64__) || defined(__s390x__)
+ // Power architecture-specific vector type
+ using load_vec_type = vec_op::FP32Vec16;
+#else
+ // Fallback for other architectures
+ using load_vec_type = vec_op::FP16Vec16;
+#endif
+ using cvt_vec_type = vec_op::FP32Vec16;
+};
+
+template
+void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
+ const float* scale, const int32_t* azp,
+ const int64_t num_tokens,
+ const int64_t input_stride,
+ const int64_t hidden_size) {
+ using load_vec_t = typename KernelVecType::load_vec_type;
+ using cvt_vec_t = typename KernelVecType::cvt_vec_type;
+ constexpr int64_t vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+
+ constexpr float i8_min =
+ static_cast(std::numeric_limits::min());
+ constexpr float i8_max =
+ static_cast(std::numeric_limits::max());
+ const cvt_vec_t inv_scale(1.0 / *scale);
+ const cvt_vec_t i8_min_vec(i8_min);
+ const cvt_vec_t i8_max_vec(i8_max);
+
+ cvt_vec_t zp_vec;
+ if constexpr (AZP) {
+ zp_vec = cvt_vec_t(static_cast(*azp));
+ }
+
+#pragma omp parallel for
+ for (int64_t i = 0; i < num_tokens; ++i) {
+ int64_t j = 0;
+ const scalar_t* input_ptr = input + i * input_stride;
+ int8_t* output_ptr = output + i * hidden_size;
+ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = elems_fp32 * inv_scale;
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + zp_vec;
+ }
+
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j);
+ }
+
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = elems_fp32 * inv_scale;
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + zp_vec;
+ }
+
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j, hidden_size - j);
+ }
+}
+
+template
+void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
+ float* scale, int32_t* azp,
+ const int64_t num_tokens,
+ const int64_t input_stride,
+ const int64_t hidden_size) {
+ using load_vec_t = typename KernelVecType::load_vec_type;
+ using cvt_vec_t = typename KernelVecType::cvt_vec_type;
+ constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+
+ constexpr float i8_min =
+ static_cast(std::numeric_limits::min());
+ constexpr float i8_max =
+ static_cast(std::numeric_limits::max());
+ const cvt_vec_t i8_min_vec(i8_min);
+ const cvt_vec_t i8_max_vec(i8_max);
+
+#pragma omp parallel for
+ for (int64_t i = 0; i < num_tokens; ++i) {
+ cvt_vec_t max_value(std::numeric_limits::lowest());
+ cvt_vec_t min_value(std::numeric_limits::max());
+ {
+ int64_t j = 0;
+ const scalar_t* input_ptr = input + i * input_stride;
+ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ if constexpr (AZP) {
+ max_value = max_value.max(elems_fp32);
+ min_value = min_value.min(elems_fp32);
+ } else {
+ max_value = max_value.max(elems_fp32.abs());
+ }
+ }
+
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+
+ if (j + vec_elem_num == hidden_size) {
+ if constexpr (AZP) {
+ max_value = max_value.max(elems_fp32);
+ min_value = min_value.min(elems_fp32);
+ } else {
+ max_value = max_value.max(elems_fp32.abs());
+ }
+ } else {
+ if constexpr (AZP) {
+ max_value = max_value.max(elems_fp32, hidden_size - j);
+ min_value = min_value.min(elems_fp32, hidden_size - j);
+ } else {
+ max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
+ }
+ }
+ }
+
+ float scale_val, azp_val;
+ if constexpr (AZP) {
+ float max_scalar = max_value.reduce_max();
+ float min_scalar = min_value.reduce_min();
+ scale_val = (max_scalar - min_scalar) / 255.0f;
+ azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
+ azp[i] = azp_val;
+ scale[i] = scale_val;
+ } else {
+ scale_val = max_value.reduce_max() / 127.0f;
+ scale[i] = scale_val;
+ }
+
+ const cvt_vec_t inv_scale(1.0 / scale_val);
+ const cvt_vec_t azp_vec(azp_val);
+
+ {
+ int64_t j = 0;
+ const scalar_t* input_ptr = input + i * input_stride;
+ int8_t* output_ptr = output + i * hidden_size;
+ for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = (elems_fp32 * inv_scale);
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + azp_vec;
+ }
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j);
+ }
+
+ load_vec_t elems(input_ptr + j);
+ cvt_vec_t elems_fp32(elems);
+ elems_fp32 = (elems_fp32 * inv_scale);
+
+ if constexpr (AZP) {
+ elems_fp32 = elems_fp32 + azp_vec;
+ }
+ elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
+ vec_op::INT8Vec16 elems_int8(elems_fp32);
+ elems_int8.save(output_ptr + j, hidden_size - j);
+ }
+ }
+}
+
+template
+void dynamic_quant_epilogue(const float* input, scalar_t* output,
+ const float* a_scale, const int32_t* azp,
+ const float* azp_adj, const scalar_t* bias,
+ const int64_t num_tokens,
+ const int64_t hidden_size) {
+ CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
+ using load_vec_t = typename KernelVecType::load_vec_type;
+ using cvt_vec_t = typename KernelVecType::cvt_vec_type;
+ constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+
+ const int64_t thread_num = omp_get_max_threads();
+ if (num_tokens > thread_num) {
+#pragma omp parallel for
+ for (int64_t i = 0; i < num_tokens; ++i) {
+ const float* input_ptr = input + i * hidden_size;
+ scalar_t* output_ptr = output + i * hidden_size;
+ int64_t j = 0;
+ cvt_vec_t token_scale_vec(a_scale[i]);
+ cvt_vec_t token_zp_scale_vec;
+ if constexpr (AZP) {
+ float zp_scale_val = a_scale[i] * static_cast(azp[i]);
+ token_zp_scale_vec = cvt_vec_t(zp_scale_val);
+ }
+ for (; j < hidden_size - vec_elem_num; ++j) {
+ cvt_vec_t elems_fp32(input_ptr + j);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + j);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + j);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + j);
+ }
+ cvt_vec_t elems_fp32(input_ptr + j);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + j);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + j);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + j, hidden_size - j);
+ }
+ } else {
+ const int64_t vec_iteration =
+ (hidden_size + vec_elem_num - 1) / vec_elem_num;
+ const int64_t vec_iteration_per_thread =
+ (vec_iteration + thread_num - 1) / thread_num;
+ const int64_t elem_num_per_thread = vec_iteration_per_thread * vec_elem_num;
+#pragma omp parallel for schedule(static, 1)
+ for (int64_t i = 0; i < thread_num; ++i) {
+ const int64_t start = elem_num_per_thread * i;
+ const int64_t end = std::min(hidden_size, elem_num_per_thread + start);
+ for (int64_t j = 0; j < num_tokens; ++j) {
+ cvt_vec_t token_scale_vec(a_scale[j]);
+ cvt_vec_t token_zp_scale_vec;
+ if constexpr (AZP) {
+ float zp_scale_val = a_scale[j] * static_cast(azp[j]);
+ token_zp_scale_vec = cvt_vec_t(zp_scale_val);
+ }
+ int64_t k = start;
+ const float* input_ptr = input + j * hidden_size;
+ scalar_t* output_ptr = output + j * hidden_size;
+ for (; k < end - vec_elem_num; k += vec_elem_num) {
+ cvt_vec_t elems_fp32(input_ptr + k);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + k);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + k);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + k);
+ }
+ if (k < end) {
+ cvt_vec_t elems_fp32(input_ptr + k);
+ elems_fp32 = elems_fp32 * token_scale_vec;
+ if constexpr (AZP) {
+ cvt_vec_t azp_adj_fp32(azp_adj + k);
+ elems_fp32 = elems_fp32 - azp_adj_fp32 * token_zp_scale_vec;
+ }
+ if constexpr (Bias) {
+ load_vec_t bias_vec(bias + k);
+ cvt_vec_t bias_vec_fp32(bias_vec);
+ elems_fp32 = elems_fp32 + bias_vec_fp32;
+ }
+ load_vec_t elems_out(elems_fp32);
+ elems_out.save(output_ptr + k, end - k);
+ }
+ }
+ }
+ }
+}
+} // namespace
+
+int64_t create_onednn_scaled_mm_handler(
+ const torch::Tensor& b, // [IC, OC], column-major
+ const torch::Tensor& b_scales, // [1] or [OC]
+ at::ScalarType output_type, bool dynamic_act_quant, bool use_azp,
+ int64_t primitive_cache_size) {
+ TORCH_CHECK(b.dim() == 2);
+ TORCH_CHECK(b.stride(0) == 1); // Column-major
+ TORCH_CHECK(b_scales.is_contiguous());
+
+ W8A8MatMulPrimitiveHandler::Args args;
+ args.primitive_cache_size = primitive_cache_size;
+
+ if (b_scales.numel() == 1) {
+ args.b_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
+ } else {
+ TORCH_CHECK_EQ(b_scales.numel(), b.size(1));
+ args.b_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_OUTPUT_CHANNEL;
+ }
+ args.b_scales_ptr = b_scales.data_ptr();
+ args.b_k_size = b.size(0);
+ args.b_k_stride = b.stride(0);
+ args.b_n_size = b.size(1);
+ args.b_n_stride = b.stride(1);
+ args.b_ptr = b.data_ptr();
+
+ if (dynamic_act_quant) {
+ // dynamic per-token, bias, A scales and A zps will be applied in outside.
+ args.a_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN;
+ args.use_a_zero_point = false;
+ } else {
+ // static per-tensor
+ args.a_quantization_strategy =
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR;
+ args.use_a_zero_point = use_azp;
+ }
+
+ VLLM_DISPATCH_FLOATING_TYPES(output_type, "create_onednn_scaled_mm_handler",
+ [&] {
+ if (dynamic_act_quant) {
+ args.c_type = get_dnnl_type();
+ } else {
+ args.c_type = get_dnnl_type();
+ }
+ });
+
+ return reinterpret_cast(new W8A8MatMulPrimitiveHandler(args));
+}
+
+void onednn_scaled_mm(
+ torch::Tensor& c, // [M, OC], row-major
+ const torch::Tensor& a, // [M, IC], row-major
+ const torch::Tensor& a_scales, // [M] or [1]
+ const std::optional& azp, // [M] or [1]
+ const std::optional& azp_adj, // [M] or [1]
+ const std::optional& bias, // [N]
+ int64_t handler) {
+ CPU_KERNEL_GUARD_IN(onednn_scaled_mm)
+ TORCH_CHECK(a.dim() == 2);
+ TORCH_CHECK(a.is_contiguous());
+ TORCH_CHECK(c.is_contiguous());
+ W8A8MatMulPrimitiveHandler* ptr =
+ reinterpret_cast(handler);
+ const int32_t* azp_ptr = nullptr;
+ if (azp.has_value()) {
+ azp_ptr = azp->data_ptr();
+ }
+ if (ptr->get_input_scale_strategy() ==
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
+ TORCH_CHECK_EQ(a_scales.numel(), 1);
+ }
+
+ W8A8MatMulPrimitiveHandler::ExecArgs exec_args;
+ exec_args.a_ptr = a.data_ptr();
+ exec_args.a_m_size = a.size(0);
+ exec_args.bias_ptr = nullptr;
+ exec_args.use_bias = false;
+ exec_args.a_scales_ptr = nullptr;
+ exec_args.a_zero_points_ptr = nullptr;
+
+ VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "onednn_scaled_mm", [&] {
+ if (ptr->get_input_scale_strategy() ==
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TENSOR) {
+ if (bias.has_value()) {
+ exec_args.bias_ptr = bias->data_ptr();
+ exec_args.bias_type = get_dnnl_type();
+ exec_args.use_bias = true;
+ }
+ exec_args.a_scales_ptr = a_scales.data_ptr();
+ exec_args.a_zero_points_ptr = azp_ptr;
+ exec_args.c_ptr = c.data_ptr();
+ ptr->execute(exec_args);
+ } else if (ptr->get_input_scale_strategy() ==
+ W8A8MatMulPrimitiveHandler::QuantizationStrategy::PER_TOKEN) {
+ torch::Tensor tmp_fp32_out =
+ torch::empty_like(c, ::at::ScalarType::Float);
+ exec_args.c_ptr = tmp_fp32_out.data_ptr();
+ ptr->execute(exec_args);
+ if (bias.has_value()) {
+ if (azp.has_value()) {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(),
+ bias->data_ptr(), c.size(0), c.size(1));
+ } else {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, nullptr,
+ bias->data_ptr(), c.size(0), c.size(1));
+ }
+ } else {
+ if (azp.has_value()) {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, azp_adj->data_ptr(),
+ (scalar_t*)nullptr, c.size(0), c.size(1));
+ } else {
+ dynamic_quant_epilogue(
+ tmp_fp32_out.data_ptr(), c.data_ptr(),
+ a_scales.data_ptr(), azp_ptr, nullptr, (scalar_t*)nullptr,
+ c.size(0), c.size(1));
+ }
+ }
+ } else {
+ TORCH_CHECK(false, "invalid act quant type.");
+ }
+ });
+}
+
+// static-per-tensor quantization.
+void static_scaled_int8_quant(
+ torch::Tensor& out, // [batch, hidden_size]
+ const torch::Tensor& input, // [batch, hidden_size]
+ const torch::Tensor& scale, std::optional const& azp) {
+ CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
+ TORCH_CHECK(out.is_contiguous());
+ TORCH_CHECK_EQ(input.dim(), 2);
+ TORCH_CHECK_EQ(input.stride(1), 1);
+ TORCH_CHECK(scale.numel() == 1);
+ TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
+
+ const int64_t stride = input.stride(0);
+ const int64_t hidden_size = input.size(1);
+ const int64_t num_tokens = input.size(0);
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
+ if (azp.has_value()) {
+ static_scaled_int8_quant_impl(
+ input.data_ptr(), out.data_ptr(),
+ scale.data_ptr(), azp->data_ptr(), num_tokens,
+ stride, hidden_size);
+ } else {
+ static_scaled_int8_quant_impl(input.data_ptr(),
+ out.data_ptr(),
+ scale.data_ptr(), nullptr,
+ num_tokens, stride, hidden_size);
+ }
+ });
+}
+
+// dynamic-per-token quantization.
+void dynamic_scaled_int8_quant(
+ torch::Tensor& out, // [batch, hidden_size]
+ const torch::Tensor& input, // [batch, hidden_size]
+ torch::Tensor& scale, // [batch, 1]
+ std::optional const& azp) {
+ CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
+ TORCH_CHECK(out.is_contiguous());
+ TORCH_CHECK_EQ(input.dim(), 2);
+ TORCH_CHECK_EQ(input.stride(1), 1);
+
+ const int64_t hidden_size = input.size(1);
+ const int64_t num_tokens = input.size(0);
+ const int64_t stride = input.stride(0);
+ VLLM_DISPATCH_FLOATING_TYPES(
+ input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
+ if (azp.has_value()) {
+ dynamic_scaled_int8_quant_impl(
+ input.data_ptr(), out.data_ptr(),
+ scale.data_ptr(), azp->data_ptr(), num_tokens,
+ stride, hidden_size);
+ } else {
+ dynamic_scaled_int8_quant_impl(
+ input.data_ptr(), out.data_ptr(),
+ scale.data_ptr(), nullptr, num_tokens, stride,
+ hidden_size);
+ }
+ });
+}
diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp
deleted file mode 100644
index 6e120b8d20a7e..0000000000000
--- a/csrc/cpu/quant.cpp
+++ /dev/null
@@ -1,951 +0,0 @@
-#include "cpu_types.hpp"
-#include "dnnl_helper.hpp"
-
-namespace {
-template
-struct KernelVecType {
- using load_vec_type = void;
- using azp_adj_load_vec_type = void;
- using cvt_vec_type = void;
-};
-
-template <>
-struct KernelVecType {
- using load_vec_type = vec_op::FP32Vec16;
- using azp_adj_load_vec_type = vec_op::INT32Vec16;
- using cvt_vec_type = vec_op::FP32Vec16;
-};
-
-#if !defined(__aarch64__) || defined(ARM_BF16_SUPPORT)
-template <>
-struct KernelVecType {
- using load_vec_type = vec_op::BF16Vec16;
- using azp_adj_load_vec_type = vec_op::INT32Vec16;
- using cvt_vec_type = vec_op::FP32Vec16;
-};
-#endif
-
-template <>
-struct KernelVecType {
-#if defined(__powerpc64__) || defined(__s390x__)
- // Power architecture-specific vector type
- using load_vec_type = vec_op::FP32Vec16;
-#else
- // Fallback for other architectures
- using load_vec_type = vec_op::FP16Vec16;
-#endif
- using azp_adj_load_vec_type = vec_op::INT32Vec16;
- using cvt_vec_type = vec_op::FP32Vec16;
-};
-
-#if defined(__AVX512F__) || defined(__aarch64__)
-template
-void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- const float* scale, const int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
- const cvt_vec_t inv_scale(1.0 / *scale);
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- cvt_vec_t zp_vec;
- if constexpr (AZP) {
- zp_vec = cvt_vec_t(static_cast(*azp));
- }
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
-
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
-
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-
-template
-void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- float* scale, int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t max_value(std::numeric_limits::lowest());
- cvt_vec_t min_value(std::numeric_limits::max());
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
-
- if (j + vec_elem_num == hidden_size) {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- } else {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32, hidden_size - j);
- min_value = min_value.min(elems_fp32, hidden_size - j);
- } else {
- max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
- }
- }
- }
-
- float scale_val, azp_val;
- if constexpr (AZP) {
- float max_scalar = max_value.reduce_max();
- float min_scalar = min_value.reduce_min();
- scale_val = (max_scalar - min_scalar) / 255.0f;
- azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
- azp[i] = static_cast(azp_val);
- scale[i] = scale_val;
- } else {
- scale_val = max_value.reduce_max() / 127.0f;
- scale[i] = scale_val;
- }
-
- const cvt_vec_t inv_scale(1.0 / scale_val);
- const cvt_vec_t azp_vec(azp_val);
-
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
- }
-}
-
-template
-void static_quant_epilogue(const float* input, scalar_t* output,
- const float a_scale, const float* b_scale,
- const int32_t* azp_with_adj, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t a_scale_vec(a_scale);
- cvt_vec_t b_scale_vec(*b_scale);
- cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
-
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-
-template
-void dynamic_quant_epilogue(const float* input, scalar_t* output,
- const float* a_scale, const float* b_scale,
- const int32_t* azp, const int32_t* azp_adj,
- const scalar_t* bias, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- cvt_vec_t token_scale_vec(a_scale[i]);
- cvt_vec_t token_zp_scale_vec;
- if constexpr (AZP) {
- float zp_scale_val = a_scale[i] * static_cast(azp[i]);
- if constexpr (!PerChannel) {
- zp_scale_val *= *b_scale;
- }
- token_zp_scale_vec = cvt_vec_t(zp_scale_val);
- }
-
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-#elif defined(__powerpc64__)
-template
-void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- const float* scale, const int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
-
- const cvt_vec_t inv_scale(1.0 / *scale);
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- cvt_vec_t zp_vec;
- if constexpr (AZP) {
- zp_vec = cvt_vec_t(static_cast(*azp));
- }
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = elems_fp32 * inv_scale;
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + zp_vec;
- }
-
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-template
-void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- float* scale, int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- using load_vec_t = typename KernelVecType::load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- constexpr float i8_min =
- static_cast(std::numeric_limits::min());
- constexpr float i8_max =
- static_cast(std::numeric_limits::max());
- const cvt_vec_t i8_min_vec(i8_min);
- const cvt_vec_t i8_max_vec(i8_max);
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t max_value(std::numeric_limits::lowest());
- cvt_vec_t min_value(std::numeric_limits::max());
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
-
- if (j + vec_elem_num == hidden_size) {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32);
- min_value = min_value.min(elems_fp32);
- } else {
- max_value = max_value.max(elems_fp32.abs());
- }
- } else {
- if constexpr (AZP) {
- max_value = max_value.max(elems_fp32, hidden_size - j);
- min_value = min_value.min(elems_fp32, hidden_size - j);
- } else {
- max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
- }
- }
- }
-
- float scale_val, azp_val;
- if constexpr (AZP) {
- float max_scalar = max_value.reduce_max();
- float min_scalar = min_value.reduce_min();
- scale_val = (max_scalar - min_scalar) / 255.0f;
- azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
- azp[i] = static_cast(azp_val);
- scale[i] = scale_val;
- } else {
- scale_val = max_value.reduce_max() / 127.0f;
- scale[i] = scale_val;
- }
-
- const cvt_vec_t inv_scale(1.0 / scale_val);
- const cvt_vec_t azp_vec(azp_val);
-
- {
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j);
- }
-
- load_vec_t elems(input + i * hidden_size + j);
- cvt_vec_t elems_fp32(elems);
- elems_fp32 = (elems_fp32 * inv_scale);
-
- if constexpr (AZP) {
- elems_fp32 = elems_fp32 + azp_vec;
- }
- elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
- vec_op::INT8Vec16 elems_int8(elems_fp32);
- elems_int8.save(output + i * hidden_size + j, hidden_size - j);
- }
- }
-}
-template
-void static_quant_epilogue(const float* input, scalar_t* output,
- const float a_scale, const float* b_scale,
- const int32_t* azp_with_adj, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- cvt_vec_t a_scale_vec(a_scale);
- cvt_vec_t b_scale_vec(*b_scale);
- cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
-
- int j = 0;
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
-
- if constexpr (PerChannel) {
- b_scale_vec = cvt_vec_t(b_scale + j);
- scale_vec = b_scale_vec * a_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-template
-void dynamic_quant_epilogue(const float* input, scalar_t* output,
- const float* a_scale, const float* b_scale,
- const int32_t* azp, const int32_t* azp_adj,
- const scalar_t* bias, const int num_tokens,
- const int hidden_size) {
- CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
- using load_vec_t = typename KernelVecType::load_vec_type;
- using azp_adj_load_vec_t =
- typename KernelVecType::azp_adj_load_vec_type;
- using cvt_vec_t = typename KernelVecType::cvt_vec_type;
- constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
-
- #pragma omp parallel for
- for (int i = 0; i < num_tokens; ++i) {
- int j = 0;
- cvt_vec_t token_scale_vec(a_scale[i]);
- cvt_vec_t token_zp_scale_vec;
- if constexpr (AZP) {
- float zp_scale_val = a_scale[i] * static_cast(azp[i]);
- if constexpr (!PerChannel) {
- zp_scale_val *= *b_scale;
- }
- token_zp_scale_vec = cvt_vec_t(zp_scale_val);
- }
-
- for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j);
- }
-
- cvt_vec_t elems_fp32(input + i * hidden_size + j);
- elems_fp32 = elems_fp32 * token_scale_vec;
-
- if constexpr (AZP) {
- azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
- cvt_vec_t azp_adj_fp32(azp_adj_vec);
- azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
-
- if constexpr (PerChannel) {
- cvt_vec_t b_scale_vec(b_scale + j);
- azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
- }
-
- elems_fp32 = elems_fp32 - azp_adj_fp32;
- }
-
- if constexpr (Bias) {
- load_vec_t bias_vec(bias + j);
- cvt_vec_t bias_vec_fp32(bias_vec);
- elems_fp32 = elems_fp32 + bias_vec_fp32;
- }
-
- load_vec_t elems_out(elems_fp32);
- elems_out.save(output + i * hidden_size + j, hidden_size - j);
- }
-}
-#else
-template
-void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- const float* scale, const int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(false,
- "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
- "support.")
-}
-
-template
-void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
- float* scale, int32_t* azp,
- const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(false,
- "dynamic_scaled_int8_quant_impl requires "
- "AVX512/powerpc64/AArch64 support.")
-}
-
-template
-void static_quant_epilogue(const float* input, scalar_t* output,
- const float a_scale, const float* b_scale,
- const int32_t* azp_with_adj, const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(
- false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
-}
-
-template
-void dynamic_quant_epilogue(const float* input, scalar_t* output,
- const float* a_scale, const float* b_scale,
- const int32_t* azp, const int32_t* azp_with_adj,
- const scalar_t* bias, const int num_tokens,
- const int hidden_size) {
- TORCH_CHECK(
- false,
- "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
-}
-#endif
-} // namespace
-
-void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
- const torch::Tensor& a, // [M, IC], row-major
- const torch::Tensor& b, // [IC, OC], column-major
- const torch::Tensor& a_scales, // [1] or [M]
- const torch::Tensor& b_scales, // [1] or [OC]
- const std::optional& bias // [OC]
-) {
- CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
- // Checks for conformality
- TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
- "int8_scaled_mm only supports INT8 inputs.")
- TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
- TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
- b.size(1) == c.size(1));
- TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
- TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
-
- // Check for strides and alignment
- TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
- TORCH_CHECK(b.stride(0) == 1); // Column-major
- TORCH_CHECK(c.stride(0) % 16 == 0 &&
- b.stride(1) % 16 == 0); // 16 Byte Alignment
- TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
-
- if (bias) {
- TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
- bias->dim() == 1);
- }
-
- VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] {
- if (a_scales.numel() != 1) {
- // per-token
- // Note: oneDNN doesn't support per-token activation quantization
- // Ideally we want to fuse the GEMM and the scale procedure with oneDNN
- // JIT, the intermediate data is cached in registers or L1. But for now
- // the oneDNN GEMM code generation only supports two quantization
- // patterns: per-tensor or per-output-channel of weight.
- // So we have to apply the per-token scale with a 'epilogue'. In C=s_a *
- // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN
- // GEMM, then the per-token scale (and bias) is applied with the epilogue
- // C=s_a * C_inter + bias.
- torch::Tensor tmp_fp32_out =
- torch::empty_like(c, ::at::ScalarType::Float);
- // Compute C_inter=s_b * (A@B)
- DNNLPrimitiveHelper::gemm_s8s8_jit(
- a.data_ptr(), b.data_ptr(),
- tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1),
- a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel());
- if (bias.has_value()) {
- // Compute C=s_a * C_inter + bias
- dynamic_quant_epilogue(
- tmp_fp32_out.data_ptr(), c.data_ptr(),
- a_scales.data_ptr(), nullptr, nullptr, nullptr,
- bias->data_ptr