Merge branch 'main' of https://github.com/neuralmagic/vllm into sage/dbo-full-cudagraphs

This commit is contained in:
Sage Moore 2025-07-31 21:24:57 +00:00
commit e283eff060
34 changed files with 3296 additions and 271 deletions

View File

@ -82,7 +82,7 @@ steps:
- bash standalone_tests/python_only_compile.sh
- label: Basic Correctness Test # 30min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
fast_check: true
torch_nightly: true
source_file_dependencies:
@ -99,7 +99,7 @@ steps:
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
- label: Chunked Prefill Test
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/basic_correctness/test_chunked_prefill
@ -108,7 +108,7 @@ steps:
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- label: Core Test # 10min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
fast_check: true
source_file_dependencies:
- vllm/core
@ -209,7 +209,7 @@ steps:
- pytest -v -s distributed/test_eplb_execute.py
- label: Metrics, Tracing Test # 10min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
num_gpus: 2
source_file_dependencies:
- vllm/
@ -228,7 +228,7 @@ steps:
##### 1 GPU test #####
- label: Regression Test # 5min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/test_regression
@ -280,7 +280,7 @@ steps:
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
- label: Examples Test # 25min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/examples"
source_file_dependencies:
- vllm/entrypoints
@ -305,7 +305,7 @@ steps:
- VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2
- label: Prefix Caching Test # 9min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/prefix_caching
@ -314,7 +314,7 @@ steps:
- label: Platform Tests (CUDA)
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/cuda
@ -353,9 +353,10 @@ steps:
- pytest -v -s compile/test_silu_mul_quant_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py
- pytest -v -s compile/test_fusion_all_reduce.py
- label: PyTorch Fullgraph Smoke Test # 9min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
torch_nightly: true
source_file_dependencies:
- vllm/
@ -368,7 +369,7 @@ steps:
- pytest -v -s compile/piecewise/test_full_cudagraph.py
- label: PyTorch Fullgraph Test # 18min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
torch_nightly: true
source_file_dependencies:
- vllm/
@ -377,7 +378,7 @@ steps:
- pytest -v -s compile/test_full_graph.py
- label: Kernels Core Operation Test
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- csrc/
- tests/kernels/core
@ -416,7 +417,7 @@ steps:
parallelism: 2
- label: Kernels Mamba Test
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- csrc/mamba/
- tests/kernels/mamba
@ -424,7 +425,7 @@ steps:
- pytest -v -s kernels/mamba
- label: Tensorizer Test # 11min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
soft_fail: true
source_file_dependencies:
- vllm/model_executor/model_loader
@ -437,7 +438,7 @@ steps:
- pytest -v -s entrypoints/openai/test_tensorizer_entrypoint.py
- label: Model Executor Test
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/model_executor
- tests/model_executor
@ -447,7 +448,7 @@ steps:
- pytest -v -s model_executor
- label: Benchmarks # 9min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/.buildkite"
source_file_dependencies:
- benchmarks/
@ -455,7 +456,7 @@ steps:
- bash scripts/run-benchmarks.sh
- label: Benchmarks CLI Test # 10min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/benchmarks/
@ -494,7 +495,7 @@ steps:
- pytest -s entrypoints/openai/correctness/
- label: Encoder Decoder tests # 5min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/encoder_decoder
@ -502,7 +503,7 @@ steps:
- pytest -v -s encoder_decoder
- label: OpenAI-Compatible Tool Use # 20 min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
fast_check: false
source_file_dependencies:
- vllm/
@ -623,7 +624,7 @@ steps:
# This test is used only in PR development phase to test individual models and should never run on main
- label: Custom Models Test
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
optional: true
commands:
- echo 'Testing custom models...'
@ -658,7 +659,7 @@ steps:
##### multi gpus test #####
- label: Distributed Comm Ops Test # 7min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 2
source_file_dependencies:
@ -755,7 +756,7 @@ steps:
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
- label: Multi-step Tests (4 GPUs) # 36min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
@ -776,7 +777,7 @@ steps:
- pytest -v -s multi_step/test_correctness_llm.py
- label: Pipeline Parallelism Test # 45min
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
@ -790,7 +791,7 @@ steps:
- pytest -v -s distributed/test_pipeline_parallel.py
- label: LoRA TP Test (Distributed)
mirror_hardwares: [amdexperimental, amdproduction]
mirror_hardwares: [amdexperimental]
num_gpus: 4
source_file_dependencies:
- vllm/lora

View File

@ -370,6 +370,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
fi
# Install vllm wheel first, so that torch etc will be installed.
# !bang
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/uv \
uv pip install --system dist/*.whl --verbose \

View File

@ -1,4 +1,4 @@
ARG NIGHTLY_DATE="20250724"
ARG NIGHTLY_DATE="20250730"
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.12_tpuvm_$NIGHTLY_DATE"
FROM $BASE_IMAGE

View File

@ -625,6 +625,7 @@ See [this page](generative_models.md) for more information on how to use generat
| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ |
| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ |
| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ |
| `Step3VLForConditionalGeneration` | Step3-VL | T + I<sup>+</sup> | `stepfun-ai/step3` | | ✅︎ | ✅︎ |
| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ |
| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ |

View File

@ -13,6 +13,38 @@ except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
QUESTION = "What is the content of each image?"
IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg",
"https://upload.wikimedia.org/wikipedia/commons/2/26/Ultramarine_Flycatcher_%28Ficedula_superciliaris%29_Naggar%2C_Himachal_Pradesh%2C_2013_%28cropped%29.JPG",
"https://upload.wikimedia.org/wikipedia/commons/thumb/e/e5/Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg/2560px-Anim1754_-_Flickr_-_NOAA_Photo_Library_%281%29.jpg",
"https://upload.wikimedia.org/wikipedia/commons/d/d4/Starfish%2C_Caswell_Bay_-_geograph.org.uk_-_409413.jpg",
"https://upload.wikimedia.org/wikipedia/commons/6/69/Grapevinesnail_01.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/0/0b/Texas_invasive_Musk_Thistle_1.jpg/1920px-Texas_invasive_Musk_Thistle_1.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/7/7a/Huskiesatrest.jpg/2880px-Huskiesatrest.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/6/68/Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg/1920px-Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg",
"https://upload.wikimedia.org/wikipedia/commons/3/30/George_the_amazing_guinea_pig.jpg",
"https://upload.wikimedia.org/wikipedia/commons/thumb/1/1f/Oryctolagus_cuniculus_Rcdo.jpg/1920px-Oryctolagus_cuniculus_Rcdo.jpg",
"https://upload.wikimedia.org/wikipedia/commons/9/98/Horse-and-pony.jpg",
]
def get_custom_mm_prompts(num_prompts):
prompts = []
for url in IMAGE_URLS:
prompts.append(
[
{"type": "image_url", "image_url": {"url": url}},
{"type": "text", "text": QUESTION},
]
)
if num_prompts > len(IMAGE_URLS):
prompts = prompts * (num_prompts // len(IMAGE_URLS) + 1)
return [[{"role": "user", "content": prompt}] for prompt in prompts[:num_prompts]]
def parse_args():
parser = FlexibleArgumentParser()
add_dataset_parser(parser)
@ -35,6 +67,7 @@ def parse_args():
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--model-dir", type=str, default=None)
parser.add_argument("--eagle-dir", type=str, default=None)
parser.add_argument("--custom-mm-prompts", action="store_true")
return parser.parse_args()
@ -44,14 +77,26 @@ def main():
model_dir = args.model_dir
if args.model_dir is None:
if args.custom_mm_prompts:
raise ValueError(
"custom_mm_prompts requires mm based models"
"default llama3.1-8b-instruct is not mm based"
"please specify model_dir to give a mm based model"
)
model_dir = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
args.custom_skip_chat_template = True
prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False) for prompt in prompts
]
if not args.custom_mm_prompts:
prompts = get_samples(args, tokenizer)
# add_special_tokens is False to avoid adding bos twice
# when using chat templates
prompt_ids = [
tokenizer.encode(prompt.prompt, add_special_tokens=False)
for prompt in prompts
]
else:
prompts = get_custom_mm_prompts(args.num_prompts)
if args.method == "eagle" or args.method == "eagle3":
eagle_dir = args.eagle_dir
@ -85,10 +130,17 @@ def main():
speculative_config=speculative_config,
disable_log_stats=False,
max_model_len=16384,
limit_mm_per_prompt={"image": 5},
disable_chunked_mm_input=True,
)
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params)
if not args.custom_mm_prompts:
outputs = llm.generate(
prompt_token_ids=prompt_ids, sampling_params=sampling_params
)
else:
outputs = llm.chat(prompts, sampling_params=sampling_params)
# print the generated text
if args.print_output:

View File

@ -22,9 +22,7 @@ aiohttp==3.10.11
aiohttp-cors==0.8.1
# via ray
aiosignal==1.3.1
# via
# aiohttp
# ray
# via aiohttp
albucore==0.0.16
# via terratorch
albumentations==1.4.6
@ -139,7 +137,7 @@ contourpy==1.3.0
# via matplotlib
cramjam==2.9.0
# via fastparquet
cupy-cuda12x==13.3.0
cupy-cuda12x==13.5.1
# via ray
cycler==0.12.1
# via matplotlib
@ -226,7 +224,6 @@ frozenlist==1.5.0
# via
# aiohttp
# aiosignal
# ray
fsspec==2024.9.0
# via
# datasets
@ -603,10 +600,18 @@ opencv-python-headless==4.11.0.86
opentelemetry-api==1.35.0
# via
# mlflow-skinny
# opentelemetry-exporter-prometheus
# opentelemetry-sdk
# opentelemetry-semantic-conventions
opentelemetry-exporter-prometheus==0.56b0
# via ray
opentelemetry-proto==1.36.0
# via ray
opentelemetry-sdk==1.35.0
# via mlflow-skinny
# via
# mlflow-skinny
# opentelemetry-exporter-prometheus
# ray
opentelemetry-semantic-conventions==0.56b0
# via opentelemetry-sdk
packaging==24.2
@ -697,7 +702,9 @@ pqdm==0.2.0
pretrainedmodels==0.7.4
# via segmentation-models-pytorch
prometheus-client==0.22.0
# via ray
# via
# opentelemetry-exporter-prometheus
# ray
propcache==0.2.0
# via yarl
proto-plus==1.26.1
@ -707,6 +714,7 @@ protobuf==5.28.3
# google-api-core
# googleapis-common-protos
# mlflow-skinny
# opentelemetry-proto
# proto-plus
# ray
# tensorboardx
@ -854,7 +862,7 @@ rasterio==1.4.3
# rioxarray
# terratorch
# torchgeo
ray==2.43.0
ray==2.48.0
# via -r requirements/test.in
redis==5.2.0
# via tensorizer

View File

@ -19,8 +19,8 @@ nixl==0.3.0
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.9.0.dev20250724
torchvision==0.24.0.dev20250724
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250724-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250724-cp312-cp312-linux_x86_64.whl ; python_version == "3.12"
torch==2.9.0.dev20250730
torchvision==0.24.0.dev20250730
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250730-cp312-cp312-linux_x86_64.whl ; python_version == "3.12"

203
setup.py
View File

@ -282,10 +282,69 @@ class cmake_build_ext(build_ext):
self.copy_file(file, dst_file)
class repackage_wheel(build_ext):
class precompiled_wheel_utils:
"""Extracts libraries and other files from an existing wheel."""
def get_base_commit_in_main_branch(self) -> str:
@staticmethod
def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict:
import tempfile
import zipfile
temp_dir = None
try:
if not os.path.isfile(wheel_url_or_path):
wheel_filename = wheel_url_or_path.split("/")[-1]
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
wheel_path = os.path.join(temp_dir, wheel_filename)
print(f"Downloading wheel from {wheel_url_or_path} "
f"to {wheel_path}")
from urllib.request import urlretrieve
urlretrieve(wheel_url_or_path, filename=wheel_path)
else:
wheel_path = wheel_url_or_path
print(f"Using existing wheel at {wheel_path}")
package_data_patch = {}
with zipfile.ZipFile(wheel_path) as wheel:
files_to_copy = [
"vllm/_C.abi3.so",
"vllm/_moe_C.abi3.so",
"vllm/_flashmla_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/cumem_allocator.abi3.so",
]
compiled_regex = re.compile(
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
file_members = list(
filter(lambda x: x.filename in files_to_copy,
wheel.filelist))
file_members += list(
filter(lambda x: compiled_regex.match(x.filename),
wheel.filelist))
for file in file_members:
print(f"[extract] {file.filename}")
target_path = os.path.join(".", file.filename)
os.makedirs(os.path.dirname(target_path), exist_ok=True)
with wheel.open(file.filename) as src, open(
target_path, "wb") as dst:
shutil.copyfileobj(src, dst)
pkg = os.path.dirname(file.filename).replace("/", ".")
package_data_patch.setdefault(pkg, []).append(
os.path.basename(file.filename))
return package_data_patch
finally:
if temp_dir is not None:
print(f"Removing temporary directory {temp_dir}")
shutil.rmtree(temp_dir)
@staticmethod
def get_base_commit_in_main_branch() -> str:
# Force to use the nightly wheel. This is mainly used for CI testing.
if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL:
return "nightly"
@ -334,115 +393,6 @@ class repackage_wheel(build_ext):
"wheel may not be compatible with your dev branch: %s", err)
return "nightly"
def run(self) -> None:
assert _is_cuda(
), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
if wheel_location is None:
base_commit = self.get_base_commit_in_main_branch()
wheel_location = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
# Fallback to nightly wheel if latest commit wheel is unavailable,
# in this rare case, the nightly release CI hasn't finished on main.
if not is_url_available(wheel_location):
wheel_location = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
import zipfile
if os.path.isfile(wheel_location):
wheel_path = wheel_location
print(f"Using existing wheel={wheel_path}")
else:
# Download the wheel from a given URL, assume
# the filename is the last part of the URL
wheel_filename = wheel_location.split("/")[-1]
import tempfile
# create a temporary directory to store the wheel
temp_dir = tempfile.mkdtemp(prefix="vllm-wheels")
wheel_path = os.path.join(temp_dir, wheel_filename)
print(f"Downloading wheel from {wheel_location} to {wheel_path}")
from urllib.request import urlretrieve
try:
urlretrieve(wheel_location, filename=wheel_path)
except Exception as e:
from setuptools.errors import SetupError
raise SetupError(
f"Failed to get vLLM wheel from {wheel_location}") from e
# Set the dist_dir for Docker build context
dist_dir = ("/workspace/dist"
if envs.VLLM_DOCKER_BUILD_CONTEXT else "dist")
os.makedirs(dist_dir, exist_ok=True)
# Extract only necessary compiled .so files from precompiled wheel
with zipfile.ZipFile(wheel_path) as wheel:
# Get version from METADATA (optional, mostly useful for logging)
metadata_file = next((n for n in wheel.namelist()
if n.endswith(".dist-info/METADATA")), None)
if not metadata_file:
raise RuntimeError(
"Could not find METADATA in precompiled wheel.")
metadata = wheel.read(metadata_file).decode()
version_line = next((line for line in metadata.splitlines()
if line.startswith("Version: ")), None)
if not version_line:
raise RuntimeError(
"Could not determine version from METADATA.")
version = version_line.split(": ")[1].strip()
print(f"Extracting precompiled kernels from vLLM wheel version: "
f"{version}")
# List of compiled shared objects to extract
files_to_copy = [
"vllm/_C.abi3.so",
"vllm/_moe_C.abi3.so",
"vllm/_flashmla_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/cumem_allocator.abi3.so",
]
file_members = list(
filter(lambda x: x.filename in files_to_copy, wheel.filelist))
compiled_regex = re.compile(
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
file_members += list(
filter(lambda x: compiled_regex.match(x.filename),
wheel.filelist))
for file in file_members:
print(f"Extracting and including {file.filename} "
"from existing wheel")
package_name = os.path.dirname(file.filename).replace("/", ".")
file_name = os.path.basename(file.filename)
if package_name not in package_data:
package_data[package_name] = []
output_base = (dist_dir
if envs.VLLM_DOCKER_BUILD_CONTEXT else ".")
target_path = os.path.join(output_base, file.filename)
os.makedirs(os.path.dirname(target_path), exist_ok=True)
with wheel.open(file.filename) as src, open(target_path,
"wb") as dst:
shutil.copyfileobj(src, dst)
package_data[package_name].append(file_name)
# Copy wheel into dist dir for Docker to consume (e.g., via --mount)
if envs.VLLM_DOCKER_BUILD_CONTEXT:
arch_tag = "cp38-abi3-manylinux1_x86_64"
corrected_wheel_name = f"vllm-{version}-{arch_tag}.whl"
final_wheel_path = os.path.join(dist_dir, corrected_wheel_name)
print(
"Docker build context detected, copying precompiled wheel to "
f"{final_wheel_path}")
shutil.copy2(wheel_path, final_wheel_path)
def _no_device() -> bool:
return VLLM_TARGET_DEVICE == "empty"
@ -676,16 +626,37 @@ package_data = {
]
}
# If using precompiled, extract and patch package_data (in advance of setup)
if envs.VLLM_USE_PRECOMPILED:
assert _is_cuda(), "VLLM_USE_PRECOMPILED is only supported for CUDA builds"
wheel_location = os.getenv("VLLM_PRECOMPILED_WHEEL_LOCATION", None)
if wheel_location is not None:
wheel_url = wheel_location
else:
base_commit = precompiled_wheel_utils.get_base_commit_in_main_branch()
wheel_url = f"https://wheels.vllm.ai/{base_commit}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
from urllib.request import urlopen
try:
with urlopen(wheel_url) as resp:
if resp.status != 200:
wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
except Exception as e:
print(f"[warn] Falling back to nightly wheel: {e}")
wheel_url = "https://wheels.vllm.ai/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl"
patch = precompiled_wheel_utils.extract_precompiled_and_patch_package(
wheel_url)
for pkg, files in patch.items():
package_data.setdefault(pkg, []).extend(files)
if _no_device():
ext_modules = []
if not ext_modules:
if not ext_modules or envs.VLLM_USE_PRECOMPILED:
# Disable build_ext when using precompiled wheel
cmdclass = {}
else:
cmdclass = {
"build_ext":
repackage_wheel if envs.VLLM_USE_PRECOMPILED else cmake_build_ext
}
cmdclass = {"build_ext": cmake_build_ext}
setup(
# static metadata should rather go in pyproject.toml

View File

@ -7,22 +7,26 @@ import torch
import vllm.envs as envs
from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
GroupShape, QuantFP8)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
from ..utils import multi_gpu_test
from ..utils import has_module_attribute, multi_gpu_test
from .backend import TestBackend
class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, eps=1e-6):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@ -43,7 +47,7 @@ class TestAllReduceRMSNormModel(torch.nn.Module):
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, eps=1e-6):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
@ -62,24 +66,101 @@ class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
class TestAllReduceFusedAddRMSNormStaticQuantFP8Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.quant_fp8 = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
torch.ops._C.static_scaled_fp8_quant(self.output,
norm_output.contiguous(),
self.scale)
return self.output, residual_output
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default
]
class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = RMSNorm(hidden_size, eps)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(token_num, 128)
scale_n = hidden_size // 16
rounded_n = round_up(scale_n, 4)
self.output_scale = torch.empty((rounded_m, rounded_n // 4),
dtype=torch.int32)
def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
norm_output = norm_output.reshape(-1, norm_output.shape[-1])
torch.ops._C.scaled_fp4_quant(self.output, norm_output,
self.output_scale, self.scale)
return self.output, residual_output, self.output_scale
def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.scaled_fp4_quant.default
]
@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"test_model",
[TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel])
@pytest.mark.parametrize("test_model", [
TestAllReduceRMSNormModel,
TestAllReduceFusedAddRMSNormModel,
TestAllReduceFusedAddRMSNormStaticQuantFP8Model,
TestAllReduceFusedAddRMSNormStaticQuantFP4Model,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [8])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
@pytest.mark.skipif(not find_spec("flashinfer"),
reason="flashinfer is not installed")
@pytest.mark.skipif(not current_platform.is_device_capability(100),
reason="Only test on SM100")
@pytest.mark.skipif(
not find_spec("flashinfer")
or not has_module_attribute("flashinfer.comm", "trtllm_allreduce_fusion"),
reason="flashinfer is not found or flashinfer "
"is not compiled with trtllm_allreduce_fusion")
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
num_processes = 2
if (test_model == TestAllReduceFusedAddRMSNormStaticQuantFP4Model
and not current_platform.has_device_capability(100)):
pytest.skip("Skip as nvfp4 is only supported on "
"devices with compute capability 10.0 (Blackwell)")
def run_torch_spawn(fn, nprocs):
torch.multiprocessing.spawn(fn,
@ -113,12 +194,11 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
vllm_config = VllmConfig(
compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm"],
compile_sizes=[2, 4, 8]))
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"]))
vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True)
enable_fi_allreduce_fusion=True, enable_noop=True)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
@ -130,14 +210,16 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
seed=42)
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
backend = TestBackend(all_reduce_fusion_pass)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)
model = test_model_cls(hidden_size)
backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
requires_grad=False)
residual = torch.randn((batch_size * seq_len, hidden_size),
requires_grad=False)
token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
residual = torch.randn((token_num, hidden_size), requires_grad=False)
compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)

View File

@ -279,6 +279,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
"Starcoder2ForCausalLM": _HfExamplesInfo("bigcode/starcoder2-3b"),
"Step3TextForCausalLM": _HfExamplesInfo("stepfun-ai/step3",
trust_remote_code=True,
is_available_online=False),
"SolarForCausalLM": _HfExamplesInfo("upstage/solar-pro-preview-instruct",
trust_remote_code=True),
"TeleChat2ForCausalLM": _HfExamplesInfo("Tele-AI/TeleChat2-3B",
@ -457,6 +460,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B",
trust_remote_code=True),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"Step3VLForConditionalGeneration": _HfExamplesInfo("stepfun-ai/step3",
trust_remote_code=True,
is_available_online=False),
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
trust_remote_code=True),
"TarsierForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier-7b", # noqa: E501

View File

@ -4,6 +4,7 @@
import asyncio
import copy
import functools
import importlib
import os
import signal
import subprocess
@ -974,3 +975,14 @@ def get_client_text_logprob_generations(
return [(text_generations, text,
(None if x.logprobs is None else x.logprobs.top_logprobs))
for completion in completions for x in completion.choices]
def has_module_attribute(module_name, attribute_name):
"""
Helper function to check if a module has a specific attribute.
"""
try:
module = importlib.import_module(module_name)
return hasattr(module, attribute_name)
except ImportError:
return False

View File

@ -0,0 +1,196 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import numpy as np
import pytest
import torch
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm.v1.attention.backends.utils import (
make_local_attention_virtual_batches)
@dataclass
class LocalAttentionTestData:
# Input parameters
batch_spec: BatchSpec
attn_chunk_size: int
block_size: int
# Expected return values
expected_q_seqlens: list[int]
expected_k_seqlens: list[int]
expected_local_block_table: list[list[int]]
test_data_list = [
# Same as example in docstring of make_local_attention_virtual_batches
# except block table has 9 columns instead of 10
LocalAttentionTestData(
batch_spec=BatchSpec(
query_lens=[4, 10, 5],
seq_lens=[6, 17, 9],
),
attn_chunk_size=4,
block_size=2,
expected_q_seqlens=[2, 2, 1, 4, 4, 1, 4, 1],
expected_k_seqlens=[4, 2, 4, 4, 4, 1, 4, 1],
# 2 pages per local branch
# (chunk size 4 // block size 2)
expected_local_block_table=[
[0, 1], # local-batch 0, (batch 0, starting from k[0])
[2, 3], # local-batch 1, (batch 0, starting from k[4])
[11, 12], # local-batch 2, (batch 1, starting from k[4])
[13, 14], # local-batch 3, (batch 1, starting from k[8])
[15, 16], # local-batch 4, (batch 1, starting from k[12])
[17, 17], # local-batch 5, (batch 1, starting from k[16])
[20, 21], # local-batch 6, (batch 2, starting from k[4])
[22, 23], # local-batch 7, (batch 2, starting from k[8])
]),
# Case where block indices are not clipped to block table ncols-1
# because tokens_in_last_block == attn_chunk_size
LocalAttentionTestData(batch_spec=BatchSpec(
query_lens=[8],
seq_lens=[12],
),
attn_chunk_size=4,
block_size=2,
expected_q_seqlens=[4, 4],
expected_k_seqlens=[4, 4],
expected_local_block_table=[
[2, 3],
[4, 5],
]),
# Case where all kv_seq positions are involved in attn
LocalAttentionTestData(
batch_spec=BatchSpec(
query_lens=[7],
# 10 - 7 = 3 previously computed tokens
seq_lens=[10],
),
attn_chunk_size=4,
block_size=2,
expected_q_seqlens=[1, 4, 2],
expected_k_seqlens=[4, 4, 2],
expected_local_block_table=[
[0, 1],
[2, 3],
[4, 4],
]),
# Case where attn_chunk_size > kv_seq_len
# so no extra mini virtual batches are created
LocalAttentionTestData(
batch_spec=BatchSpec(
query_lens=[4],
seq_lens=[6],
),
# Larger than kv_seq_len
attn_chunk_size=10,
block_size=2,
# No change to q_seqlens and k_seqlens
expected_q_seqlens=[4],
expected_k_seqlens=[6],
# In this case, we only need a block-table like:
# block_table = [ [0, 1, 2] ] # 1 batch, 3 pages
# But we need to pad it to 5 pages per local batch
# because currently the pages_per_local_batch
# is calculated as (attn_chunk_size // block_size)
expected_local_block_table=[
[0, 1, 2, 2, 2],
]),
# Block size equal to chunk size
# Expect single page per batch in local batch table
LocalAttentionTestData(
batch_spec=BatchSpec(
query_lens=[6, 6],
seq_lens=[8, 8],
),
attn_chunk_size=4,
block_size=4,
expected_q_seqlens=[2, 4, 2, 4],
expected_k_seqlens=[4, 4, 4, 4],
# Initial block table = [
# [0, 1], < batch 0
# [2, 3], < batch 1
# ]
expected_local_block_table=[
[0], # local-batch 0, (batch 0, starting from k[0])
[1], # local-batch 1, (batch 0, starting from k[4])
[2], # local-batch 1, (batch 0, starting from k[0])
[3], # local-batch 1, (batch 0, starting from k[4])
]),
# Case where query falls in the second attention chunk
# k_toks > 0 1 2 3 4
# q_toks v _____________
# 0 | 1
# 1 | 1 1
# 2 | 1 1 1
# 3 | 1 1 1 1
# 4 | 1
# where tokens 0,1,2,3 have been pre-computed
LocalAttentionTestData(batch_spec=BatchSpec(
query_lens=[1],
seq_lens=[5],
),
attn_chunk_size=4,
block_size=2,
expected_q_seqlens=[1],
expected_k_seqlens=[1],
expected_local_block_table=[
[2, 2],
]),
]
@pytest.mark.parametrize("test_data", test_data_list)
def test_local_attention_virtual_batches(test_data: LocalAttentionTestData):
device = torch.device("cuda:0")
batch_spec = test_data.batch_spec
attn_chunk_size = test_data.attn_chunk_size
block_size = test_data.block_size
expected_q_seqlens = test_data.expected_q_seqlens
expected_k_seqlens = test_data.expected_k_seqlens
expected_local_block_table = test_data.expected_local_block_table
# Create common attention metadata
common_attn_metadata = create_common_attn_metadata(
batch_spec,
block_size,
device,
# Use torch.arange instead of torch.randint so we can assert on
# block table tensor values. The block table will have shape
# (num_batches, cdiv(max_seq_len, block_size)) and the values will be
# aranged from 0 to cdiv(max_seq_len, block_size)-1
arange_block_indices=True,
)
# Call the function
result = make_local_attention_virtual_batches(attn_chunk_size,
common_attn_metadata,
block_size)
# Convert to numpy for easier comparison
actual_q_seqlens = np.diff(result.query_start_loc_cpu.numpy())
actual_k_seqlens = result.seq_lens_cpu.numpy()
# Check that all query lengths are less than or equal to attn_chunk_size
assert all(q_len <= attn_chunk_size for q_len in actual_q_seqlens)
# Check that all key lengths are less than or equal to attn_chunk_size
assert all(k_len <= attn_chunk_size for k_len in actual_k_seqlens)
# Check that the total number of query tokens is preserved
assert sum(actual_q_seqlens) == sum(batch_spec.query_lens)
# Verify results
np.testing.assert_array_equal(actual_q_seqlens, expected_q_seqlens)
np.testing.assert_array_equal(actual_k_seqlens, expected_k_seqlens)
expected_block_table_tensor =\
torch.tensor(expected_local_block_table,
dtype=torch.int32,
device=device)
print(f"Expected block table:\n{expected_block_table_tensor}")
print(f"Actual block table:\n{result.block_table_tensor}")
torch.testing.assert_close(result.block_table_tensor,
expected_block_table_tensor)

View File

@ -40,7 +40,8 @@ def create_common_attn_metadata(
batch_spec: BatchSpec,
block_size: int,
device: torch.device,
max_block_idx: int = 1000) -> CommonAttentionMetadata:
max_block_idx: int = 1000,
arange_block_indices: bool = False) -> CommonAttentionMetadata:
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
# Create query start locations
query_start_loc = torch.zeros(batch_spec.batch_size + 1,
@ -65,19 +66,28 @@ def create_common_attn_metadata(
]
num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32)
# Create block table (random for testing)
# Create block table and slot mapping
max_blocks = (max(batch_spec.seq_lens) + block_size - 1) // block_size
block_table_tensor = torch.randint(0,
max_block_idx,
(batch_spec.batch_size, max_blocks),
dtype=torch.int32,
device=device)
# Create slot mapping
slot_mapping = torch.randint(0,
max_block_idx, (num_tokens, ),
dtype=torch.int64,
device=device)
if arange_block_indices:
num_blocks = batch_spec.batch_size * max_blocks
block_table_tensor = torch.arange(num_blocks,
dtype=torch.int32,
device=device).view(
batch_spec.batch_size,
max_blocks)
slot_mapping = torch.arange(num_tokens,
dtype=torch.int64,
device=device).view(num_tokens)
else:
block_table_tensor = torch.randint(0,
max_block_idx,
(batch_spec.batch_size, max_blocks),
dtype=torch.int32,
device=device)
slot_mapping = torch.randint(0,
max_block_idx, (num_tokens, ),
dtype=torch.int64,
device=device)
# Calculate max query length
max_query_len = max(batch_spec.query_lens)

View File

@ -3,29 +3,34 @@
from __future__ import annotations
import random
from typing import Any
from typing import Any, Union
import pytest
import torch
from vllm import LLM, SamplingParams
from vllm.assets.base import VLLM_S3_BUCKET_URL
from vllm.assets.image import VLM_IMAGES_DIR
from vllm.distributed import cleanup_dist_env_and_memory
@pytest.fixture
def test_prompts():
def get_test_prompts(mm_enabled: bool):
prompt_types = ["repeat", "sentence"]
if mm_enabled:
prompt_types.append("mm")
num_prompts = 100
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
print(f"Prompt types: {random_prompt_type_choices}")
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
prompt: Union[str, list[dict[str, Any]]] = ""
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
@ -38,6 +43,21 @@ def test_prompts():
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
elif kind == "mm":
placeholders = [{
"type": "image_url",
"image_url": {
"url":
f"{VLLM_S3_BUCKET_URL}/{VLM_IMAGES_DIR}/stop_sign.jpg"
},
}]
prompt = [
*placeholders,
{
"type": "text",
"text": "The meaning of the image is"
},
]
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
@ -57,7 +77,6 @@ def model_name():
def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
@ -67,6 +86,7 @@ def test_ngram_correctness(
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
test_prompts = get_test_prompts(mm_enabled=False)
ref_llm = LLM(model=model_name, max_model_len=1024)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
@ -103,23 +123,32 @@ def test_ngram_correctness(
cleanup_dist_env_and_memory()
@pytest.mark.parametrize("model_setup", [
("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1),
("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"])
@pytest.mark.parametrize(
["model_setup", "mm_enabled"], [
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
False,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
pytest.param(
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
True,
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
],
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_setup: tuple[str, str, str, int],
mm_enabled: bool,
):
# Generate test prompts inside the function instead of using fixture
test_prompts = get_test_prompts(mm_enabled)
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.

View File

@ -37,6 +37,8 @@ logger = init_logger(__name__)
ALLREDUCE_OP = torch.ops.vllm.all_reduce.default
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
STATIC_FP8_QUANT_OP = torch.ops._C.static_scaled_fp8_quant.default
STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default
class BasePattern:
@ -394,7 +396,7 @@ if flashinfer_comm is not None:
# Max size of the input tensor per world size
# to use flashinfer fused allreduce
_FI_MAX_SIZES = {
2: MiB, # 1MB
2: 64 * MiB, # 64MB
4: MiB, # 1MB
6: MiB // 2, # 512KB
8: MiB // 2, # 512KB
@ -414,9 +416,13 @@ if flashinfer_comm is not None:
trigger_completion_at_end: bool,
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
fuse_rms_quant: bool,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
scale_out: Optional[torch.Tensor] = None,
scale_factor: Optional[torch.Tensor] = None,
) -> None:
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
@ -425,7 +431,6 @@ if flashinfer_comm is not None:
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
max_fusion_size,
)
if use_flashinfer:
assert (_FI_WORKSPACE_TENSOR is not None
), "Flashinfer must be enabled when using flashinfer"
@ -455,37 +460,65 @@ if flashinfer_comm is not None:
use_oneshot=True,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNorm,
pattern_code=pattern_code,
allreduce_out=None,
quant_out=None,
scale_out=None,
layout_code=None,
scale_factor=None,
quant_out=quant_out,
scale_out=scale_out,
# in vllm we only support swizzled layout
layout_code=flashinfer_comm.FP4QuantizationSFLayout.SWIZZLED,
scale_factor=scale_factor,
)
else:
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
if norm_out is None:
torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
rms_gamma, rms_eps)
if (scale_factor is not None and scale_out is None
and fuse_rms_quant):
# Do fused rms norm static fp8 quant fused op
if norm_out is None:
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
quant_out, allreduce_out, residual, rms_gamma,
scale_factor, rms_eps)
else:
torch.ops._C.rms_norm_static_fp8_quant(
quant_out, allreduce_out, rms_gamma, scale_factor,
rms_eps)
else:
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
rms_eps)
allreduce_in.copy_(allreduce_out)
if norm_out is None:
torch.ops._C.fused_add_rms_norm(allreduce_out, residual,
rms_gamma, rms_eps)
norm_out = allreduce_out
else:
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma,
rms_eps)
if scale_factor is not None:
if scale_out is not None:
torch.ops._C.scaled_fp4_quant(quant_out, norm_out,
scale_out, scale_factor)
else:
torch.ops._C.static_scaled_fp8_quant(
quant_out, norm_out, scale_factor)
if scale_factor is None or norm_out is not None:
# we need to return allreduce outpput
# in cases of non quant fused AR + RMS norm
# and fused AR + RMS norm + quant without fused add
allreduce_in.copy_(allreduce_out)
def call_trtllm_fused_allreduce_norm_fake(
allreduce_in: torch.Tensor,
residual: torch.Tensor,
rms_gamma: torch.Tensor,
rms_eps: float,
world_rank: int,
world_size: int,
launch_with_pdl: bool,
trigger_completion_at_end: bool,
fp32_acc: bool,
max_token_num: int,
norm_out: Optional[torch.Tensor] = None,
) -> None:
allreduce_in: torch.Tensor,
residual: torch.Tensor,
rms_gamma: torch.Tensor,
rms_eps: float,
world_rank: int,
world_size: int,
launch_with_pdl: bool,
trigger_completion_at_end: bool,
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
fuse_rms_quant: bool,
norm_out: Optional[torch.Tensor] = None,
quant_out: Optional[torch.Tensor] = None,
scale_out: Optional[torch.Tensor] = None,
scale_factor: Optional[torch.Tensor] = None) -> None:
pass
direct_register_custom_op(
@ -495,6 +528,8 @@ if flashinfer_comm is not None:
"allreduce_in",
"residual",
"norm_out",
"quant_out",
"scale_out",
],
fake_impl=call_trtllm_fused_allreduce_norm_fake,
dispatch_key=current_platform.dispatch_key,
@ -512,6 +547,7 @@ class FlashInferFusedAllReduceParams:
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
fuse_rms_quant: bool = False,
):
self.rank = rank
self.world_size = world_size
@ -521,6 +557,7 @@ class FlashInferFusedAllReduceParams:
self.fp32_acc = True
self.use_oneshot = False
self.max_token_num = max_token_num
self.fuse_rms_quant = fuse_rms_quant
def get_trtllm_fused_allreduce_kwargs(self):
return {
@ -530,10 +567,16 @@ class FlashInferFusedAllReduceParams:
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc,
"max_token_num": self.max_token_num,
"fuse_rms_quant": self.fuse_rms_quant,
}
class AllReduceRMSNORMPattern(BasePattern):
class AllReduceRMSNormPattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
with fused flashinfer implementation.
Applies to allreduce + rmsnorm before attn in the first Transformer block.
"""
def __init__(
self,
@ -559,29 +602,34 @@ class AllReduceRMSNORMPattern(BasePattern):
def pattern(input: torch.Tensor, rms_result: torch.Tensor,
weight: torch.Tensor):
all_reduce_output = tensor_model_parallel_all_reduce(input)
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized(
RMS_OP,
result=rms_result,
input=all_reduce_output,
input=allreduce_output,
weight=weight,
epsilon=self.epsilon,
)
return rms[1], all_reduce_output
# rms_result, allreduce_output
return rms[1], allreduce_output
def replacement(input: torch.Tensor, rms_result: torch.Tensor,
weight: torch.Tensor):
residual = torch.zeros_like(input)
allreduce = auto_functionalized(
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=rms_result,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# rms_result, allreduce_in
return allreduce[3], allreduce[1]
pm.register_replacement(pattern, replacement, self.get_inputs(),
@ -589,6 +637,11 @@ class AllReduceRMSNORMPattern(BasePattern):
class AllReduceFusedAddRMSNormPattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
"""
def __init__(
self,
@ -615,33 +668,390 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
def pattern(residual: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor):
all_reduce_output = tensor_model_parallel_all_reduce(input)
allreduce_output = tensor_model_parallel_all_reduce(input)
rms = auto_functionalized(
RMS_ADD_OP,
input=all_reduce_output,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon,
)
# input, residual
return rms[1], rms[2]
def replacement(residual: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor):
allreduce = auto_functionalized(
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=None,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
norm_out=None,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNorm,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# allreduce_in, residual
return allreduce[1], allreduce[2]
pm.register_replacement(pattern, replacement, self.get_inputs(),
pm.fwd_only, pm_pass)
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
+ static fp8 quant with fused flashinfer implementation.
Applies to allreduce + rmsnorm + quant before attn
in the first Transformer block.
"""
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
allreduce_params: FlashInferFusedAllReduceParams):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.zeros([1, 8, 4],
device=self.device,
dtype=self.dtype)
rmsnorm_result = torch.empty([1, 8, 4],
device=self.device,
dtype=self.dtype)
quant_result = torch.empty([1, 8, 4],
device=self.device,
dtype=self.quant_dtype)
weight = torch.empty([4], device=self.device, dtype=self.dtype)
scale = torch.tensor(1.0, device=self.device, dtype=torch.float32)
return [input, rmsnorm_result, quant_result, weight, scale]
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
all_reduce = tensor_model_parallel_all_reduce(input)
rmsnorm_out_tuple = auto_functionalized(RMS_OP,
result=rmsnorm_result,
input=all_reduce,
weight=weight,
epsilon=self.epsilon)
quant_out_tuple = auto_functionalized(STATIC_FP8_QUANT_OP,
result=quant_result,
input=rmsnorm_out_tuple[1],
scale=scale)
# quant_out, allreduce_output
return quant_out_tuple[1], all_reduce
def replacement(input: torch.Tensor, result_rms: torch.Tensor,
quant_result: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
residual = torch.zeros_like(input)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=quant_result,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
scale_factor=scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, allreduce_output
return allreduce[4], allreduce[1]
pm.register_replacement(pattern, replacement, get_inputs(),
pm.fwd_only, pm_pass)
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
+ static fp8 quant with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn + quant and
mlp + rmsnorm + quant before attn.
"""
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
allreduce_params: FlashInferFusedAllReduceParams):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
self.quant_dtype = torch.float8_e4m3fn
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([4, 4], device=self.device, dtype=self.dtype)
residual = torch.empty([4, 4],
device=self.device,
dtype=self.dtype)
weight = torch.empty([4, 4], device=self.device, dtype=self.dtype)
quant_result = torch.empty([4, 4],
device=self.device,
dtype=self.quant_dtype)
scale = torch.empty([1, 1],
device=self.device,
dtype=torch.float32)
return [
quant_result,
residual,
input,
weight,
scale,
]
def pattern(
quant_result: torch.Tensor,
residual: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
):
allreduce_output = tensor_model_parallel_all_reduce(input)
fused_add_rmsnorm_out_tuple = \
auto_functionalized(
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon)
quant_out_tuple = auto_functionalized(
STATIC_FP8_QUANT_OP,
result=quant_result,
input=fused_add_rmsnorm_out_tuple[1],
scale=scale)
# quant_out, allreduce_output
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[2]
def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=quant_result,
scale_out=None,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNormFP8Quant, # we don't use norm_out afterwards
scale_factor=scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# # quant_out, rms_norm_residual
return allreduce[4], allreduce[2]
pm.register_replacement(pattern, replacement, get_inputs(),
pm.fwd_only, pm_pass)
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (without residual)
+ static nvfp4 quant with fused flashinfer implementation.
Applies to allreduce + rmsnorm + quant before attn
in the first Transformer block.
"""
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
allreduce_params: FlashInferFusedAllReduceParams):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([1, 16, 16],
device=self.device,
dtype=self.dtype)
rmsnorm_result = torch.empty([1, 16, 16],
device=self.device,
dtype=self.dtype)
quant_result = torch.empty((16, 8),
device=self.device,
dtype=torch.uint8)
input_global_scale = torch.empty([1, 1],
device=self.device,
dtype=torch.float32)
weight = torch.empty([16], device=self.device, dtype=self.dtype)
output_scale = torch.empty([128, 4],
device=self.device,
dtype=torch.int32)
return [
input, rmsnorm_result, quant_result, weight,
input_global_scale, output_scale
]
def pattern(
input: torch.Tensor,
rmsnorm_result: torch.Tensor,
quant_result: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor,
):
all_reduce = tensor_model_parallel_all_reduce(input)
rmsnorm_out_tuple = auto_functionalized(RMS_OP,
result=rmsnorm_result,
input=all_reduce,
weight=weight,
epsilon=self.epsilon)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=rmsnorm_out_tuple[1],
output_scale=output_scale,
input_scale=input_global_scale)
# quant_out, allreduce_output, output_scale
return quant_out_tuple[1], all_reduce, quant_out_tuple[2]
def replacement(input: torch.Tensor, result_rms: torch.Tensor,
quant_result: torch.Tensor, weight: torch.Tensor,
input_global_scale: torch.Tensor,
output_scale: torch.Tensor):
residual = torch.zeros_like(input)
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=result_rms,
quant_out=quant_result,
scale_out=output_scale,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
scale_factor=input_global_scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, allreduce_output, output_scale
return allreduce[4], allreduce[1], allreduce[5]
pm.register_replacement(pattern, replacement, get_inputs(),
pm.fwd_only, pm_pass)
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
"""
This pattern replaces the allreduce + rms norm (with residual)
+ static nvfp4 quant with fused flashinfer implementation.
Applies to o_proj + rmsnorm after attn + quant and
mlp + rmsnorm + quant before attn.
"""
def __init__(self, epsilon: float, dtype: torch.dtype, device: str,
allreduce_params: FlashInferFusedAllReduceParams):
super().__init__(dtype, device)
self.epsilon = epsilon
self.allreduce_params = allreduce_params
def register(self, pm_pass: PatternMatcherPass):
def get_inputs():
input = torch.empty([16, 16], device=self.device, dtype=self.dtype)
residual = torch.empty([16, 16],
device=self.device,
dtype=self.dtype)
weight = torch.empty([16, 16],
device=self.device,
dtype=self.dtype)
quant_result = torch.empty((16, 8),
device=self.device,
dtype=torch.uint8)
input_global_scale = torch.empty([1, 1],
device=self.device,
dtype=torch.float32)
output_scale = torch.empty([128, 4],
device=self.device,
dtype=torch.int32)
return [
quant_result,
residual,
input,
output_scale,
weight,
input_global_scale,
]
def pattern(quant_result: torch.Tensor, residual: torch.Tensor,
input: torch.Tensor, output_scale: torch.Tensor,
weight: torch.Tensor, input_global_scale: torch.Tensor):
allreduce_output = tensor_model_parallel_all_reduce(input)
fused_add_rmsnorm_out_tuple = \
auto_functionalized(
RMS_ADD_OP,
input=allreduce_output,
residual=residual,
weight=weight,
epsilon=self.epsilon)
quant_out_tuple = auto_functionalized(
STATIC_FP4_QUANT_OP,
output=quant_result,
input=fused_add_rmsnorm_out_tuple[1],
output_scale=output_scale,
input_scale=input_global_scale)
# quant_out, allreduce_output, output_scale
return quant_out_tuple[1], fused_add_rmsnorm_out_tuple[
2], quant_out_tuple[2]
def replacement(quant_result: torch.Tensor, residual: torch.Tensor,
input: torch.Tensor, output_scale: torch.Tensor,
weight: torch.Tensor,
input_global_scale: torch.Tensor):
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
residual=residual,
norm_out=None,
quant_out=quant_result,
scale_out=output_scale,
rms_gamma=weight,
rms_eps=self.epsilon,
pattern_code=flashinfer_comm.AllReduceFusionPattern.
kARResidualRMSNormFP4Quant, # we don't use norm_out afterwards
scale_factor=input_global_scale,
**self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
)
# quant_out, rms_norm_residual, output_scale
return allreduce[4], allreduce[2], allreduce[5]
pm.register_replacement(pattern, replacement, get_inputs(),
pm.fwd_only, pm_pass)
class AllReduceFusionPass(VllmInductorPass):
def __init__(self, config: VllmConfig):
@ -671,13 +1081,16 @@ class AllReduceFusionPass(VllmInductorPass):
self.tp_size,
)
return
max_num_token = min(
_FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE) //
(self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
config.compilation_config.pass_config.
fi_allreduce_fusion_max_token_num)
self.ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=self.tp_size,
max_token_num=config.compilation_config.pass_config.
fi_allreduce_fusion_max_token_num,
max_token_num=max_num_token,
hidden_dim=self.hidden_dim,
group=self.group,
use_fp32_lamport=use_fp32_lamport,
@ -689,12 +1102,38 @@ class AllReduceFusionPass(VllmInductorPass):
rank=rank,
world_size=self.tp_size,
use_fp32_lamport=use_fp32_lamport,
max_token_num=config.compilation_config.pass_config.
fi_allreduce_fusion_max_token_num,
)
max_token_num=max_num_token,
# fuse rms norm static fp8 quant fused op
# in fallback path, when we don't use flashinfer
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion)
for epsilon in [1e-5, 1e-6]:
AllReduceRMSNORMPattern(
AllReduceFusedRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantFP8Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
if current_platform.has_device_capability(100):
AllReduceFusedRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(
epsilon,
self.model_dtype,
self.device,
self.allreduce_params,
).register(self.patterns)
AllReduceRMSNormPattern(
epsilon,
self.model_dtype,
self.device,
@ -707,6 +1146,10 @@ class AllReduceFusionPass(VllmInductorPass):
self.allreduce_params,
).register(self.patterns)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
self.disabled = False
def __call__(self, graph: fx.Graph):
@ -723,5 +1166,5 @@ class AllReduceFusionPass(VllmInductorPass):
if self.disabled:
return
if flashinfer_comm is not None:
flashinfer_comm.trtllm_destroy_ipc_workspace(
flashinfer_comm.trtllm_destroy_ipc_workspace_for_all_reduce(
self.ipc_handles, self.group)

View File

@ -108,7 +108,7 @@ def support_torch_compile(
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer (can be negative), the corresponding dimension
- if it is a single integer (can be negative), the corresponding dimension
of the argument will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate

View File

@ -4062,7 +4062,7 @@ class PassConfig:
"""Whether to enable async TP."""
enable_fi_allreduce_fusion: bool = False
"""Whether to enable flashinfer allreduce fusion."""
fi_allreduce_fusion_max_token_num: int = 1024
fi_allreduce_fusion_max_token_num: int = 16384
"""Max number of tokens to used in flashinfer allreduce fusion."""
# TODO(luka) better pass enabling system.

View File

@ -18,6 +18,7 @@ from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser
from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser
__all__ = [
@ -40,4 +41,5 @@ __all__ = [
"HunyuanA13BToolParser",
"Glm4MoeModelToolParser",
"Qwen3CoderToolParser",
"Step3ToolParser",
]

View File

@ -0,0 +1,296 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import json
from collections.abc import Sequence
from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ToolParserManager.register_module(["step3"])
class Step3ToolParser(ToolParser):
"""
Tool parser for a model that uses a specific XML-like format for tool calls.
This version uses a robust, stateful, cursor-based streaming parser and
consolidates tool arguments into a single message.
"""
TOOL_CALLS_BEGIN = "<tool_calls_begin>"
TOOL_CALLS_END = "<tool_calls_end>"
TOOL_CALL_BEGIN = "<tool_call_begin>"
TOOL_CALL_END = "<tool_call_end>"
TOOL_SEP = "<tool_sep>"
SPECIAL_TOKENS = [
TOOL_CALLS_BEGIN, TOOL_CALLS_END, TOOL_CALL_BEGIN, TOOL_CALL_END
]
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.position = 0
# Explicit state flags for robust streaming
self.tool_block_started = False
self.tool_block_finished = False
def adjust_request(
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
if request.tools and request.tool_choice != 'none':
request.skip_special_tokens = False
return request
@staticmethod
def _parse_steptml_invoke(
action_text: str
) -> tuple[Optional[str], Optional[dict[str, str]]]:
func_name_match = re.search(r'<steptml:invoke name="([^"]+)">',
action_text)
if not func_name_match:
return None, None
func_name = func_name_match.group(1)
params: dict[str, str] = {}
param_matches = re.findall(
r'<steptml:parameter name="([^"]+)">([^<]*)</steptml:parameter>',
action_text)
for name, value in param_matches:
params[name] = value.strip()
return func_name, params
def _cast_arguments(
self,
func_name: str,
params: dict[str, Any],
request: ChatCompletionRequest,
) -> dict[str, Any]:
for tool in request.tools or []:
if tool.function.name == func_name:
schema = tool.function.parameters or {}
properties = schema.get("properties", {})
for key, value in params.items():
if not isinstance(value, str):
continue
prop = properties.get(key, {})
typ = prop.get("type")
if typ == "string":
params[key] = value.strip()
elif typ == "integer":
with contextlib.suppress(ValueError):
params[key] = int(value)
elif typ == "number":
with contextlib.suppress(ValueError):
params[key] = float(value)
elif typ == "boolean":
lower_val = value.lower()
params[key] = lower_val == "true" if lower_val in (
"true", "false") else value
elif typ == "null":
params[key] = None if value.lower(
) == "null" else value
break
return params
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# The main loop processes the stream from the last known position.
while True:
if self.position >= len(current_text):
return None # We've processed the entire stream.
unprocessed_text = current_text[self.position:]
# STATE: After all tools are done, all subsequent text is content.
if self.tool_block_finished:
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
# STATE: Before the tool block has started.
if not self.tool_block_started:
if unprocessed_text.startswith(self.TOOL_CALLS_BEGIN):
self.position += len(self.TOOL_CALLS_BEGIN)
self.tool_block_started = True
continue # Token consumed, re-loop.
start_pos = unprocessed_text.find(self.TOOL_CALLS_BEGIN)
if start_pos == -1:
if self.TOOL_CALLS_BEGIN.startswith(
unprocessed_text.strip()) and unprocessed_text:
return None # It's a prefix, wait.
self.position = len(current_text)
return DeltaMessage(content=unprocessed_text)
else:
content = unprocessed_text[:start_pos]
self.position += len(content)
return DeltaMessage(content=content)
# STATE: Inside the main tool block.
offset = len(unprocessed_text) - len(unprocessed_text.lstrip())
unprocessed_text = unprocessed_text.lstrip()
self.position += offset
if unprocessed_text.startswith(self.TOOL_CALLS_END):
self.position += len(self.TOOL_CALLS_END)
self.tool_block_finished = True
self.current_tool_id = -1
continue
# Check if we are between tool calls.
tool_finished = (
self.current_tool_id != -1 and
self.prev_tool_call_arr[self.current_tool_id].get("finished"))
if self.current_tool_id == -1 or tool_finished:
if unprocessed_text.startswith(self.TOOL_CALL_BEGIN):
self.position += len(self.TOOL_CALL_BEGIN)
if self.current_tool_id == -1:
self.current_tool_id = 0
else:
self.current_tool_id += 1
self.current_tool_name_sent = False
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
self.prev_tool_call_arr[
self.current_tool_id]["finished"] = False
continue
if self.TOOL_CALL_BEGIN.startswith(unprocessed_text):
return None
# STATE: Parsing an active tool call.
if self.current_tool_id != -1 and not self.prev_tool_call_arr[
self.current_tool_id].get("finished", False):
end_tool_pos = unprocessed_text.find(self.TOOL_CALL_END)
if end_tool_pos == -1:
tool_body = unprocessed_text
else:
tool_body = unprocessed_text[:end_tool_pos]
if end_tool_pos == -1 and self.TOOL_CALL_END.startswith(
tool_body):
return None
function_name, arguments = self._parse_steptml_invoke(
tool_body)
if not function_name:
return None
tool_call_arr = {
"name": function_name,
"parameters": arguments or {}
}
# Send the function name as soon as it's parsed.
if not self.current_tool_name_sent:
self.current_tool_name_sent = True
self.prev_tool_call_arr[self.current_tool_id].update(
tool_call_arr)
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=f"chatcmpl-tool-{random_uuid()}",
function=DeltaFunctionCall(
name=function_name))
])
# Update our internal state with the latest parsed arguments.
self.prev_tool_call_arr[
self.current_tool_id].update( # noqa: E501
tool_call_arr)
# Only send arguments when the tool call is complete.
if end_tool_pos != -1:
self.position += end_tool_pos + len(self.TOOL_CALL_END)
self.prev_tool_call_arr[
self.current_tool_id]["finished"] = True
final_args = self._cast_arguments(
function_name,
tool_call_arr.get("parameters", {}), # type: ignore
request)
if final_args:
final_args_json = json.dumps(final_args,
ensure_ascii=False)
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=final_args_json))
])
# If tool is not finished, return None to wait for more tokens.
return None
return None
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
if self.TOOL_CALLS_BEGIN not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
pre_text, rest = model_output.split(self.TOOL_CALLS_BEGIN, 1)
if self.TOOL_CALLS_END not in rest:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
tool_block, post_text = rest.split(self.TOOL_CALLS_END, 1)
content = (pre_text + post_text).strip()
tool_calls: list[ToolCall] = []
call_parts = tool_block.split(self.TOOL_CALL_BEGIN)
for part in call_parts:
if not part or self.TOOL_CALL_END not in part:
continue
call_content = part.split(self.TOOL_CALL_END, 1)[0]
if self.TOOL_SEP not in call_content:
continue
type_part, invoke_part = call_content.split(self.TOOL_SEP, 1)
if type_part.strip() != "function":
continue
function_name, params_dict = self._parse_steptml_invoke(
invoke_part)
if function_name and params_dict is not None:
params_dict = self._cast_arguments(function_name, params_dict,
request)
params_str = json.dumps(params_dict, ensure_ascii=False)
tool_calls.append(
ToolCall(function=FunctionCall(name=function_name,
arguments=params_str)))
if tool_calls:
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None)
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)

View File

@ -182,6 +182,8 @@ class DeepSeekMTP(nn.Module, SupportsPP):
stacked_params_mapping = [
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
expert_params_mapping = FusedMoE.make_expert_params_mapping(
@ -212,6 +214,13 @@ class DeepSeekMTP(nn.Module, SupportsPP):
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
if ((param_name == "fused_qkv_a_proj")
and name not in params_dict):
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

View File

@ -256,6 +256,7 @@ class Llama4DecoderLayer(nn.Module):
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.global_layer = config.no_rope_layers[self.layer_idx] == 0
self.hidden_size = config.hidden_size
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling

View File

@ -37,8 +37,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama4 import (Llama4DecoderLayer,
Llama4ForCausalLM)
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.inputs import NestedTensors
from .utils import AutoWeightsLoader, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix, merge_multimodal_embeddings
logger = init_logger(__name__)
@ -78,15 +79,23 @@ class LlamaModel(nn.Module):
self.norm = RMSNorm(self.config.hidden_size,
eps=self.config.rms_norm_eps)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
torch.cat((inputs_embeds, hidden_states), dim=-1))
residual = None
for layer in self.layers:
hidden_states, residual = layer(
@ -190,8 +199,9 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states)
return self.model(input_ids, positions, hidden_states, inputs_embeds)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> None:
@ -212,3 +222,20 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
self.config.image_token_index,
)
return inputs_embeds

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional
import torch
import torch.nn as nn
@ -148,7 +149,12 @@ class EagleLlamaForCausalLM(LlamaForCausalLM):
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

View File

@ -202,7 +202,12 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if inputs_embeds is not None:
raise NotImplementedError(
f"{type(self).__name__} does not support multimodal inputs yet."
)
return self.model(input_ids, positions, hidden_states)
def compute_logits(

View File

@ -129,6 +129,7 @@ _TEXT_GENERATION_MODELS = {
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
@ -238,6 +239,7 @@ _MULTIMODAL_MODELS = {
"Qwen2_5OmniModel": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"Qwen2_5OmniForConditionalGeneration": ("qwen2_5_omni_thinker", "Qwen2_5OmniThinkerForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
"Step3VLForConditionalGeneration": ("step3_vl", "Step3VLForConditionalGeneration"), # noqa: E501
"TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501
"Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501

View File

@ -0,0 +1,521 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only Jurassic model."""
from collections.abc import Iterable
from typing import Any, Optional
import torch
from torch import nn
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__)
class FusedMoEBlock(nn.Module):
def __init__(self,
config: ModelConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.moe_num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.moe_num_experts}.")
self.experts = FusedMoE(num_experts=config.moe_num_experts,
top_k=config.moe_top_k,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_expert_weight,
quant_config=quant_config,
prefix=f"{prefix}.experts")
self.gate = ReplicatedLinear(config.hidden_size,
config.moe_num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(orig_shape)
class Step3TextMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
self.hidden_size = hidden_size
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(hidden_states)
intermediate_act = self.act_fn(gate_up)
output, _ = self.down_proj(intermediate_act)
return output
class Step3TextAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
norm_eps: float,
rope_theta: int,
share_q_dim: Optional[int] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embedding: int = 8192,
head_dim: int = 256,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
if num_kv_heads != 1:
raise ValueError(f"Step3TextAttention num_kv_heads must be 1, "
f"but got {num_kv_heads}.")
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.q_size = share_q_dim if share_q_dim else self.head_dim
self.qkv_proj = ReplicatedLinear(
hidden_size,
self.q_size + self.kv_size * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.inter_norm = RMSNorm(self.q_size, eps=norm_eps)
self.wq = ColumnParallelLinear(
self.q_size,
self.head_dim * self.total_num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq",
)
self.rotary_emb = get_rope(self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embedding,
base=rope_theta,
rope_scaling=rope_scaling)
scaling = self.head_dim**-0.5
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
self.num_kv_heads,
cache_config=cache_config,
prefix=f"{prefix}.attn")
def forward(self, positions: torch.Tensor,
hidden_states: torch.Tensor) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.inter_norm(q)
q = self.wq(q)[0]
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
residual, _ = self.o_proj(attn_output)
return residual
class Step3TextDecoderLayer(nn.Module):
def __init__(self,
config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
config = config.hf_config
self.hidden_size = config.hidden_size
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = Step3TextAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
norm_eps=config.rms_norm_eps,
max_position_embedding=config.max_position_embedding,
head_dim=config.head_dim,
share_q_dim=config.share_q_dim,
rope_theta=config.rope_theta,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn")
layer_idx = int(prefix.split("layers.")[1].split(".")[0])
moe_layers_enum = getattr(config, "moe_layers_enum", None)
if moe_layers_enum is not None:
moe_layers_idx = [
int(i) for i in moe_layers_enum.strip().split(',')
]
else:
# Default to 1dense.
moe_layers_idx = [i for i in range(1, config.num_hidden_layers)]
if layer_idx in moe_layers_idx:
self.moe = FusedMoEBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.moe")
self.share_expert = Step3TextMLP(
hidden_size=self.hidden_size,
intermediate_size=config.share_expert_dim,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.share_expert")
self.use_moe = True
else:
self.mlp = Step3TextMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act="silu",
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.use_moe = False
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self, positions: torch.Tensor, hidden_states: torch.Tensor,
residual: Optional[torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
if self.use_moe:
share_output = self.share_expert(hidden_states)
moe_output = self.moe(hidden_states)
hidden_states = share_output + moe_output
else:
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class Step3TextModel(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.config = config
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Step3TextDecoderLayer(config=vllm_config.
model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual,
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Step3TextForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
lora_config = vllm_config.lora_config
self.config = config
self.vllm_config = vllm_config
self.model = Step3TextModel(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = get_sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None):
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
qkv_params_mapping = [
# (param_name, shard_name, relative_start_idx, relative_end_idx)
(".qkv_proj", ".q_proj", 0, self.config.share_q_dim /
(self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".k_proj", self.config.share_q_dim /
(self.config.share_q_dim + self.config.head_dim * 2),
(self.config.share_q_dim + self.config.head_dim) /
(self.config.share_q_dim + self.config.head_dim * 2)),
(".qkv_proj", ".v_proj",
(self.config.share_q_dim + self.config.head_dim) /
(self.config.share_q_dim + self.config.head_dim * 2),
(self.config.share_q_dim + self.config.head_dim * 2) /
(self.config.share_q_dim + self.config.head_dim * 2)),
]
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = [
(".moe.experts.w13_weight", ".moe.gate_proj.weight", "w1"),
(".moe.experts.w13_weight", ".moe.up_proj.weight", "w3"),
(".moe.experts.w2_weight", ".moe.down_proj.weight", "w2")
]
disable_moe_stacked_params = [
data[1] for data in expert_params_mapping
]
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
if any(disable_moe_stacked_param in name
for disable_moe_stacked_param in
disable_moe_stacked_params):
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
for expert_id in range(loaded_weight.shape[0]):
loaded_weight_expert = loaded_weight[expert_id]
weight_loader(param,
loaded_weight_expert,
name,
shard_id=shard_id,
expert_id=expert_id)
loaded_params.add(name)
break
else:
for (param_name, weight_name, start_idx,
end_idx) in qkv_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
dim = param.shape[param.output_dim]
begin_idx = int(start_idx * dim)
end_idx = int(end_idx * dim)
param_slice = param.narrow(param.output_dim, begin_idx,
end_idx - begin_idx)
param_slice.copy_(loaded_weight)
loaded_params.add(name)
break
else:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

File diff suppressed because it is too large Load Diff

View File

@ -8,6 +8,7 @@ from .granite_reasoning_parser import GraniteReasoningParser
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
from .mistral_reasoning_parser import MistralReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser
from .step3_reasoning_parser import Step3ReasoningParser
__all__ = [
"ReasoningParser",
@ -18,4 +19,5 @@ __all__ = [
"Qwen3ReasoningParser",
"Glm4MoeModelReasoningParser",
"MistralReasoningParser",
"Step3ReasoningParser",
]

View File

@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Optional, Union
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
@ReasoningParserManager.register_module("step3")
class Step3ReasoningParser(ReasoningParser):
"""
Reasoning parser for Step3 model.
The Step3 model uses </think> token to denote the end of reasoning
text. This parser extracts all content before </think> as reasoning content.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.think_end_token = "</think>"
self.reasoning_regex = re.compile(rf"(.*?){self.think_end_token}",
re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
self.think_end_token_id = self.vocab.get(self.think_end_token)
if self.think_end_token_id is None:
raise RuntimeError(
"Step3 reasoning parser could not locate think end "
"token in the tokenizer!")
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""
Extract reasoning content from a delta message.
Handles streaming output where previous + delta = current.
Uses token IDs for faster processing.
For text "abc</think>xyz":
- 'abc' goes to reasoning_content
- 'xyz' goes to content
"""
# Skip single special token
if len(delta_token_ids
) == 1 and delta_token_ids[0] == self.think_end_token_id:
return None
if self.think_end_token_id in delta_token_ids:
# </think> in delta, extract reasoning content and remaining content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
# </think> already seen in previous text, everything is content
return DeltaMessage(content=delta_text)
else:
# No </think> seen yet, everything is reasoning
return DeltaMessage(reasoning_content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
# Check if the model output contains the </think> token
if self.think_end_token not in model_output:
# If no </think> token, everything is reasoning content
return model_output, None
else:
# Find the first occurrence of </think>
end_index = model_output.find(self.think_end_token)
reasoning_content = model_output[:end_index]
# Content after </think> token
content = model_output[end_index + len(self.think_end_token):]
if len(content) == 0:
content = None
return reasoning_content, content
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
if self.think_end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:]

View File

@ -35,7 +35,8 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
MllamaConfig, MLPSpeculatorConfig,
Nemotron_Nano_VL_Config,
NemotronConfig, NVLM_D_Config,
RWConfig, UltravoxConfig)
RWConfig, Step3TextConfig,
Step3VLConfig, UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file
@ -83,6 +84,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config,
"ultravox": UltravoxConfig,
"step3_vl": Step3VLConfig,
"step3_text": Step3TextConfig,
**_CONFIG_REGISTRY_OVERRIDE_HF
}

View File

@ -24,6 +24,9 @@ from vllm.transformers_utils.configs.nemotron import NemotronConfig
from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig
from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config
from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
from vllm.transformers_utils.configs.step3_vl import (Step3TextConfig,
Step3VisionEncoderConfig,
Step3VLConfig)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
@ -42,4 +45,7 @@ __all__ = [
"Nemotron_Nano_VL_Config",
"NVLM_D_Config",
"UltravoxConfig",
"Step3VLConfig",
"Step3VisionEncoderConfig",
"Step3TextConfig",
]

View File

@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional, Union
from transformers.configuration_utils import PretrainedConfig
class Step3VisionEncoderConfig(PretrainedConfig):
model_type = "step3_vision_encoder"
def __init__(
self,
hidden_size=1792,
intermediate_size=3072,
output_hidden_size=4096,
num_hidden_layers=63,
num_attention_heads=16,
num_channels=3,
image_size=728,
patch_size=14,
hidden_act="quick_gelu",
layer_norm_eps=1e-5,
**kwargs,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.output_hidden_size = output_hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
super().__init__(**kwargs)
class Step3TextConfig(PretrainedConfig):
model_type = "step3_text"
architectures = ["Step3TextForCausalLM"]
def __init__(
self,
hidden_size: int = 7168,
intermediate_size: int = 18432,
num_attention_heads: int = 64,
num_attention_groups: int = 1,
num_hidden_layers: int = 61,
max_seq_len: int = 65536,
vocab_size: int = 128815,
rms_norm_eps: float = 1e-5,
moe_intermediate_size: int = 5120,
moe_num_experts: int = 48,
moe_top_k: int = 3,
rope_theta: float = 500000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embedding: int = 65536,
share_expert_dim: int = 5120,
share_q_dim: int = 2048,
head_dim: int = 256,
norm_expert_weight: bool = False,
moe_layers_enum: tuple[int,
...] = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 58, 59),
**kwargs,
) -> None:
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_attention_groups = num_attention_groups
self.num_hidden_layers = num_hidden_layers
self.max_seq_len = max_seq_len
self.vocab_size = vocab_size
self.rms_norm_eps = rms_norm_eps
self.moe_intermediate_size = moe_intermediate_size
self.moe_num_experts = moe_num_experts
self.moe_top_k = moe_top_k
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.max_position_embedding = max_position_embedding
self.share_expert_dim = share_expert_dim
self.share_q_dim = share_q_dim
self.head_dim = head_dim
self.norm_expert_weight = norm_expert_weight
self.moe_layers_enum = moe_layers_enum
super().__init__(**kwargs)
class Step3VLConfig(PretrainedConfig):
model_type = "step3_vl"
def __init__(
self,
vision_config: Optional[Union[dict, Step3VisionEncoderConfig]] = None,
text_config: Optional[Union[dict, Step3TextConfig]] = None,
understand_projector_stride: int = 1,
projector_bias: bool = True,
image_token_id: int = 128001,
**kwargs,
) -> None:
if vision_config is None:
vision_config = Step3VisionEncoderConfig()
elif isinstance(vision_config, dict):
vision_config = Step3VisionEncoderConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = Step3TextConfig()
elif isinstance(text_config, dict):
text_config = Step3TextConfig(**text_config)
self.text_config = text_config
self.understand_projector_stride = understand_projector_stride
self.projector_bias = projector_bias
self.hidden_size = text_config.hidden_size
self.image_token_id = image_token_id
super().__init__(**kwargs)

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
@ -51,6 +53,9 @@ class EagleProposer:
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.is_multimodal_model = vllm_config.model_config \
.is_multimodal_model
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
@ -76,6 +81,11 @@ class EagleProposer:
device=device,
dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
def propose(
self,
# [num_tokens]
@ -88,6 +98,7 @@ class EagleProposer:
next_token_ids: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
@ -128,14 +139,27 @@ class EagleProposer:
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds or None,
)
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
ret_hidden_states = self.model(
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens],
input_ids=input_ids,
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
@ -218,15 +242,24 @@ class EagleProposer:
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
if self.is_multimodal_model:
inputs_embeds = self.model.get_input_embeddings(input_ids)
self.inputs_embeds[:batch_size] = inputs_embeds
inputs_embeds = self.inputs_embeds[:input_batch_size]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:input_batch_size]
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
self.input_ids[:input_batch_size],
self.positions[:input_batch_size],
self.hidden_states[:input_batch_size],
input_ids=input_ids,
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
@ -391,10 +424,18 @@ class EagleProposer:
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = None
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
input_ids=input_ids,
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
inputs_embeds=inputs_embeds,
)
def validate_same_kv_cache_group(self,

View File

@ -1314,13 +1314,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _gather_mm_embeddings(
self,
scheduler_output: "SchedulerOutput",
shift_computed_tokens: int = 0,
) -> list[torch.Tensor]:
mm_embeds: list[torch.Tensor] = []
for req_id in self.input_batch.req_ids:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset
@ -2298,6 +2300,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
mm_embeds = None
if self.is_multimodal_model:
mm_embeds = self._gather_mm_embeddings(scheduler_output,
shift_computed_tokens=1)
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
@ -2305,6 +2312,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
next_token_ids=next_token_ids,
sampling_metadata=sampling_metadata,
common_attn_metadata=common_attn_metadata,
mm_embeds=mm_embeds,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids