Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-23 11:30:29 -07:00
commit 48bca9a109
119 changed files with 5080 additions and 953 deletions

View File

@ -382,7 +382,7 @@ run_genai_perf_tests() {
client_command="genai-perf profile \
-m $model \
--service-kind openai \
--backend vllm \
--backend "$backend" \
--endpoint-type chat \
--streaming \
--url localhost:$port \

View File

@ -843,3 +843,10 @@ steps:
commands:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4
- label: Qwen MoE EP Test # optional
gpu: h200
optional: true
num_gpus: 2
commands:
- CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048

View File

@ -9,8 +9,11 @@ from typing import Optional
import flashinfer
import torch
from vllm.utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn
FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn):
@ -61,13 +64,13 @@ def benchmark_decode(
else:
raise ValueError(f"Invalid kv_layout: {kv_layout}")
query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
# Always using 1.0 scale to reflect the real perf in benchmarking
q_scale = 1.0
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
query, _ = to_float8(ref_query)
else:
q_scale = 1.0
ref_query = query
query = ref_query
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_seq_len
@ -75,14 +78,13 @@ def benchmark_decode(
seq_lens = kv_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
# Always using 1.0 scale to reflect the real perf in benchmarking
k_scale = v_scale = 1.0
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
kv_cache, _ = to_float8(ref_kv_cache)
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
kv_cache = ref_kv_cache
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(
@ -142,11 +144,31 @@ def benchmark_decode(
return sum(times) / len(times), torch.std(torch.tensor(times))
o_scale = 1.0
o_sf_scale = None
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
if o_quant_dtype == FP4_DTYPE:
o_sf_scale = 500.0
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
torch.empty(
(
round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4),
),
dtype=torch.float8_e4m3fn,
),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
def baseline_decode():
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
return wrapper.run(
ref_query,
ref_kv_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output_baseline,
)
def trtllm_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
@ -158,6 +180,7 @@ def benchmark_decode(
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
@ -237,6 +260,7 @@ if __name__ == "__main__":
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
for quant_dtype in quant_dtypes:

View File

@ -9,8 +9,11 @@ from typing import Optional
import flashinfer
import torch
from vllm.utils import round_up
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = torch.float8_e4m3fn
FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn):
@ -72,13 +75,15 @@ def benchmark_prefill(
]
)
query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype)
# Always using 1.0 scale to reflect the real perf in benchmarking
q_scale = 1.0
ref_query = torch.randn(
torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
)
if q_quant_dtype == FP8_DTYPE:
query, q_scale = to_float8(query)
ref_query = query.to(dtype) * q_scale
query, _ = to_float8(ref_query)
else:
q_scale = 1.0
ref_query = query
query = ref_query
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
kv_lens[-1] = max_kv_len
@ -86,14 +91,13 @@ def benchmark_prefill(
seq_lens = kv_lens + q_lens
max_seq_len = torch.max(seq_lens).item()
kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
# Always using 1.0 scale to reflect the real perf in benchmarking
k_scale = v_scale = 1.0
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
if kv_quant_dtype == FP8_DTYPE:
kv_cache, kv_scale = to_float8(kv_cache)
ref_kv_cache = kv_cache.to(dtype) * kv_scale
kv_cache, _ = to_float8(ref_kv_cache)
else:
kv_scale = 1.0
ref_kv_cache = kv_cache
k_scale = v_scale = kv_scale
kv_cache = ref_kv_cache
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables = torch.randint(
@ -152,11 +156,31 @@ def benchmark_prefill(
return sum(times) / len(times), torch.std(torch.tensor(times))
o_scale = 1.0
o_sf_scale = None
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
if o_quant_dtype == FP4_DTYPE:
o_sf_scale = 500.0
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
torch.empty(
(
round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4),
),
dtype=torch.float8_e4m3fn,
),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
def baseline_prefill():
return wrapper.run(ref_query, ref_kv_cache, out=output_baseline)
return wrapper.run(
ref_query,
ref_kv_cache,
k_scale=k_scale,
v_scale=v_scale,
out=output_baseline,
)
def trtllm_prefill():
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
@ -172,6 +196,7 @@ def benchmark_prefill(
batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
@ -250,6 +275,7 @@ if __name__ == "__main__":
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
for quant_dtype in quant_dtypes:

View File

@ -432,31 +432,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
# Install DeepGEMM from source
ARG DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c"
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
. /etc/environment
CUDA_MAJOR="${CUDA_VERSION%%.*}"
CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}"
CUDA_MINOR="${CUDA_MINOR%%.*}"
if [ "$CUDA_MAJOR" -ge 12 ] && [ "$CUDA_MINOR" -ge 8 ]; then
git clone --recursive --shallow-submodules \
${DEEPGEMM_GIT_REPO} deepgemm
echo "🏗️ Building DeepGEMM"
pushd deepgemm
git checkout ${DEEPGEMM_GIT_REF}
# Build DeepGEMM
# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh)
rm -rf build dist
rm -rf *.egg-info
python3 setup.py bdist_wheel
uv pip install --system dist/*.whl
popd
rm -rf deepgemm
else
echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
fi
BASH
COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh
RUN --mount=type=cache,target=/root/.cache/uv \
VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" --ref "${DEEPGEMM_GIT_REF}" \
&& rm /tmp/install_deepgemm.sh
# Install EP kernels(pplx-kernels and DeepEP), NixL
COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh
COPY tools/install_nixl.sh install_nixl.sh
ENV CUDA_HOME=/usr/local/cuda
RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \
&& bash install_python_libraries.sh \
&& bash install_nixl.sh --force
#################### vLLM installation IMAGE ####################

View File

@ -172,6 +172,7 @@ The availablilty of batch-level DP is based on model implementation.
Currently, the following models support `mm_encoder_tp_mode="data"`:
- Llama4 (<gh-pr:18368>)
- MiniCPM-V-4 (<gh-pr:23327>)
- Qwen2.5-VL (<gh-pr:22742>)
- Step3 (<gh-pr:22697>)

View File

@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`.
There are other miscellaneous places hard-coding the use of `spawn`:
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L135>
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/distributed/device_communicators/all_reduce_utils.py#L135>
- <https://github.com/vllm-project/vllm/blob/d05f88679bedd73939251a17c3d785a354b2946c/vllm/entrypoints/openai/api_server.py#L184>
Related PRs:

View File

@ -284,6 +284,14 @@ Supported models:
Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}`
### DeepSeek-V3.1 Models (`deepseek_v31`)
Supported models:
* `deepseek-ai/DeepSeek-V3.1` (use with <gh-file:examples/tool_chat_template_deepseekv31.jinja>)
Flags: `--tool-call-parser deepseek_v31 --chat-template {see_above}`
### Kimi-K2 Models (`kimi_k2`)
Supported models:

View File

@ -401,6 +401,7 @@ th {
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |

View File

@ -166,7 +166,7 @@ Processed means the values after applying all processors, including temperature
##### Prompt Logprobs with Prefix Caching
Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](gh-issue:13414).
Logprobs are not cached. For a request requiring prompt logprobs, the engine will ignore the prefix cache and recompute the prefill of full prompt to generate the logprobs.
#### Deprecated Features

View File

@ -0,0 +1,91 @@
{% if not add_generation_prompt is defined %}
{% set add_generation_prompt = false %}
{% endif %}
{% if not thinking is defined %}
{% set thinking = false %}
{% endif %}
{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %}
{%- for message in messages %}
{%- if message['role'] == 'system' %}
{%- if ns.is_first_sp %}
{% set ns.system_prompt = ns.system_prompt + message['content'] %}
{% set ns.is_first_sp = false %}
{%- else %}
{% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %}
{%- endif %}
{%- endif %}
{%- endfor %}
{% if tools is defined and tools is not none %}
{% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %}
{% for tool in tools %}
{% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %}
{% endfor %}
{% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<tool▁calls▁begin><tool▁call▁begin>tool_call_name<tool▁sep>tool_call_arguments<tool▁call▁end>{{additional_tool_calls}}<tool▁calls▁end>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %}
{% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %}
{% endif %}
{{ bos_token }}{{ ns.system_prompt }}
{%- for message in messages %}
{%- if message['role'] == 'user' %}
{%- set ns.is_tool = false -%}
{%- set ns.is_first = false -%}
{%- set ns.is_last_user = true -%}
{{'<User>' + message['content']}}
{%- endif %}
{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}
{%- if ns.is_last_user %}
{{'<Assistant></think>'}}
{%- endif %}
{%- set ns.is_last_user = false -%}
{%- set ns.is_first = false %}
{%- set ns.is_tool = false -%}
{%- for tool in message['tool_calls'] %}
{%- if not ns.is_first %}
{%- if message['content'] is none %}
{{'<tool▁calls▁begin><tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments']|tojson + '<tool▁call▁end>'}}
{%- else %}
{{message['content'] + '<tool▁calls▁begin><tool▁call▁begin>' + tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments']|tojson + '<tool▁call▁end>'}}
{%- endif %}
{%- set ns.is_first = true -%}
{%- else %}
{{'<tool▁call▁begin>'+ tool['function']['name'] + '<tool▁sep>' + tool['function']['arguments']|tojson + '<tool▁call▁end>'}}
{%- endif %}
{%- endfor %}
{{'<tool▁calls▁end><end▁of▁sentence>'}}
{%- endif %}
{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}
{%- if ns.is_last_user %}
{{'<Assistant>'}}
{%- if message['prefix'] is defined and message['prefix'] and thinking %}
{{'<think>'}}
{%- else %}
{{'</think>'}}
{%- endif %}
{%- endif %}
{%- set ns.is_last_user = false -%}
{%- if ns.is_tool %}
{{message['content'] + '<end▁of▁sentence>'}}
{%- set ns.is_tool = false -%}
{%- else %}
{%- set content = message['content'] -%}
{%- if '</think>' in content %}
{%- set content = content.split('</think>', 1)[1] -%}
{%- endif %}
{{content + '<end▁of▁sentence>'}}
{%- endif %}
{%- endif %}
{%- if message['role'] == 'tool' %}
{%- set ns.is_last_user = false -%}
{%- set ns.is_tool = true -%}
{{'<tool▁output▁begin>' + message['content'] + '<tool▁output▁end>'}}
{%- endif %}
{%- endfor -%}
{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %}
{{'<Assistant>'}}
{%- if not thinking %}
{{'</think>'}}
{%- else %}
{{'<think>'}}
{%- endif %}
{% endif %}

View File

@ -13,7 +13,7 @@ protobuf # Required by LlamaTokenizer.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
aiohttp
openai >= 1.99.1 # For Responses API with reasoning content
pydantic >= 2.10
pydantic >= 2.11.7
prometheus_client >= 0.18.0
pillow # Required for image processing
prometheus-fastapi-instrumentator >= 7.0.0

View File

@ -6,7 +6,7 @@ torch==2.7.0
torchvision==0.22.0
torchaudio==2.7.0
triton==3.2
triton==3.3.0
cmake>=3.26.1,<4
packaging>=24.2
setuptools>=77.0.3,<80.0.0

View File

@ -17,4 +17,4 @@ setuptools>=77.0.3,<80.0.0
setuptools-scm>=8
runai-model-streamer==0.11.0
runai-model-streamer-s3==0.11.0
conch-triton-kernels==1.2.1
conch-triton-kernels==1.2.1

View File

@ -742,7 +742,7 @@ pycparser==2.22
# via cffi
pycryptodomex==3.22.0
# via blobfile
pydantic==2.11.5
pydantic==2.11.7
# via
# -r requirements/test.in
# albumentations

View File

@ -695,6 +695,8 @@ setup(
"video": [], # Kept for backwards compatibility
# FlashInfer should be updated together with the Dockerfile
"flashinfer": ["flashinfer-python==0.2.12"],
# Optional deps for AMD FP4 quantization support
"petit-kernel": ["petit-kernel"],
},
cmdclass=cmdclass,
package_data=package_data,

View File

@ -8,11 +8,12 @@ import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fusion import FUSED_OPS, FusionPass
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym)
from .backend import TestBackend

View File

@ -7,11 +7,13 @@ import torch
import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, GroupShape, QuantKey)
FusionPass)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, QuantKey, ScaleDesc)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
from vllm.platforms import current_platform
@ -30,10 +32,8 @@ class TestModel(torch.nn.Module):
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.key = QuantKey(dtype=FP8_DTYPE,
static=static,
group_shape=group_shape,
symmetric=True)
quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
else:

View File

@ -11,9 +11,10 @@ from tests.models.utils import check_outputs_equal
from tests.v1.attention.utils import (BatchSpec, _Backend,
create_common_attn_metadata)
from vllm import LLM, SamplingParams
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention import Attention
from vllm.attention.selector import global_force_attn_backend_context_manager
from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
@ -22,13 +23,14 @@ from vllm.config import (CacheConfig, CompilationConfig, CompilationLevel,
set_current_vllm_config)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
QuantKey, kFp8StaticTensorSym, kNvfp4Quant)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp)
from vllm.platforms import current_platform
from vllm.v1.kv_cache_interface import AttentionSpec
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
# globals needed for string-import custom Dynamo backend field
backend: Optional[TestBackend] = None
@ -105,9 +107,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
# check support
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape)
layer.impl.fused_output_quant_supported(quant_key)
for key, layer in compile_config.static_forward_context.items()
]
@ -149,12 +149,12 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
backend = None
class TestAttentionStaticQuantPatternModel(torch.nn.Module):
"""Test model for AttentionStaticQuantPattern fusion."""
class AttentionQuantPatternModel(torch.nn.Module):
"""Base model for AttentionQuantPattern fusion."""
def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int,
kv_cache_dtype: torch.dtype, device: torch.device,
vllm_config: VllmConfig):
vllm_config: VllmConfig, **kwargs):
super().__init__()
self.num_qo_heads = num_qo_heads
self.num_kv_heads = num_kv_heads
@ -172,11 +172,6 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
prefix="model.layers.0.self_attn.attn",
)
self.fp8_linear = Fp8LinearOp(
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR)
self.wscale = torch.tensor([1.0], dtype=torch.float32)
self.scale = torch.tensor([1.0], dtype=torch.float32)
self.block_size = 16
# Initialize attn MetadataBuilder
@ -230,23 +225,86 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
return self.attn_metadata
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
w: torch.Tensor):
class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
"""Test model for AttentionFp8StaticQuantPattern fusion."""
quant_key = kFp8StaticTensorSym
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.quant_key.scale.static,
act_quant_group_shape=self.quant_key.scale.group_shape)
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w", {
"weight":
torch.randn(hidden_size, hidden_size).to(
dtype=FP8_DTYPE, device=self.device).t(),
"wscale":
torch.tensor([1.0], dtype=torch.float32, device=self.device),
"scale":
torch.tensor([1.0], dtype=torch.float32, device=self.device),
})
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
return self.fp8_linear.apply(input=attn_output,
weight=w,
weight_scale=self.wscale,
input_scale=self.scale)
weight=self.w["weight"],
weight_scale=self.w["wscale"],
input_scale=self.w["scale"])
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
"""Test model for AttentionNvfp4QuantPattern fusion."""
quant_key = kNvfp4Quant
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get(
"w", {
"weight":
torch.randint(256, (hidden_size, hidden_size // 2),
dtype=FP4_DTYPE,
device=self.device),
"wscale_swizzled":
torch.randn(hidden_size, hidden_size // 16).to(
dtype=FP8_DTYPE, device=self.device),
"wscale":
torch.tensor([500], dtype=torch.float32, device=self.device),
"scale":
torch.tensor([0.002], dtype=torch.float32, device=self.device),
})
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v)
quant_output, output_block_scale = scaled_fp4_quant(
attn_output, 1 / self.w["scale"])
return cutlass_scaled_fp4_mm(a=quant_output,
b=self.w["weight"],
block_scale_a=output_block_scale,
block_scale_b=self.w["wscale_swizzled"],
alpha=self.w["scale"] * self.w["wscale"],
out_dtype=attn_output.dtype)
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)])
@pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("batch_size", [7, 256, 533])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize(
"model_name, quant_key",
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)])
@pytest.mark.parametrize("model_name, model_class",
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel),
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel)])
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER])
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@ -255,8 +313,8 @@ class TestAttentionStaticQuantPatternModel(torch.nn.Module):
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
head_size: int, batch_size: int,
dtype: torch.dtype, model_name: str,
quant_key: QuantKey, backend: _Backend,
monkeypatch, dist_init):
model_class: type[AttentionQuantPatternModel],
backend: _Backend, monkeypatch, dist_init):
"""Test AttentionStaticQuantPattern fusion pass"""
monkeypatch.setenv("VLLM_USE_V1", "1")
@ -277,8 +335,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
cache_config=CacheConfig(cache_dtype="fp8"))
# Create test inputs
hidden_size = num_qo_heads * head_size
q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
q = torch.randn(batch_size,
num_qo_heads * head_size,
dtype=dtype,
device=device)
k = torch.randn(batch_size,
num_kv_heads * head_size,
dtype=dtype,
@ -287,7 +347,6 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
num_kv_heads * head_size,
dtype=dtype,
device=device)
linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t()
# Mark first dimension as dynamic for realistic testing
torch._dynamo.mark_dynamic(q, 0)
@ -299,9 +358,12 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
with set_current_vllm_config(vllm_config_unfused), set_forward_context(
attn_metadata=None, vllm_config=vllm_config_unfused
), global_force_attn_backend_context_manager(backend):
model_unfused = TestAttentionStaticQuantPatternModel(
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
vllm_config_unfused)
model_unfused = model_class(num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config_unfused)
model_unfused = model_unfused.to(device)
forward_ctx = get_forward_context()
@ -309,7 +371,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
batch_size)
# Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v, linear_w)
result_unfused = model_unfused(q, k, v)
# Run model with attn fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
@ -317,9 +379,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
with set_current_vllm_config(vllm_config), set_forward_context(
attn_metadata=None, vllm_config=vllm_config
), global_force_attn_backend_context_manager(backend):
model_fused = TestAttentionStaticQuantPatternModel(
num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device,
vllm_config)
model_fused = model_class(num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
kv_cache_dtype=FP8_DTYPE,
device=device,
vllm_config=vllm_config,
w=model_unfused.w)
model_fused = model_fused.to(device)
forward_ctx = get_forward_context()
@ -336,21 +402,20 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
backend=test_backend,
fullgraph=True)
assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v, linear_w)
result_fused_1 = model_compiled(q, k, v)
# After the 1st round of the forward pass, output quant scale should be
# loaded into the attn layer's _o_scale_float, the 2nd round should
# reuse the loaded _o_scale_float
assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v, linear_w)
result_fused_2 = model_compiled(q, k, v)
assert model_compiled.attn._o_scale_float is not None
# Check attn fusion support
quant_key = model_class.quant_key
attn_fusion_supported = [
layer.impl.fused_output_quant_supported(quant_key.dtype,
quant_key.static,
quant_key.group_shape) for key,
layer in vllm_config.compilation_config.static_forward_context.items()
layer.impl.fused_output_quant_supported(quant_key) for key, layer in
vllm_config.compilation_config.static_forward_context.items()
]
if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion
@ -370,6 +435,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \
"Attention should have output_scale after fusion"
assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \
"Attention should not have output_block_scale before fusion"
if quant_key.dtype == FP8_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \
"Attention should not have output_block_scale after FP8 fusion"
elif quant_key.dtype == FP4_DTYPE:
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
"Attention should have output_block_scale after FP4 fusion" # noqa: E501
# Check that results are closed
torch.testing.assert_close(result_unfused,
result_fused_1,

View File

@ -233,6 +233,7 @@ MULTIMODAL_MODELS = {
"openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(),
"allenai/Molmo-7B-D-0924": PPTestSettings.fast(),
"AIDC-AI/Ovis2-1B": PPTestSettings.fast(),
"AIDC-AI/Ovis2.5-2B": PPTestSettings.fast(),
"microsoft/Phi-3.5-vision-instruct": PPTestSettings.fast(),
"mistralai/Pixtral-12B-2409": PPTestSettings.fast(load_format="dummy"),
"Qwen/Qwen-VL-Chat": PPTestSettings.fast(),

View File

@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import random
import typing
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import vllm.envs as envs
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_group,
init_distributed_environment,
initialize_model_parallel)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables
torch.manual_seed(42)
random.seed(44)
test_size_elements = 4 * 1024 * 1024
def symm_mem_allreduce_worker(local_rank: int, world_size: int):
monkeypatch = pytest.MonkeyPatch()
with monkeypatch.context() as m:
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)
cuda_communicator = typing.cast(CudaCommunicator,
get_tp_group().device_communicator)
symm_mem_comm = cuda_communicator.symm_mem_comm
if symm_mem_comm is None or symm_mem_comm.disabled:
pytest.skip("SymmMemCommunicator is not available or disabled.")
inp_direct_symm_mem = torch.randint(1,
23, (test_size_elements, ),
dtype=dtype,
device=device)
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
pytest.skip(
"SymmMemCommunicator isn't used for this world and input size."
)
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
assert out_direct_symm_mem is not None
group = get_tensor_model_parallel_group().device_group
dist.all_reduce(original_inp_direct_symm_mem, group=group)
torch.testing.assert_close(out_direct_symm_mem,
original_inp_direct_symm_mem,
atol=2.5,
rtol=0.1)
# Test tensor_model_parallel_all_reduce which should use symm_mem
inp_tensor_parallel = torch.randint(-23,
1, (test_size_elements, ),
dtype=dtype,
device=device)
original_inp_tensor_parallel = inp_tensor_parallel.clone()
out_tensor_parallel = tensor_model_parallel_all_reduce(
inp_tensor_parallel)
dist.all_reduce(original_inp_tensor_parallel, group=group)
torch.testing.assert_close(out_tensor_parallel,
original_inp_tensor_parallel,
atol=2.5,
rtol=0.1)
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="SymmMemAllreduce is only available for CUDA platforms.")
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("pipeline_parallel_size", [1])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
pipeline_parallel_size):
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
# Enable SymmMemCommunicator
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
cleanup_dist_env_and_memory()

View File

@ -74,18 +74,15 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI):
}
with pytest.raises(openai.BadRequestError) as err:
err = await client.post(path="embeddings",
cast_to=object,
body={**kwargs})
await client.post(path="embeddings", cast_to=object, body={**kwargs})
assert str(err) == f"""openai.BadRequestError:
Error code: 400 - {{'object': 'error',
'message': 'truncate_prompt_tokens value
({truncation_size})
is greater than max_model_len ({max_model_len}).
Please, select a smaller truncation size.',
'type': 'BadRequestError',
'param': None, 'code': 400}}"""
assert err.value.status_code == 400
error_details = err.value.response.json()["error"]
assert error_details["type"] == "BadRequestError"
expected_message = ("truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
assert error_details["message"] == expected_message
@pytest.mark.asyncio

View File

@ -6,7 +6,11 @@ import flashinfer
import pytest
import torch
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype)
from vllm.platforms import current_platform
from vllm.utils import round_up
if not current_platform.is_device_capability(100):
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
@ -14,6 +18,7 @@ if not current_platform.is_device_capability(100):
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def to_float8(x, dtype=torch.float8_e4m3fn):
@ -29,7 +34,9 @@ DTYPE = [torch.bfloat16]
QUANT_DTYPES = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
BATCH_SIZE = [4, 12]
MAX_SEQ_LENS = [(1024, 4096)]
@ -153,11 +160,25 @@ def test_flashinfer_trtllm_decode_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Decode
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.decode.trtllm_batch_decode_with_kv_cache(
query=query,
kv_cache=kv_cache,
@ -167,15 +188,27 @@ def test_flashinfer_trtllm_decode_with_baseline(
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 3e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2
rtol, atol = 1e-2, 2e-2
torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \
f"{torch.max(torch.abs(output - output_trtllm))}"
@ -211,6 +244,9 @@ def test_flashinfer_trtllm_prefill_with_baseline(
kv_quant_dtype = kv_quant_dtype or dtype
o_quant_dtype = o_quant_dtype or dtype
if q_quant_dtype != kv_quant_dtype:
pytest.skip("Skipped mixed QKV dtypes for prefill")
max_q_len, max_kv_len = max_seq_lens
num_qo_heads, num_kv_heads = num_heads
@ -303,11 +339,25 @@ def test_flashinfer_trtllm_prefill_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.amax(output.flatten(), dim=-1)).to(torch.float32)
# TRTLLM Prefill
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
if o_quant_dtype == FP4_DTYPE:
output_trtllm = flashinfer.utils.FP4Tensor(
torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ),
dtype=torch.uint8),
torch.empty((round_up(query.shape[0], 128),
round_up(query.shape[1] * query.shape[2] // 16, 4)),
dtype=torch.float8_e4m3fn),
)
else:
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
flashinfer.prefill.trtllm_batch_context_with_kv_cache(
query=query,
kv_cache=kv_cache,
@ -321,12 +371,24 @@ def test_flashinfer_trtllm_prefill_with_baseline(
batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
output_trtllm = output_trtllm.to(dtype) * o_scale
elif o_quant_dtype == FP4_DTYPE:
output_trtllm.data = output_trtllm.data.reshape(
-1, query.shape[1] * query.shape[2] // 2)
output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data,
output_trtllm.scale,
o_sf_scale, dtype,
query.device)
output_trtllm = output_trtllm.reshape(-1, query.shape[1],
query.shape[2])
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE:
rtol, atol = 4e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
else:
rtol, atol = 1e-2, 1e-2

View File

@ -429,11 +429,11 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[...,
0:-128],
requires_grad=False)
torch.cuda.empty_cache()
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[...,
0:-128],
requires_grad=False)
torch.cuda.synchronize()
torch.cuda.empty_cache()
# Run forward passes for both MoE blocks

View File

@ -3,15 +3,13 @@
import tempfile
from collections import OrderedDict
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
import vllm
from vllm.config import LoRAConfig
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
@ -21,7 +19,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.platforms import current_platform
@ -104,6 +101,7 @@ def dummy_model() -> nn.Module:
]))
model.config = MagicMock()
model.embedding_modules = {"lm_head": "lm_head"}
model.unpadded_vocab_size = 32000
return model
@ -137,6 +135,8 @@ def dummy_model_gate_up() -> nn.Module:
],
}
model.embedding_modules = {"lm_head": "lm_head"}
model.unpadded_vocab_size = 32000
return model
@ -221,29 +221,6 @@ def phi2_lora_files():
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
@pytest.fixture
def llama_2_7b_engine_extra_embeddings():
cleanup_dist_env_and_memory(shutdown_ray=True)
get_model_old = get_model
def get_model_patched(**kwargs):
kwargs["vllm_config"].lora_config = LoRAConfig(max_loras=4,
max_lora_rank=8)
return get_model_old(**kwargs)
with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
yield engine.llm_engine
del engine
cleanup_dist_env_and_memory(shutdown_ray=True)
@pytest.fixture
def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)
@pytest.fixture
def reset_default_device():
"""

View File

@ -5,7 +5,6 @@ import time
import pytest
import vllm.envs as env
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
@ -98,12 +97,10 @@ async def test_add_lora(chatglm3_lora_files):
# Run with warmup
add_lora_tasks = [llm.add_lora(lr) for lr in warmup_run_requests]
add_lora_results = await asyncio.gather(*add_lora_tasks)
if env.VLLM_USE_V1:
# Test that all all_lora calls are successful.
assert all(add_lora_results)
else:
# No way to check V0 engine results as the calls just return None.
pass
# Test that all all_lora calls are successful.
assert all(add_lora_results)
time_with_add_lora = await requests_processing_time(
llm, warmup_run_requests)

View File

@ -113,8 +113,7 @@ def test_llama_lora(sql_lora_files):
enable_lora=True,
# also test odd max_num_seqs
max_num_seqs=13,
max_loras=4,
enable_chunked_prefill=True)
max_loras=4)
generate_and_test(llm, sql_lora_files)
@ -128,7 +127,6 @@ def test_llama_lora_tp4(sql_lora_files):
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4,
enable_chunked_prefill=True,
)
generate_and_test(llm, sql_lora_files)
@ -144,7 +142,6 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
max_loras=4,
tensor_parallel_size=4,
fully_sharded_loras=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, sql_lora_files)

View File

@ -21,6 +21,8 @@ from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
WorkerLoRAManager)
from vllm.platforms import current_platform
from .utils import create_peft_lora
EMBEDDING_MODULES = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
@ -35,17 +37,6 @@ DEVICES = ([
DEFAULT_DTYPE = torch.get_default_dtype()
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
"""
Some tests depend on V0 internals. Since both V0 and V1 use the same
LoRAModelManager it is okay to just test V0.
"""
with monkeypatch.context() as m:
m.setenv('VLLM_USE_V1', '0')
yield
@pytest.mark.parametrize("device", DEVICES)
def test_from_lora_tensors(sql_lora_files, device):
tensors = load_file(
@ -326,7 +317,6 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
# Add up to capacity
@ -430,32 +420,40 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES)
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
tmp_path):
lora_config = LoRAConfig(max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE)
dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)
worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device,
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings)
4, 2,
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_adapter_manager.create_lora_manager(dummy_model)
mapping = LoRAMapping([], [])
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files)
LoRARequest("1", 1, dummy_lora_files),
LoRARequest("2", 2, dummy_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("3", 3, sql_lora_files),
LoRARequest("4", 4, sql_lora_files)
LoRARequest("1", 1, dummy_lora_files),
LoRARequest("3", 3, dummy_lora_files),
LoRARequest("4", 4, dummy_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
@ -464,9 +462,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files),
LoRARequest("5", 5, sql_lora_files)
LoRARequest("1", 1, dummy_lora_files),
LoRARequest("2", 2, dummy_lora_files),
LoRARequest("5", 5, dummy_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
@ -475,9 +473,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files)
LoRARequest("1", 1, dummy_lora_files),
LoRARequest("1", 1, dummy_lora_files),
LoRARequest("1", 1, dummy_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
@ -486,9 +484,9 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
worker_adapter_manager.set_active_adapters([
LoRARequest("6", 6, sql_lora_files),
LoRARequest("7", 7, sql_lora_files),
LoRARequest("8", 8, sql_lora_files)
LoRARequest("6", 6, dummy_lora_files),
LoRARequest("7", 7, dummy_lora_files),
LoRARequest("8", 8, dummy_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
@ -499,11 +497,11 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
# Over capacity
with pytest.raises(RuntimeError):
worker_adapter_manager.set_active_adapters([
LoRARequest("10", 10, sql_lora_files),
LoRARequest("11", 11, sql_lora_files),
LoRARequest("12", 12, sql_lora_files),
LoRARequest("13", 13, sql_lora_files),
LoRARequest("14", 14, sql_lora_files)
LoRARequest("10", 10, dummy_lora_files),
LoRARequest("11", 11, dummy_lora_files),
LoRARequest("12", 12, dummy_lora_files),
LoRARequest("13", 13, dummy_lora_files),
LoRARequest("14", 14, dummy_lora_files)
], mapping)
assert worker_adapter_manager.device == device
@ -512,33 +510,41 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
@pytest.mark.parametrize("device", DEVICES)
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
tmp_path):
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
4, 2, dummy_model_gate_up.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device,
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
worker_adapter_manager.create_lora_manager(
llama_2_7b_model_extra_embeddings)
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model_gate_up,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)
mapping = LoRAMapping([], [])
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files)
LoRARequest("1", 1, dummy_lora_files),
LoRARequest("2", 2, dummy_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("3", 3, sql_lora_files),
LoRARequest("4", 4, sql_lora_files)
LoRARequest("1", 1, dummy_lora_files),
LoRARequest("3", 3, dummy_lora_files),
LoRARequest("4", 4, dummy_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 3, 4}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
@ -546,9 +552,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("2", 2, sql_lora_files),
LoRARequest("5", 5, sql_lora_files)
LoRARequest("1", 1, dummy_lora_files),
LoRARequest("2", 2, dummy_lora_files),
LoRARequest("5", 5, dummy_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1, 2, 5}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
@ -556,9 +562,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
worker_adapter_manager.set_active_adapters([
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files),
LoRARequest("1", 1, sql_lora_files)
LoRARequest("1", 1, dummy_lora_files),
LoRARequest("1", 1, dummy_lora_files),
LoRARequest("1", 1, dummy_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {1}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
@ -566,9 +572,9 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
worker_adapter_manager.set_active_adapters([
LoRARequest("6", 6, sql_lora_files),
LoRARequest("7", 7, sql_lora_files),
LoRARequest("8", 8, sql_lora_files)
LoRARequest("6", 6, dummy_lora_files),
LoRARequest("7", 7, dummy_lora_files),
LoRARequest("8", 8, dummy_lora_files)
], mapping)
assert worker_adapter_manager.list_adapters() == {6, 7, 8}
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
@ -578,11 +584,11 @@ def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
# Over capacity
with pytest.raises(RuntimeError):
worker_adapter_manager.set_active_adapters([
LoRARequest("10", 10, sql_lora_files),
LoRARequest("11", 11, sql_lora_files),
LoRARequest("12", 12, sql_lora_files),
LoRARequest("13", 13, sql_lora_files),
LoRARequest("14", 14, sql_lora_files)
LoRARequest("10", 10, dummy_lora_files),
LoRARequest("11", 11, dummy_lora_files),
LoRARequest("12", 12, dummy_lora_files),
LoRARequest("13", 13, dummy_lora_files),
LoRARequest("14", 14, dummy_lora_files)
], mapping)
assert worker_adapter_manager.device == device

View File

@ -50,7 +50,6 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size,
enable_chunked_prefill=True,
)
expected_lora_output = [

View File

@ -4,17 +4,14 @@
import os
import random
import tempfile
from typing import Union
from unittest.mock import patch
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VllmConfig)
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.v1.worker.gpu_worker import Worker as V1Worker
from vllm.worker.worker import Worker
from vllm.v1.worker.gpu_worker import Worker
NUM_LORAS = 16
@ -22,18 +19,11 @@ NUM_LORAS = 16
@patch.dict(os.environ, {"RANK": "0"})
def test_worker_apply_lora(sql_lora_files):
def set_active_loras(worker: Union[Worker, V1Worker],
lora_requests: list[LoRARequest]):
def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]):
lora_mapping = LoRAMapping([], [])
if isinstance(worker, Worker):
# v0 case
worker.model_runner.set_active_loras(lora_requests, lora_mapping)
else:
# v1 case
worker.model_runner.lora_manager.set_active_adapters(
lora_requests, lora_mapping)
worker_cls = V1Worker if envs.VLLM_USE_V1 else Worker
worker.model_runner.lora_manager.set_active_adapters(
lora_requests, lora_mapping)
vllm_config = VllmConfig(
model_config=ModelConfig(
@ -62,7 +52,7 @@ def test_worker_apply_lora(sql_lora_files):
max_cpu_loras=NUM_LORAS,
max_loras=NUM_LORAS),
)
worker = worker_cls(
worker = Worker(
vllm_config=vllm_config,
local_rank=0,
rank=0,

View File

@ -1,10 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
from dataclasses import dataclass
from typing import Optional, Union
import torch
from safetensors.torch import save_file
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
@ -340,3 +343,76 @@ def generate_data_for_nslices(
seq_len_tensor,
indices,
)
def create_peft_lora(
model: torch.nn.Module,
save_dir: str,
target_modules: list[str],
rank: int = 8,
alpha: int = 16,
dropout: float = 0.1,
lora_dtype: torch.dtype = torch.float16,
) -> dict[str, torch.Tensor]:
lora_weights = {}
adapter_config = {
"peft_type": "LORA",
"auto_mapping": None,
"base_model_name_or_path": "dummy_model",
"revision": None,
"task_type": "CAUSAL_LM",
"inference_mode": False,
"r": rank,
"lora_alpha": alpha,
"lora_dropout": dropout,
"fan_in_fan_out": False,
"bias": "none",
"modules_to_save": None,
"init_lora_weights": True,
"layers_to_transform": None,
"layers_pattern": None,
"target_modules": target_modules,
"exclude_modules": None,
"use_rslora": False,
"use_dora": False,
"loftq_config": None,
}
for module_name in target_modules:
module = model
for attr in module_name.split("."):
module = getattr(module, attr)
if hasattr(module, "input_size") and hasattr(module, "output_size"):
in_features = module.input_size
out_features = module.output_size
elif hasattr(module, "embedding_dim") and hasattr(
module, "num_embeddings"):
# ParallelLMHead
in_features = module.embedding_dim
out_features = module.num_embeddings
else:
raise ValueError(
f"Unable to determine dimensions for module {module_name}")
lora_A = torch.randn(rank, in_features, dtype=lora_dtype)
torch.nn.init.kaiming_uniform_(lora_A, a=5**0.5)
lora_B = torch.zeros(out_features, rank, dtype=lora_dtype)
# PEFT style
lora_weights[f"base_model.model.{module_name}.lora_A.weight"] = lora_A
lora_weights[f"base_model.model.{module_name}.lora_B.weight"] = lora_B
config_path = os.path.join(save_dir, "adapter_config.json")
with open(config_path, "w", encoding="utf-8") as f:
json.dump(adapter_config, f, indent=2, ensure_ascii=False)
weights_path = os.path.join(save_dir, "adapter_model.safetensors")
save_file(lora_weights, weights_path)
return lora_weights

View File

@ -11,7 +11,6 @@ from pathlib import PosixPath
import pytest
from transformers import (AutoModel, AutoModelForImageTextToText,
AutoModelForTextToWaveform, AutoModelForVision2Seq)
from transformers.utils import is_flash_attn_2_available
from vllm.platforms import current_platform
from vllm.utils import identity
@ -637,10 +636,7 @@ VLM_TEST_SETTINGS = {
dtype="half",
num_logprobs=10,
patch_hf_runner=model_utils.ovis2_5_patch_hf_runner,
marks=[pytest.mark.skipif(
not is_flash_attn_2_available(),
reason="HF model needs `flash_attn` installed"
)],
hf_model_kwargs={"revision": "refs/pr/5"},
),
"phi3v": VLMTestInfo(
models=["microsoft/Phi-3.5-vision-instruct"],

View File

@ -292,6 +292,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
trust_remote_code=True,
is_available_online=False),
"SmolLM3ForCausalLM": _HfExamplesInfo("HuggingFaceTB/SmolLM3-3B"),
"StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501
"StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"),
@ -465,9 +468,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
extras={"1.6-llama": "AIDC-AI/Ovis1.6-Llama3.2-3B",
"1.6-gemma": "AIDC-AI/Ovis1.6-Gemma2-9B"}), # noqa: E501
"Ovis2_5": _HfExamplesInfo("AIDC-AI/Ovis2.5-2B",
trust_remote_code=True,
max_transformers_version="4.53",
transformers_version_reason="HF model is not compatible"), # noqa: E501
trust_remote_code=True),
"PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501
extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501
"Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct",

View File

@ -0,0 +1,459 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
import json
from collections.abc import Generator
from typing import Optional
import pytest
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import SeedOssToolParser
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available
MODEL = "ByteDance-Seed/Seed-OSS-36B-Instruct"
@pytest.fixture(scope="module")
def seed_oss_tokenizer():
return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True)
@pytest.fixture
def seed_oss_tool_parser(seed_oss_tokenizer):
return SeedOssToolParser(seed_oss_tokenizer)
@pytest.fixture
def sample_tools():
return [
ChatCompletionToolsParam(
type="function",
function={
"name": "get_weather",
"description": "Get current temperature for a given location.",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description":
"City and country e.g. Bogotá, Colombia"
},
"unit": {
"type": "string",
"description": "this is the unit of temperature"
}
},
"required": ["location"],
"additionalProperties": False
},
"returns": {
"type": "object",
"properties": {
"temperature": {
"type": "number",
"description": "temperature in celsius"
}
},
"required": ["temperature"],
"additionalProperties": False
},
"strict": True
}),
]
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
# Seed-OSS tool call will not generate id
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
assert actual_tool_call.function.name == expected_tool_call.function.name
assert actual_tool_call.function.arguments == expected_tool_call.function.arguments
def test_extract_tool_calls_no_tools(seed_oss_tool_parser):
model_output = "This is a test response without any tool calls"
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"tool_call_0_thinking_budget",
"tool_call_512_thinkg_budget",
"tool_call_unlimited_thinking_budget",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n"""
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
],
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
),
(
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
"""\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
],
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
),
(
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
"unit": "celsius",
}, ),
),
type='function')
],
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think>""",
),
],
)
def test_extract_tool_calls(seed_oss_tool_parser, sample_tools, model_output,
expected_tool_calls, expected_content):
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
extracted_tool_calls = seed_oss_tool_parser.extract_tool_calls(
model_output, request=request) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_streaming_tool_calls_no_tools(seed_oss_tool_parser):
model_output = "This is a test response without any tool calls"
result = seed_oss_tool_parser.extract_tool_calls_streaming(
previous_text="his is a test response",
current_text=model_output,
delta_text=" without any tool calls.",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return the delta text as content
assert result is not None
assert hasattr(result, 'content')
assert result.content == " without any tool calls."
def stream_delta_message_generator(
seed_oss_tool_parser: SeedOssToolParser,
seed_oss_tokenizer: AnyTokenizer,
model_output: str,
request: Optional[ChatCompletionRequest] = None
) -> Generator[DeltaMessage, None, None]:
all_token_ids = seed_oss_tokenizer.encode(model_output,
add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = detokenize_incrementally(
tokenizer=seed_oss_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
)
current_text = previous_text + delta_text
delta_message = seed_oss_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (previous_tokens +
new_tokens if previous_tokens else new_tokens)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
@pytest.mark.parametrize(
ids=[
"tool_call_0_thinking_budget",
"tool_call_512_thinkg_budget",
"tool_call_unlimited_thinking_budget",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
("""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n"""
"""<parameter=location>Barcelona, Spain</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
],
"""<seed:think>\n</seed:cot_budget_reflect>\n</seed:cot_budget_reflect>\n"""
"""The current thinking budget is 0, so I will directly start answering the question.\n</seed:think>\n"""
),
(
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n"""
"""<seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, Spain</parameter>\n</function>"""
"""\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps({
"location": "Barcelona, Spain",
}, ),
),
type='function')
],
"""<seed:think>The user\'s current thinking budget is 512.</seed:cot_budget_reflect>\nLet me analyze the """
"""question. The user wants to know the weather in Barcelona, Spain. Looking at the functions available, """
"""there\'s a get_weather function that can retrieve the current temperature for a given location. \n\nFirst, """
"""check the parameters required by get_weather: location is mandatory (needs city and country), and unit is """
"""optional. The user provided "Barcelona Spain" as the location, which fits the required format (city, """
"""country). \n<seed:cot_budget_reflect>I have used 131 tokens, and there are 381 tokens remaining for use."""
"""</seed:cot_budget_reflect>\n Since the unit isn\'t specified, the function will default to Celsius, which """
"""is fine. \n\nThere\'s no need to ask for more information because the location is clear. So I should call """
"""the get_weather function with location set to "Barcelona, Spain" (adding a comma for clarity, though the """
"""user\'s input has a space, but the function might accept either; to be safe, using the standard format """
"""with a comma).\n<seed:cot_budget_reflect>I have used 257 tokens, and there are 255 tokens remaining for """
"""use.</seed:cot_budget_reflect>\n The unit parameter can be omitted since it\'s optional.</seed:think>\n""",
),
(
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think><seed:tool_call>\n<function=get_weather>\n<parameter=location>"""
"""Barcelona, Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>""",
[
ToolCall(function=FunctionCall(
name="get_weather",
arguments=json.dumps(
{
"location": "Barcelona, Spain",
"unit": "celsius",
}, ),
),
type='function')
],
"""<seed:think>\nGot it, let\'s see. The user asked for the weather in Barcelona, Spain. """
"""First, I need to remember the function I can use: get_weather. The function requires a """
"""location (city and country) which is "Barcelona, Spain" here, and unit is optional. Since """
"""the user didn\'t specify the unit, the default in the function is Celsius, right? Wait, """
"""let me check the function docstring again. Oh, the function says unit is optional, and """
"""returns temperature in Celsius. So I should call get_weather with location "Barcelona, """
"""Spain" and maybe omit unit or set to Celsius. Let me format the function call correctly. """
"""The format is <seed:tool_call>\n<function=get_weather>\n<parameter=location>Barcelona, """
"""Spain</parameter>\n<parameter=unit>celsius</parameter>\n</function>\n</seed:tool_call>. """
"""Wait, but does the unit parameter accept "celsius"? The docstring says unit is the unit """
"""of temperature, but the return is in Celsius anyway. Maybe even if I don\'t pass unit, """
"""it\'s okay, but to be explicit, maybe pass "celsius". Let me go with that. So the function """
"""call should be as above. Then wait for the result to come back and tell the user the """
"""temperature in Celsius.</seed:think>""",
),
],
)
def test_streaming_tool_calls(seed_oss_tool_parser, seed_oss_tokenizer,
sample_tools, model_output, expected_tool_calls,
expected_content):
"""Test incremental streaming behavior"""
request = ChatCompletionRequest(model=MODEL,
messages=[],
tools=sample_tools)
other_content = ''
tool_states = {} # Track state per tool index
for delta_message in stream_delta_message_generator(
seed_oss_tool_parser, seed_oss_tokenizer, model_output, request):
# role should never be streamed from tool parser
assert not delta_message.role
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
# Initialize state for new tool
if idx not in tool_states:
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"type": None
}
# First chunk should have id, name, and type
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
if tool_call.type:
assert tool_call.type == "function"
tool_states[idx]["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
# Should only be set once
assert tool_states[idx]["name"] is None
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
# Accumulate arguments incrementally
tool_states[idx][
"arguments"] += tool_call.function.arguments
# Verify final content
assert other_content == expected_content
# Verify we got all expected tool calls
assert len(tool_states) == len(expected_tool_calls)
# Verify each tool call
for idx, expected_tool in enumerate(expected_tool_calls):
state = tool_states[idx]
assert state["id"] is not None
assert state["type"] == "function"
assert state["name"] == expected_tool.function.name
# Parse accumulated arguments
arguments_str = state["arguments"]
assert arguments_str is not None
actual_args = json.loads(arguments_str)
expected_args = json.loads(expected_tool.function.arguments)
assert actual_args == expected_args

View File

@ -11,9 +11,11 @@ from typing import TYPE_CHECKING, Any
import jsonschema
import pytest
import regex as re
import torch
from pydantic import BaseModel
from tests.reasoning.utils import run_reasoning_extraction
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
@ -727,3 +729,83 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch):
assert "a4" not in generated
assert "a5" not in generated
assert "a6" not in generated
@pytest.mark.parametrize("guided_decoding_backend",
["guidance", "xgrammar", "outlines"])
def test_structured_output_batched_with_non_guided_requests(
monkeypatch: pytest.MonkeyPatch,
sample_json_schema: dict[str, Any],
guided_decoding_backend: str,
):
monkeypatch.setenv("VLLM_USE_V1", "1")
# Don't use eager execution on TPUs because we want to test for no
# recompilation at runtime
enforce_eager = bool(not current_platform.is_tpu())
llm = LLM(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
enforce_eager=enforce_eager,
max_model_len=1024,
guided_decoding_backend=guided_decoding_backend,
guided_decoding_disable_any_whitespace=(guided_decoding_backend
in {"xgrammar", "guidance"}),
)
guided_prompt = (
"Give an example JSON for an employee profile that fits this "
"schema. Make the response as short as possible. Schema: "
f"{sample_json_schema}")
non_guided_prompt = "The diameter of the Earth in kilometers is "
prompts = [guided_prompt, non_guided_prompt]
sampling_params = [
SamplingParams(
temperature=1.0,
max_tokens=400,
guided_decoding=GuidedDecodingParams(json=sample_json_schema)),
# No max tokens, temp=0 to assert on contents
SamplingParams(
seed=42,
temperature=0,
top_p=1.0,
),
]
outputs = llm.generate(prompts=prompts,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
# Free memory as soon as possible as failed assertions
# will short circuit and not free up memory
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
for index, output in enumerate(outputs):
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt:\n{prompt!r}\nGenerated text:\n{generated_text!r}")
if index == 0:
# First prompt is guided, expect valid JSON
assert "\n" not in generated_text
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json,
schema=sample_json_schema)
else:
# Second prompt is not guided, expect valid output
# Cannot assert on exact output, but we can expect it to be factual
assert "12,742" in generated_text
# non-guided requests should not return a valid JSON here
with pytest.raises(ValueError):
output_json = json.loads(generated_text)

View File

@ -14,6 +14,7 @@ from unittest.mock import patch
import pytest
import ray
import torch
from vllm import LLM
from vllm.config import KVTransferConfig
@ -22,6 +23,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlConnectorWorker)
from vllm.forward_context import ForwardContext
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from .utils import create_request, create_scheduler, create_vllm_config
@ -98,7 +100,6 @@ class FakeNixlWrapper:
def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
self._cycles_before_xfer_done = cycles
@contextlib.contextmanager
@ -562,3 +563,86 @@ def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
sampling_params)
# Request-0 times out and is cleared!
assert '0' not in req_to_blocks
def test_register_kv_caches(dist_init):
"""
Test that register_kv_caches() properly calls nixl_wrapper methods with
correct data.
This test verifies:
1. nixl_wrapper.get_reg_descs() is called with caches_data containing
tensor metadata
2. nixl_wrapper.get_xfer_descs() is called with blocks_data containing
block layout info
"""
vllm_config = create_vllm_config()
# Create test kv cache tensors using proper backend shape
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(num_blocks=2,
block_size=16,
num_kv_heads=4,
head_size=64)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
expected_tensor_size = shared_tensor[0].element_size(
) * shared_tensor[0].numel()
expected_base_addrs = [
shared_tensor[0].data_ptr(), shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(), unique_tensor[1].data_ptr()
]
with patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper") as mock_nixl_wrapper, \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Event"), \
patch("vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"): # noqa: E501
# Create connector
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(
vllm_config, connector.engine_id, hand_shake_latency=0)
# Get the mock instance
mock_wrapper_instance = mock_nixl_wrapper.return_value
connector.connector_worker.nixl_wrapper = mock_wrapper_instance
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
# Verify get_reg_descs was called with caches_data
assert mock_wrapper_instance.get_reg_descs.called
caches_data, _ = mock_wrapper_instance.get_reg_descs.call_args[0]
assert len(caches_data) == 4
for i, cache_entry in enumerate(caches_data):
base_addr, size, _tp_rank, _ = cache_entry
assert size == expected_tensor_size, \
f"Entry {i}: Expected tensor size {expected_tensor_size}, " \
f"got {size}"
assert base_addr == expected_base_addrs[i], \
f"Entry {i}: Expected base address {expected_base_addrs[i]}, " \
f"got {base_addr}"
# Verify get_xfer_descs was called with blocks_data
assert mock_wrapper_instance.get_xfer_descs.called
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]
# Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, \
f"Expected {expected_blocks_count} blocks, " \
f"got {len(blocks_data)}"
expected_block_len = expected_tensor_size // 2
for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
assert block_len == expected_block_len, \
f"Block entry {i}: Expected block len {expected_block_len}, " \
f"got {block_len}"

View File

@ -680,6 +680,7 @@ def test_init_kv_cache_with_kv_sharing_valid():
kv_cache_spec[layer_0].page_size_bytes
runner.initialize_kv_cache(kv_cache_config)
kv_cache_config_after_init = runner.kv_cache_config
layer_0_kv = vllm_ctx[layer_0].kv_cache[0]
layer_1_kv = vllm_ctx[layer_1].kv_cache[0]
@ -687,10 +688,12 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert id(layer_1_kv) == id(layer_0_kv)
# check layer 1 added to kv cache group's layer names
assert len(kv_cache_config.kv_cache_groups) == 1
assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1
assert len(kv_cache_config_after_init.kv_cache_groups) == 1
assert len(kv_cache_config_after_init.kv_cache_groups[0].layer_names) == 2
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
0] == layer_0
assert kv_cache_config_after_init.kv_cache_groups[0].layer_names[
1] == layer_1
def test_hybrid_attention_mamba_tensor_shapes(monkeypatch):

View File

@ -37,7 +37,7 @@ ALLOWED_FILES = set([
'vllm/distributed/utils.py',
'vllm/distributed/parallel_state.py',
'vllm/engine/multiprocessing/client.py',
'vllm/distributed/device_communicators/custom_all_reduce_utils.py',
'vllm/distributed/device_communicators/all_reduce_utils.py',
'vllm/distributed/device_communicators/shm_broadcast.py',
'vllm/engine/multiprocessing/engine.py',
'benchmarks/kernels/graph_machete_bench.py',

108
tools/install_deepgemm.sh Executable file
View File

@ -0,0 +1,108 @@
#!/bin/bash
# Script to install DeepGEMM from source
# This script can be used both in Docker builds and by users locally
set -e
# Default values
DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git"
DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c"
# Parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--ref)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --ref requires an argument." >&2
exit 1
fi
DEEPGEMM_GIT_REF="$2"
shift 2
;;
--cuda-version)
if [[ -z "$2" || "$2" =~ ^- ]]; then
echo "Error: --cuda-version requires an argument." >&2
exit 1
fi
CUDA_VERSION="$2"
shift 2
;;
-h|--help)
echo "Usage: $0 [OPTIONS]"
echo "Options:"
echo " --ref REF Git reference to checkout (default: $DEEPGEMM_GIT_REF)"
echo " --cuda-version VER CUDA version (auto-detected if not provided)"
echo " -h, --help Show this help message"
exit 0
;;
*)
echo "Unknown option: $1" >&2
exit 1
;;
esac
done
# Auto-detect CUDA version if not provided
if [ -z "$CUDA_VERSION" ]; then
if command -v nvcc >/dev/null 2>&1; then
CUDA_VERSION=$(nvcc --version | grep "release" | sed -n 's/.*release \([0-9]\+\.[0-9]\+\).*/\1/p')
echo "Auto-detected CUDA version: $CUDA_VERSION"
else
echo "Warning: Could not auto-detect CUDA version. Please specify with --cuda-version"
exit 1
fi
fi
# Extract major and minor version numbers
CUDA_MAJOR="${CUDA_VERSION%%.*}"
CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}"
CUDA_MINOR="${CUDA_MINOR%%.*}"
echo "CUDA version: $CUDA_VERSION (major: $CUDA_MAJOR, minor: $CUDA_MINOR)"
# Check CUDA version requirement
if [ "$CUDA_MAJOR" -lt 12 ] || { [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 8 ]; }; then
echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})"
exit 0
fi
echo "Installing DeepGEMM from source..."
echo "Repository: $DEEPGEMM_GIT_REPO"
echo "Reference: $DEEPGEMM_GIT_REF"
# Create a temporary directory for the build
INSTALL_DIR=$(mktemp -d)
trap 'rm -rf "$INSTALL_DIR"' EXIT
# Clone the repository
git clone --recursive --shallow-submodules "$DEEPGEMM_GIT_REPO" "$INSTALL_DIR/deepgemm"
echo "🏗️ Building DeepGEMM"
pushd "$INSTALL_DIR/deepgemm"
# Checkout the specific reference
git checkout "$DEEPGEMM_GIT_REF"
# Build DeepGEMM
# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh)
rm -rf build dist
rm -rf *.egg-info
python3 setup.py bdist_wheel
# Install the wheel
if command -v uv >/dev/null 2>&1; then
echo "Installing DeepGEMM wheel using uv..."
# Use --system in Docker contexts, respect user's environment otherwise
if [ -n "$VLLM_DOCKER_BUILD_CONTEXT" ]; then
uv pip install --system dist/*.whl
else
uv pip install dist/*.whl
fi
else
echo "Installing DeepGEMM wheel using pip..."
python3 -m pip install dist/*.whl
fi
popd
echo "✅ DeepGEMM installation completed successfully"

View File

@ -9,8 +9,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional,
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.multimodal import MultiModalPlaceholderMap
if TYPE_CHECKING:
@ -285,20 +284,17 @@ class AttentionImpl(ABC, Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
def fused_output_quant_supported(self, quant_key: QuantKey):
"""
Does this attention implementation support fused output quantization.
This is used by the AttnFusionPass to only fuse output quantization
onto implementations that support it.
TODO(luka) merge parameters into QuantDescriptor
:param dtype: quantized dtype
:param static: static or dynamic quantization
:param group_shape: quant group shape.
:param quant_key: QuantKey object that describes the quantization op
:return: is fusion supported for this type of quantization
"""
return False
@ -317,6 +313,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError

View File

@ -800,6 +800,7 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
attn_metadata: DifferentialFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -817,6 +818,11 @@ class DifferentialFlashAttentionImpl(AttentionImpl):
{q,k,v}_descale to be (num_sequences, num_kv_heads).
We use torch's .expand() to avoid duplicating values
"""
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for DifferentialFlashAttentionImpl")
if self.lambda_full is None:
self.lambda_init = self.differential_flash_attention_config[
"lambda_init"]

View File

@ -371,6 +371,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
attn_metadata: DualChunkFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with DualChunkFlashAttention.
Args:
@ -386,7 +387,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
"""
assert output is None, "Output tensor not supported for DualChunk"
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")

View File

@ -596,6 +596,7 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
@ -615,7 +616,7 @@ class FlashAttentionImpl(AttentionImpl):
"""
assert output is not None, "Output tensor must be provided."
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for FlashAttentionImpl")

View File

@ -1238,12 +1238,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
attn_metadata: T,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for MLAImplBase")

View File

@ -20,7 +20,7 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
QuantKey, kFp8StaticTensorSym)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -529,11 +529,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
head_dim).reshape(tokens, n_kv_heads * n_rep,
head_dim))
def fused_output_quant_supported(self, dtype: torch.dtype, static: bool,
group_shape: GroupShape):
def fused_output_quant_supported(self, quant_key: QuantKey):
if self.use_triton_flash_attn:
return dtype == current_platform.fp8_dtype(
) and static and group_shape == GroupShape.PER_TENSOR
return quant_key == kFp8StaticTensorSym
# Only supported in the Triton backend
return False
@ -548,6 +546,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_metadata: ROCmFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
@ -606,6 +605,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"fused output quantization only supported for Triton"
" implementation in ROCMFlashAttentionImpl for now")
if output_block_scale is not None:
raise NotImplementedError(
"fused nvfp4 output quantization is not supported"
" for ROCMFlashAttentionImpl")
query = query.view(-1, self.num_heads, self.head_size)
if key is not None:
assert value is not None

View File

@ -432,6 +432,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
attn_metadata: "XFormersMetadata",
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
@ -484,7 +485,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
if output_scale is not None:
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
"fused output quantization is not yet supported"
" for XFormersImpl")

View File

@ -495,6 +495,7 @@ def unified_attention_with_output(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
@ -510,7 +511,8 @@ def unified_attention_with_output(
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale)
output_scale=output_scale,
output_block_scale=output_block_scale)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
@ -522,6 +524,7 @@ def unified_attention_with_output_fake(
output: torch.Tensor,
layer_name: str,
output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> None:
return
@ -529,7 +532,7 @@ def unified_attention_with_output_fake(
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["output"],
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -6,12 +6,13 @@ from typing import List, Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig, QuantizationConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata, make_local_attention_virtual_batches,
subclass_attention_backend, subclass_attention_metadata_builder)
subclass_attention_backend)
from ..layer import Attention
@ -24,21 +25,23 @@ def create_chunked_local_attention_backend(
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
def build_preprocess_fn(cm: CommonAttentionMetadata):
return make_local_attention_virtual_batches(attention_chunk_size, cm,
block_size)
underlying_builder = underlying_attn_backend.get_builder_cls()
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
common_attn_metadata = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size)
return super().build(common_prefix_len, common_attn_metadata,
fast_build)
# Dynamically create a new attention backend that wraps the
# underlying attention backend but applies
# `make_local_attention_virtual_batches` before calling `build(...)`
builder_cls = subclass_attention_metadata_builder(
name_prefix=prefix,
builder_cls=underlying_attn_backend.get_builder_cls(),
build_preprocess_fn=build_preprocess_fn)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=builder_cls)
builder_cls=ChunkedLocalAttentionBuilder)
return attn_backend

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
from typing import Optional
import torch
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionType)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
subclass_attention_backend)
@functools.lru_cache
def create_encoder_only_attention_backend(
underlying_attn_backend: AttentionBackend, ) -> type[AttentionBackend]:
prefix = "EncoderOnlyAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> AttentionMetadata:
new_common_attn_metadata = copy(common_attn_metadata)
new_common_attn_metadata.causal = False
return super().build(common_prefix_len, new_common_attn_metadata,
fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=EncoderOnlyAttentionBuilder)
return attn_backend
class EncoderOnlyAttention(Attention):
"""
Encoder attention is a special case that doesn't need a KV Cache.
"""
def __init__(self,
num_heads: int,
head_size: int,
scale: float,
cache_config: Optional[CacheConfig] = None,
attn_type: Optional[str] = None,
**kwargs):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if envs.VLLM_USE_V1:
underlying_attn_backend = get_attn_backend(head_size, dtype,
kv_cache_dtype,
block_size)
attn_backend = create_encoder_only_attention_backend(
underlying_attn_backend)
else:
# in v0 encoder only attention is handled inside the backends
attn_backend = None
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_ONLY, \
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
super().__init__(num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_ONLY,
**kwargs)

View File

@ -18,7 +18,7 @@ class BeamSearchSequence:
The text field is optional and will only be filled when the sequence is
about to be returned to the user.
"""
# The tokens includes the prompt.
# The tokens include the prompt.
tokens: list[int]
logprobs: list[dict[int, Logprob]]
lora_request: Optional[LoRARequest] = None

View File

@ -484,7 +484,7 @@ class VllmBackend:
factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affects the computation graph.
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
env_hash = envs.compute_hash()
factors.append(env_hash)

View File

@ -12,7 +12,8 @@ from torch._ops import OpOverload
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym,
kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from .fx_utils import find_getitem_maybe
@ -21,6 +22,7 @@ from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
def empty_bf16(*args, **kwargs):
@ -31,42 +33,13 @@ def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
def empty_i32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda")
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.
dtype: quantized data type
static: static quantization if True, dynamic if False
group_shape: quantization group shape
symmetric: symmetric if True, asymmetric if False
TODO(luka) use QuantDescriptor once standardized:
https://github.com/vllm-project/vllm/issues/8913
"""
dtype: torch.dtype
static: bool
group_shape: GroupShape
symmetric: bool = True
def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"QuantKey({'static' if self.static else 'dynamic'},"
f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape},"
f"{'a' if not self.symmetric else ''}symmetric)")
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True)
QUANT_OPS: dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym:
torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
@ -74,6 +47,7 @@ QUANT_OPS: dict[QuantKey, OpOverload] = {
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
kFp8DynamicTokenSym:
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
kNvfp4Quant: torch.ops._C.scaled_fp4_quant.default, # noqa: E501
}
@ -187,11 +161,9 @@ class RMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
quant=QuantKey(dtype=quant_dtype,
scale=kStaticTensorScale,
symmetric=symmetric))
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass):
@ -244,11 +216,9 @@ class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(
dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric))
quant=QuantKey(dtype=quant_dtype,
scale=kStaticTensorScale,
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
@ -337,10 +307,10 @@ class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=False,
group_shape=group_shape,
scale=scale,
symmetric=symmetric))
super().__init__(epsilon, key)
@ -435,10 +405,10 @@ class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
quant_dtype: torch.dtype,
group_shape: GroupShape = GroupShape.PER_TOKEN,
symmetric=True):
scale = ScaleDesc(torch.float32, False, group_shape)
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=False,
group_shape=group_shape,
scale=scale,
symmetric=symmetric))
super().__init__(epsilon, key)

View File

@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
@ -11,44 +13,41 @@ from torch._subclasses.fake_tensor import (FakeTensorMode,
from vllm.attention import Attention
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kNvfp4Quant, kStaticTensorScale)
from vllm.platforms import current_platform
from vllm.utils import round_up
from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32
from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
ATTN_OP = torch.ops.vllm.unified_attention_with_output.default
RESHAPE_OP = torch.ops.aten.reshape.default
class AttentionStaticQuantPattern:
class AttentionQuantPattern(ABC):
"""
Fusion for Attention+StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the StaticQuant
op will be removed from the graph, and its scale will be passed into
Attention op as the `output_scale` argument.
The base class for Attn+Quant fusions.
Should not be used directly.
"""
def __init__(
self,
layer: Attention,
quant_dtype: torch.dtype,
symmetric=True,
quant_key: QuantKey,
):
self.layer = layer
self.layer_name = layer.layer_name
self.num_heads = layer.num_heads
self.head_size = layer.head_size
self.quant_dtype = quant_dtype
self.quant_key = QuantKey(dtype=quant_dtype,
static=True,
group_shape=GroupShape.PER_TENSOR,
symmetric=symmetric)
self.quant_key = quant_key
self.quant_dtype = quant_key.dtype
assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key]
@ -57,12 +56,49 @@ class AttentionStaticQuantPattern:
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
@staticmethod
def wrap_trace_fn(process_fx, trace_fn):
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped
@staticmethod
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
def register_if_supported(self, pm_pass: PatternMatcherPass):
if self.layer.impl.fused_output_quant_supported(
self.quant_dtype, self.quant_key.static,
self.quant_key.group_shape):
if self.layer.impl.fused_output_quant_supported(self.quant_key):
self._register(pm_pass)
@abstractmethod
def _register(self, pm_pass: PatternMatcherPass):
raise NotImplementedError
class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Fp8StaticQuant.
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Fp8StaticQuant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(
self,
layer: Attention,
symmetric: bool = True,
):
quant_key = QuantKey(dtype=FP8_DTYPE,
scale=kStaticTensorScale,
symmetric=symmetric)
super().__init__(layer, quant_key)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
@ -74,9 +110,10 @@ class AttentionStaticQuantPattern:
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None)
attn_out_view = RESHAPE_OP(at1[1],
[-1, self.num_heads * self.head_size])
output_scale=None,
output_block_scale=None)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
result=output_quant,
input=attn_out_view,
@ -98,7 +135,8 @@ class AttentionStaticQuantPattern:
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=scale)
output_scale=scale,
output_block_scale=None)
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
# Need custom fake mode, otherwise tracing happens with real tensors.
@ -114,21 +152,94 @@ class AttentionStaticQuantPattern:
empty_fp32(1, 1) # scale
]
def wrap_trace_fn(process_fx, trace_fn):
pm.register_replacement(
pattern, replacement, inputs,
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
def wrapped(*args, **kwargs):
return process_fx(trace_fn(*args, **kwargs))
return wrapped
class AttentionNvfp4QuantPattern(AttentionQuantPattern):
"""
Fusion for Attention+Nvfp4Quant.
def fx_view_to_reshape(gm: torch.fx.GraphModule):
from torch._inductor.fx_passes.post_grad import view_to_reshape
view_to_reshape(gm)
return gm
Only triggers when the attention implementation returns True in
`fused_output_quant_supported()`. If the pattern is found, the
Nvfp4Quant op will be removed from the graph, and its scale
will be passed into Attention op as the `output_scale` argument.
"""
def __init__(self, layer: Attention):
super().__init__(layer, kNvfp4Quant)
def _register(self, pm_pass: PatternMatcherPass):
def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, input_scale: torch.Tensor):
at1 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=None,
output_block_scale=None)
attn_out_view = RESHAPE_OP(
at1[1], [q.shape[0], self.num_heads * self.head_size])
at2 = auto_functionalized(self.QUANT_OP,
output=output_quant,
input=attn_out_view,
output_scale=output_scale,
input_scale=input_scale)
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view
def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
output_attn: torch.Tensor, output_quant: torch.Tensor,
output_scale: torch.Tensor, input_scale: torch.Tensor):
# attention output in quant_dtype
output_attn = torch.ops.aten.full.default(
[q.shape[0], self.num_heads, self.head_size // 2],
0.0,
dtype=self.quant_dtype,
device=q.device)
# attention output block scale
output_scale_view = torch.ops.aten.view.dtype(
output_scale, FP8_DTYPE)
at2 = auto_functionalized(ATTN_OP,
query=q,
key=k,
value=v,
output=output_attn,
layer_name=self.layer_name,
output_scale=input_scale,
output_block_scale=output_scale_view)
output = RESHAPE_OP(at2[1],
[-1, self.num_heads * self.head_size // 2])
return output, at2[2]
# Need custom fake mode, otherwise tracing happens with real tensors.
# That would not work for the unified_attention custom op.
with unset_fake_temporarily(), FakeTensorMode():
inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q
empty_bf16(5, self.num_heads, self.head_size), # k
empty_bf16(5, self.num_heads, self.head_size), # v
empty_bf16(5, self.num_heads, self.head_size), # output_attn
self.empty_quant(5, self.num_heads * self.head_size //
2), # output_quant
empty_i32(128,
round_up(self.num_heads * self.head_size // 16,
4)), # output_scale
empty_fp32(1, 1), # input_scale
]
pm.register_replacement(
pattern, replacement, inputs,
wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass)
AttentionQuantPattern.wrap_trace_fn(
AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only),
pm_pass)
class AttnFusionPass(VllmInductorPass):
@ -151,8 +262,12 @@ class AttnFusionPass(VllmInductorPass):
attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items():
pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE)
pattern.register_if_supported(self.patterns)
pattern_fp8 = AttentionFp8StaticQuantPattern(layer)
pattern_fp8.register_if_supported(self.patterns)
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer)
pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0:
logger.warning(
"Attention + quant fusion is enabled, but no attention layers "
@ -175,4 +290,6 @@ class AttnFusionPass(VllmInductorPass):
self.end_and_log()
def uuid(self):
return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern)
return VllmInductorPass.hash_source(self, AttentionQuantPattern,
AttentionFp8StaticQuantPattern,
AttentionNvfp4QuantPattern)

View File

@ -1119,9 +1119,20 @@ class ModelConfig:
def _verify_quantization(self) -> None:
supported_quantization = me_quant.QUANTIZATION_METHODS
optimized_quantization_methods = [
"fp8", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed-tensors", "experts_int8", "quark",
"modelopt_fp4", "bitblas", "gptq_bitblas", "inc"
"fp8",
"modelopt",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"fbgemm_fp8",
"compressed-tensors",
"experts_int8",
"quark",
"modelopt_fp4",
"bitblas",
"gptq_bitblas",
"inc",
"petit_nvfp4",
]
if self.quantization is not None:
self.quantization = cast(me_quant.QuantizationMethods,
@ -1153,6 +1164,7 @@ class ModelConfig:
"moe_wna16",
"modelopt",
"modelopt_fp4",
"petit_nvfp4",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides

View File

@ -23,6 +23,39 @@ from vllm.utils import (cuda_device_count_stateless,
logger = init_logger(__name__)
MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available
# For different SM architectures
CUSTOM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: MiB // 2, # 512 KB
8: MiB // 4, # 256 KB
},
"10.0": {
2: 2 * MiB, # 2 MB
4: 2 * MiB, # 2 MB
6: 2 * MiB, # 2 MB
8: 2 * MiB, # 2 MB
}
}
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
"9.0": {
2: 64 * MiB, # 64 MB
4: 32 * MiB, # 32 MB
6: 64 * MiB, # 64 MB
8: 64 * MiB, # 64 MB
},
"10.0": {
2: 8 * MiB, # 8 MB
4: 32 * MiB, # 32 MB
6: 128 * MiB, # 128 MB
8: 128 * MiB, # 128 MB
}
}
def producer(batch_src: Sequence[int],
producer_queue,

View File

@ -44,6 +44,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
PyNcclCommunicator)
from vllm.distributed.device_communicators.quick_all_reduce import (
QuickAllReduce)
from vllm.distributed.device_communicators.symm_mem import (
SymmMemCommunicator)
self.pynccl_comm: Optional[PyNcclCommunicator] = None
if use_pynccl and self.world_size > 1:
@ -54,6 +56,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
@ -69,6 +72,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
# currently be an MI300 series.
self.qr_comm = QuickAllReduce(group=self.cpu_group,
device=self.device)
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
@ -105,6 +114,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
out = ca_comm.custom_all_reduce(input_)
assert out is not None
return out
symm_mem_comm = self.symm_mem_comm
if symm_mem_comm is not None and \
symm_mem_comm.should_use_symm_mem(input_):
out = symm_mem_comm.all_reduce(input_)
assert out is not None
return out
pynccl_comm = self.pynccl_comm
assert pynccl_comm is not None
out = pynccl_comm.all_reduce(input_)

View File

@ -10,8 +10,8 @@ from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.device_communicators.all_reduce_utils import (
CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -109,7 +109,13 @@ class CustomAllreduce:
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
device_capability = current_platform.get_device_capability(
).as_version_str()
if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size],
max_size)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))

View File

@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from vllm.distributed.device_communicators.all_reduce_utils import (
SYMM_MEM_ALL_REDUCE_MAX_SIZES)
from vllm.logger import init_logger
from vllm.platforms import current_platform
try:
import torch.distributed._symmetric_memory as torch_symm_mem
symm_mem_available = True
except ImportError:
symm_mem_available = False
logger = init_logger(__name__)
class SymmMemCommunicator:
_WORLD_SIZES_MULTIMEM = {
"9.0": [4, 6, 8],
"10.0": [6, 8],
}
def __init__(self, group: ProcessGroup, device: Union[int, str,
torch.device]):
self.disabled = True
if not symm_mem_available:
return
if not current_platform.is_cuda():
logger.warning("SymmMemCommunicator: symmetric "
"memory is not available.")
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
torch.cuda.set_device(device)
self.dtype = torch.bfloat16
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = current_platform.get_device_capability(
).as_version_str()
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
"communicator is not available.",
self.device_capability,
)
return
if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[
self.device_capability]:
logger.warning(
"SymmMemCommunicator: World size %d not supported, "
"communicator is not available.",
self.world_size,
)
return
self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][
self.world_size]
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=self.device,
dtype=self.dtype,
)
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
if handle.multicast_ptr == 0:
logger.warning("SymmMemCommunicator: symmetric memory "
"multicast operations are not supported.")
return
self.disabled = False
def should_use_symm_mem(self, inp: torch.Tensor):
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
if inp_size % 4 != 0:
return False
return inp_size < self.max_size
def all_reduce(
self,
inp: torch.Tensor,
*,
out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]:
if not self.should_use_symm_mem(inp):
return None
if out is None:
out = torch.empty_like(inp)
self.buffer[:inp.numel()].copy_(inp.view(-1))
if self.world_size in self._WORLD_SIZES_MULTIMEM[
self.device_capability]:
torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
else:
torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()],
"sum",
self.group.group_name)
out.copy_(self.buffer[:inp.numel()].view(out.shape))
return out

View File

@ -131,8 +131,8 @@ class KVConnectorBase_V1(ABC):
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args: kv_caches:
dictionary of layer names, kv cache
Args:
kv_caches: dictionary of layer names, kv cache
"""
return

View File

@ -686,9 +686,6 @@ class NixlConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
_, first_kv_cache = next(iter(kv_caches.items()))
kv_elem_size = first_kv_cache.element_size()
if self.use_host_buffer:
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
assert len(self.host_xfer_buffers) == len(kv_caches), (
@ -701,66 +698,16 @@ class NixlConnectorWorker:
"host_xfer_buffer should not be initialized when "
f"kv_buffer_device is {self.kv_buffer_device}")
# TODO(tms): Find a more robust way to detect and handle MLA
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
use_mla = len(first_kv_cache.shape) == 3
if self.device_type == "tpu":
assert not use_mla, f"{self.kv_buffer_device} does not support MLA."
assert self._use_pallas_v1, f"attn backend: {self.backend_name}"
# tpu (v1) kv shape per layer:
# (num_blocks, block_size, num_kv_heads * 2, head_size)
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads_x_2, head_dim = block_shape
self.slot_size_bytes = kv_elem_size * n_kv_heads_x_2 * head_dim
elif self.device_type == "cuda":
assert use_mla == self.use_mla
# TODO (NickLucche) not compatible with hybrid allocator.
# Enforce check once it goes live, as a single kv layout
# is expected for xfers.
if use_mla:
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, kv_latent_dim = block_shape
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads, head_dim = block_shape[-3:]
# head size in bytes.
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
assert block_size == self.block_size
else:
raise RuntimeError(
f"{self.device_type} ({self.backend_name}) is not supported.")
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)
logger.info(
"Registering KV_Caches. use_mla: %s, kv_buffer_device: %s, "
"use_host_buffer: %s, num_blocks: %s, block_shape: %s, "
"per_layer_kv_cache_shape: %s", use_mla, self.kv_buffer_device,
self.use_host_buffer, self.num_blocks, block_shape,
first_kv_cache.shape)
self.dst_num_blocks[self.engine_id] = self.num_blocks
self.device_kv_caches = kv_caches
kv_caches_base_addr = []
"use_host_buffer: %s", self.use_mla, self.kv_buffer_device,
self.use_host_buffer)
caches_data = []
# With hybrid allocator, layers can share a kv cache tensor
seen_base_addresses = []
xfer_buffers = (self.host_xfer_buffers
if self.use_host_buffer else kv_caches)
# Note(tms): I modified this from the original region setup code.
# K and V are now in different regions. Advantage is that we can
@ -770,42 +717,35 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for cache_or_caches in xfer_buffers.values():
# Normalize to always be a list of caches
cache_list = [cache_or_caches] if use_mla \
or self._use_pallas_v1 or self._use_flashinfer \
else cache_or_caches
split_k_and_v = not (self.use_mla or self._use_pallas_v1
or self._use_flashinfer)
tensor_size_bytes = None
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = cache_or_caches if split_k_and_v else [
cache_or_caches
]
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len
# NOTE: use tp_rank for device_id since multi-node TP
# is rarely used.
caches_data.append((base_addr, region_len, self.tp_rank, ""))
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
if base_addr in seen_base_addresses:
continue
seen_base_addresses.append(base_addr)
curr_tensor_size_bytes = cache.numel() * cache.element_size()
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
self.num_blocks = cache.shape[0]
assert tensor_size_bytes == curr_tensor_size_bytes, \
"All kv cache tensors must have the same size"
caches_data.append(
(base_addr, tensor_size_bytes, self.tp_rank, ""))
self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys())
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
descs = self.nixl_wrapper.get_reg_descs(caches_data,
self.nixl_memory_type)
logger.debug("Registering descs: %s", caches_data)
@ -813,9 +753,20 @@ class NixlConnectorWorker:
logger.debug("Done registering descs")
self._registered_descs.append(descs)
assert tensor_size_bytes is not None
assert self.num_blocks != 0
assert tensor_size_bytes % self.num_blocks == 0
self.block_len = tensor_size_bytes // self.num_blocks
self.slot_size_bytes = self.block_len // self.block_size
if self._use_flashinfer:
assert self.slot_size_bytes % 2 == 0
self.slot_size_bytes /= 2
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
# Register local/src descr for NIXL xfer.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
for base_addr in seen_base_addresses:
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
@ -836,6 +787,26 @@ class NixlConnectorWorker:
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)
# TODO(mgoin): Hybrid memory allocator is currently diabled for
# models with local attention (Llama 4). Can remove this once enabled.
if self.vllm_config.model_config.hf_config.model_type == "llama4":
from transformers import Llama4TextConfig
assert isinstance(self.vllm_config.model_config.hf_text_config,
Llama4TextConfig)
llama4_config = self.vllm_config.model_config.hf_text_config
no_rope_layers = llama4_config.no_rope_layers
chunk_size = llama4_config.attention_chunk_size
chunk_block_size = math.ceil(chunk_size / self.block_size)
for layer_idx in range(self.num_layers):
# no_rope_layers[layer_idx] == 0 means NoPE (global)
# Any other value means RoPE (local chunked)
is_local_attention = no_rope_layers[layer_idx] != 0
block_window = chunk_block_size if is_local_attention else None
self.block_window_per_layer.append(block_window)
logger.debug("Llama 4 block window per layer mapping: %s",
self.block_window_per_layer)
assert len(self.block_window_per_layer) == self.num_layers
# After KV Caches registered, listen for new connections.
metadata = NixlAgentMetadata(
engine_id=self.engine_id,

View File

@ -605,7 +605,7 @@ class EngineArgs:
**guided_decoding_kwargs["disable_additional_properties"])
guided_decoding_group.add_argument(
"--reasoning-parser",
# This choices is a special case because it's not static
# This choice is a special case because it's not static
choices=list(ReasoningParserManager.reasoning_parsers),
**guided_decoding_kwargs["reasoning_backend"])
@ -1047,7 +1047,7 @@ class EngineArgs:
# details from the config directly
# no user input required / expected
if isinstance(hf_config, SpeculatorsConfig):
# We create one since we dont create one
# We create one since we don't create one
self.speculative_config = {}
self.speculative_config[
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
@ -1775,7 +1775,7 @@ class AsyncEngineArgs(EngineArgs):
def add_cli_args(parser: FlexibleArgumentParser,
async_args_only: bool = False) -> FlexibleArgumentParser:
# Initialize plugin to update the parser, for example, The plugin may
# adding a new kind of quantization method to --quantization argument or
# add a new kind of quantization method to --quantization argument or
# a new device to --device argument.
load_general_plugins()
if not async_args_only:

View File

@ -539,7 +539,7 @@ class MQLLMEngineClient(EngineClient):
if request_id in self.output_queues:
raise ValueError(f"Request {request_id} already exists")
# 1) Create output queue for this requests.
# 1) Create output queue for this request.
queue: asyncio.Queue[Union[RequestOutput,
BaseException]] = asyncio.Queue()
self.output_queues[request_id] = queue
@ -651,7 +651,7 @@ class MQLLMEngineClient(EngineClient):
# Uses the same I/O as generate requests
request = RPCLoadAdapterRequest(lora_request)
# Create output queue for this requests.
# Create output queue for this request.
queue: asyncio.Queue[Union[None, BaseException]] = asyncio.Queue()
self.output_queues[request.request_id] = queue

View File

@ -1330,7 +1330,7 @@ def apply_mistral_chat_template(
# mistral-common uses assert statements to stop processing of input
# if input does not comply with the expected format.
# We convert those assertion errors to ValueErrors so they can be
# are properly caught in the preprocessing_input step
# properly caught in the preprocessing_input step
except (AssertionError, MistralCommonException) as e:
raise ValueError(str(e)) from e

View File

@ -3,6 +3,7 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Union
from openai_harmony import Author, Message, Role, StreamState, TextContent
@ -67,15 +68,27 @@ class HarmonyContext(ConversationContext):
self.parser = get_streamable_parser_for_assistant()
self.num_init_messages = len(messages)
# TODO(woosuk): Implement the following fields.
self.num_prompt_tokens = 0
self.num_cached_tokens = 0
self.num_output_tokens = 0
# TODO(woosuk): Implement the following fields.
self.num_cached_tokens = 0
self.num_reasoning_tokens = 0
def _update_num_prompt_tokens(self, output: RequestOutput):
if output.prompt_token_ids and len(output.prompt_token_ids) > 0:
# NOTE: with built-in tools, there might be multiple rounds in
# the conversation, with the full conversation being resent
# as new prompt each time. Hence the sum.
self.num_prompt_tokens += len(output.prompt_token_ids)
def _update_num_output_tokens(self, token_ids: Sequence[int]):
self.num_output_tokens += len(token_ids)
def append_output(self, output) -> None:
if isinstance(output, RequestOutput):
self._update_num_prompt_tokens(output)
output_token_ids = output.outputs[0].token_ids
self._update_num_output_tokens(output_token_ids)
self.parser = get_streamable_parser_for_assistant()
for token_id in output_token_ids:
self.parser.process(token_id)
@ -158,6 +171,7 @@ class StreamingHarmonyContext(HarmonyContext):
self.parser = get_streamable_parser_for_assistant()
self.encoding = get_encoding()
self.last_tok = None
self.first_tok_of_message = True
@property
def messages(self) -> list:
@ -165,8 +179,18 @@ class StreamingHarmonyContext(HarmonyContext):
def append_output(self, output) -> None:
if isinstance(output, RequestOutput):
# append_output is called for each output token in streaming case,
# so we only want to add the prompt tokens once for each message.
if self.first_tok_of_message:
self._update_num_prompt_tokens(output)
# Reset self.first_tok_of_message if needed:
# if the current token is the last one of the current message
# (finished=True), then the next token processed will mark the
# beginning of a new message
self.first_tok_of_message = output.finished
tok = output.outputs[0].token_ids[0]
self.parser.process(tok)
self._update_num_output_tokens(output.outputs[0].token_ids)
self.last_tok = tok
else:
# Handle the case of tool output in direct message format

View File

@ -3,6 +3,7 @@
from .abstract_tool_parser import ToolParser, ToolParserManager
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
from .deepseekv31_tool_parser import DeepSeekV31ToolParser
from .glm4_moe_tool_parser import Glm4MoeModelToolParser
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
from .granite_tool_parser import GraniteToolParser
@ -18,6 +19,7 @@ from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
from .qwen3coder_tool_parser import Qwen3CoderToolParser
from .seed_oss_tool_parser import SeedOssToolParser
from .step3_tool_parser import Step3ToolParser
from .xlam_tool_parser import xLAMToolParser
@ -35,11 +37,13 @@ __all__ = [
"PythonicToolParser",
"Phi4MiniJsonToolParser",
"DeepSeekV3ToolParser",
"DeepSeekV31ToolParser",
"xLAMToolParser",
"MinimaxToolParser",
"KimiK2ToolParser",
"HunyuanA13BToolParser",
"Glm4MoeModelToolParser",
"Qwen3CoderToolParser",
"SeedOssToolParser",
"Step3ToolParser",
]

View File

@ -0,0 +1,367 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import Union
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("deepseek_v31")
class DeepSeekV31ToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = (
[]) # map what has been streamed for each tool so far to a list
self.tool_calls_start_token: str = "<tool▁calls▁begin>"
self.tool_calls_end_token: str = "<tool▁calls▁end>"
self.tool_call_start_token: str = "<tool▁call▁begin>"
self.tool_call_end_token: str = "<tool▁call▁end>"
self.tool_call_regex = re.compile(
r"<tool▁call▁begin>(?P<function_name>.*)<tool▁sep>(?P<function_arguments>.*)<tool▁call▁end>"
)
self.stream_tool_call_portion_regex = re.compile(
r"(?P<function_name>.*)<tool▁sep>(?P<function_arguments>.*)")
self.stream_tool_call_name_regex = re.compile(
r"(?P<function_name>.*)<tool▁sep>")
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_calls_start_token_id = self.vocab.get(
self.tool_calls_start_token)
self.tool_calls_end_token_id = self.vocab.get(
self.tool_calls_end_token)
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (self.tool_calls_start_token_id is None
or self.tool_calls_end_token_id is None):
raise RuntimeError(
"DeepSeek-V3 Tool parser could not locate tool call start/end "
"tokens in the tokenizer!")
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# sanity check; avoid unnecessary processing
if self.tool_calls_start_token not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
else:
try:
# there are two possible captures - between tags, or between a
# tag and end-of-string so the result of
# findall is an array of tuples where one is a function call and
# the other is None
function_call_tuples = self.tool_call_regex.findall(
model_output)
tool_calls = []
for match in function_call_tuples:
function_name, function_args = match
tool_calls.append(
ToolCall(
type="function",
function=FunctionCall(name=function_name,
arguments=function_args),
))
content = model_output[:model_output.
find(self.tool_calls_start_token)]
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception(
"Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# check to see if we should be streaming a tool call - is there a
if self.tool_calls_start_token_id not in current_token_ids:
logger.debug("No tool call tokens found!")
return DeltaMessage(content=delta_text)
delta_text = delta_text.replace(self.tool_calls_start_token,
"").replace(self.tool_calls_end_token,
"")
try:
# figure out where we are in the parsing by counting tool call
# start & end tags
prev_tool_start_count = previous_token_ids.count(
self.tool_call_start_token_id)
prev_tool_end_count = previous_token_ids.count(
self.tool_call_end_token_id)
cur_tool_start_count = current_token_ids.count(
self.tool_call_start_token_id)
cur_tool_end_count = current_token_ids.count(
self.tool_call_end_token_id)
tool_call_portion = None
text_portion = None
# case: if we're generating text, OR rounding out a tool call
if (cur_tool_start_count == cur_tool_end_count
and prev_tool_end_count == cur_tool_end_count
and self.tool_call_end_token not in delta_text):
logger.debug("Generating text content! skipping tool parsing.")
return DeltaMessage(content=delta_text)
if self.tool_call_end_token in delta_text:
logger.debug("tool_call_end_token in delta_text")
full_text = current_text + delta_text
tool_call_portion = full_text.split(
self.tool_call_start_token)[-1].split(
self.tool_call_end_token)[0].rstrip()
delta_text = delta_text.split(
self.tool_call_end_token)[0].rstrip()
text_portion = delta_text.split(
self.tool_call_end_token)[-1].lstrip()
# case -- we're starting a new tool call
if (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count > prev_tool_start_count):
if len(delta_token_ids) > 1:
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
else:
tool_call_portion = None
delta = None
text_portion = None
# set cursors and state appropriately
self.current_tool_id += 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("Starting on a new tool %s", self.current_tool_id)
# case -- we're updating an existing tool call
elif (cur_tool_start_count > cur_tool_end_count
and cur_tool_start_count == prev_tool_start_count):
# get the portion of the text that's the tool call
tool_call_portion = current_text.split(
self.tool_call_start_token)[-1]
text_portion = None
# case -- the current tool call is being closed.
elif (cur_tool_start_count == cur_tool_end_count
and cur_tool_end_count >= prev_tool_end_count):
if self.prev_tool_call_arr is None or len(
self.prev_tool_call_arr) == 0:
logger.debug(
"attempting to close tool call, but no tool call")
return None
diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
if diff:
diff = (diff.encode("utf-8").decode("unicode_escape")
if diff is str else diff)
if '"}' not in delta_text:
return None
end_loc = delta_text.rindex('"}')
diff = delta_text[:end_loc] + '"}'
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s",
diff,
)
self.streamed_args_for_tool[self.current_tool_id] += diff
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(exclude_none=True),
)
])
# case -- otherwise we're just generating text
else:
text = delta_text.replace(self.tool_call_start_token, "")
text = text.replace(self.tool_call_end_token, "")
delta = DeltaMessage(tool_calls=[], content=text)
return delta
current_tool_call = dict()
if tool_call_portion:
current_tool_call_matches = (
self.stream_tool_call_portion_regex.match(
tool_call_portion))
if current_tool_call_matches:
tool_name, tool_args = current_tool_call_matches.groups()
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = tool_args
else:
current_tool_call_name_matches = (
self.stream_tool_call_name_regex.match(
tool_call_portion))
if current_tool_call_name_matches:
tool_name = current_tool_call_name_matches.groups()
current_tool_call["name"] = tool_name
current_tool_call["arguments"] = ""
else:
logger.debug("Not enough token")
return None
# case - we haven't sent the tool name yet. If it's available, send
# it. otherwise, wait until it's available.
if not self.current_tool_name_sent:
if current_tool_call is None:
return None
function_name: Union[str, None] = current_tool_call.get("name")
if function_name:
self.current_tool_name_sent = True
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True),
)
])
else:
return None
# case -- otherwise, send the tool call delta
# if the tool call portion is None, send the delta as text
if tool_call_portion is None:
# if there's text but not tool calls, send that -
# otherwise None to skip chunk
delta = (DeltaMessage(
content=delta_text) if text_portion is not None else None)
return delta
# now, the nitty-gritty of tool calls
# now we have the portion to parse as tool call.
logger.debug("Trying to parse current tool call with ID %s",
self.current_tool_id)
# if we're starting a new tool call, push an empty object in as
# a placeholder for the arguments
if len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
# main logic for tool parsing here - compare prev. partially-parsed
# JSON to the current partially-parsed JSON
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
cur_arguments = current_tool_call.get("arguments")
logger.debug("diffing old arguments: %s", prev_arguments)
logger.debug("against new ones: %s", cur_arguments)
# case -- no arguments have been created yet. skip sending a delta.
if not cur_arguments and not prev_arguments:
logger.debug("Skipping text %s - no arguments", delta_text)
delta = None
# case -- prev arguments are defined, but non are now.
# probably impossible, but not a fatal error - just keep going
elif not cur_arguments and prev_arguments:
logger.error("should be impossible to have arguments reset "
"mid-call. skipping streaming anything.")
delta = None
# case -- we now have the first info about arguments available from
# autocompleting the JSON
elif cur_arguments and not prev_arguments:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=cur_arguments).model_dump(
exclude_none=True),
)
])
self.streamed_args_for_tool[
self.current_tool_id] = cur_arguments
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
if (isinstance(delta_text, str)
and cur_arguments != prev_arguments
and len(cur_arguments) > len(prev_arguments)
and cur_arguments.startswith(prev_arguments)):
delta_arguments = cur_arguments[len(prev_arguments):]
logger.debug("got diff %s", delta_text)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=delta_arguments).model_dump(
exclude_none=True),
)
])
self.streamed_args_for_tool[
self.current_tool_id] = cur_arguments
else:
delta = None
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
self.prev_tool_call_arr[
self.current_tool_id] = current_tool_call
else:
self.prev_tool_call_arr.append(current_tool_call)
return delta
except Exception:
logger.exception("Error trying to handle streaming tool call.")
return None # do not stream a delta. skip this token ID.

View File

@ -0,0 +1,676 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from qwen3coder xml parser, All rights reserved.
# ruff: noqa: E501
import ast
import json
import uuid
from collections.abc import Sequence
from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionToolsParam,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("seed_oss")
class SeedOssToolParser(ToolParser):
TOOL_CALL_START = "<seed:tool_call>"
TOOL_CALL_END = "</seed:tool_call>"
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# --- streaming state ---
self._reset_streaming_state()
self.prev_tool_call_arr: list[dict] = []
self.tool_call_start_token: str = self.TOOL_CALL_START
self.tool_call_end_token: str = self.TOOL_CALL_END
# Sentinel tokens for streaming mode
self.tool_call_prefix: str = "<function="
self.function_end_token: str = "</function>"
self.parameter_prefix: str = "<parameter="
self.parameter_end_token: str = "</parameter>"
self.think_start_token: str = "<seed:think>"
self.think_end_token: str = "</seed:think>"
self.is_tool_call_started: bool = False
self.is_thinking_end: bool = False
self.failed_count: int = 0
self._reset_streaming_state()
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if (self.tool_call_start_token_id is None
or self.tool_call_end_token_id is None):
raise RuntimeError(
"Seed_Oss XML parser: tokenizer did not include "
"<seed:tool_call> or its closing tag.")
tool_start_re = re.escape(self.tool_call_start_token)
tool_end_re = re.escape(self.tool_call_end_token)
self.tool_call_complete_regex = re.compile(
rf"{tool_start_re}(.*?){tool_end_re}", re.DOTALL)
self.tool_call_regex = re.compile(
rf"{tool_start_re}(.*?){tool_end_re}|{tool_start_re}(.*?)$",
re.DOTALL)
self.tool_call_function_regex = re.compile(
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL)
logger.info("vLLM Seed-Oss XML tool parser loaded (%s).",
self.__class__.__name__)
def _generate_tool_call_id(self) -> str:
"""Generate a unique tool call ID."""
return f"call_{uuid.uuid4().hex[:24]}"
def _reset_streaming_state(self):
"""Reset all streaming state."""
self.current_tool_index = 0
self.is_tool_call_started = False
self.header_sent = False
self.current_tool_id = -1
self.current_function_name = None
self.current_param_name = None
self.current_param_value = ""
self.param_count = 0
self.in_param = False
self.in_function = False
self.accumulated_text = ""
self.json_started = False
self.json_closed = False
def _parse_xml_function_call(
self, function_call_str: str,
tools: Optional[list[ChatCompletionToolsParam]]
) -> Optional[ToolCall]:
def get_arguments_config(func_name: str) -> dict:
if tools is None:
return {}
for config in tools:
if not hasattr(config, "type") or not (
hasattr(config, "function")
and hasattr(config.function, "name")):
continue
if (config.type == "function"
and config.function.name == func_name):
if not hasattr(config.function, "parameters"):
return {}
params = config.function.parameters
if isinstance(params, dict) and "properties" in params:
return params["properties"]
elif isinstance(params, dict):
return params
else:
return {}
logger.warning("Tool '%s' is not defined in the tools list.",
func_name)
return {}
def convert_param_value(param_value: str, param_name: str,
param_config: dict, func_name: str) -> Any:
# Handle null value for any type
if param_value.lower() == "null":
return None
if param_name not in param_config:
if param_config != {}:
logger.warning(
"Parsed parameter '%s' is not defined in "
"the tool parameters for tool '%s', "
"directly returning the string value.", param_name,
func_name)
return param_value
if (isinstance(param_config[param_name], dict)
and "type" in param_config[param_name]):
param_type = str(
param_config[param_name]["type"]).strip().lower()
else:
param_type = "string"
if param_type in [
"string", "str", "text", "varchar", "char", "enum"
]:
return param_value
elif (param_type.startswith("int") or param_type.startswith("uint")
or param_type.startswith("long")
or param_type.startswith("short")
or param_type.startswith("unsigned")):
try:
param_value = int(param_value) # type: ignore
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not an integer in tool "
"'%s', degenerating to string.", param_value,
param_name, func_name)
return param_value
elif param_type.startswith("num") or param_type.startswith(
"float"):
try:
float_param_value = float(param_value)
param_value = float_param_value if float_param_value - int(
float_param_value) != 0 else int(
float_param_value) # type: ignore
except (ValueError, TypeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not a float in tool "
"'%s', degenerating to string.", param_value,
param_name, func_name)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
param_value = param_value.lower()
if param_value not in ["true", "false"]:
logger.warning(
"Parsed value '%s' of parameter '%s' is not a boolean "
"(`true` of `false`) in tool '%s', degenerating to false.",
param_value, param_name, func_name)
return param_value == "true"
else:
if param_type == "object" or param_type.startswith("dict"):
try:
param_value = json.loads(param_value)
return param_value
except (ValueError, TypeError, json.JSONDecodeError):
logger.warning(
"Parsed value '%s' of parameter '%s' is not a valid JSON "
"object in tool '%s', will try other methods to parse it.",
param_value, param_name, func_name)
try:
param_value = ast.literal_eval(param_value)
except (ValueError, SyntaxError):
logger.warning(
"Parsed value '%s' of parameter '%s' cannot be converted via "
"Python `ast.literal_eval()` in tool '%s', degenerating to string.",
param_value, param_name, func_name)
return param_value
# Extract function name
end_index = function_call_str.index(">")
function_name = function_call_str[:end_index]
param_config = get_arguments_config(function_name)
parameters = function_call_str[end_index + 1:]
param_dict = {}
for match in self.tool_call_parameter_regex.findall(parameters):
match_text = match[0] if match[0] else match[1]
idx = match_text.index(">")
param_name = match_text[:idx]
param_value = str(match_text[idx + 1:])
# Remove prefix and trailing \n
if param_value.startswith("\n"):
param_value = param_value[1:]
if param_value.endswith("\n"):
param_value = param_value[:-1]
param_dict[param_name] = convert_param_value(
param_value, param_name, param_config, function_name)
return ToolCall(
type="function",
function=FunctionCall(name=function_name,
arguments=json.dumps(param_dict,
ensure_ascii=False)),
)
def _get_function_calls(self, model_output: str) -> list[str]:
# Find all tool calls
matched_ranges = self.tool_call_regex.findall(model_output)
raw_tool_calls = [
match[0] if match[0] else match[1] for match in matched_ranges
]
# Back-off strategy if no tool_call tags found
if len(raw_tool_calls) == 0:
raw_tool_calls = [model_output]
raw_function_calls = []
for tool_call in raw_tool_calls:
raw_function_calls.extend(
self.tool_call_function_regex.findall(tool_call))
function_calls = [
match[0] if match[0] else match[1] for match in raw_function_calls
]
return function_calls
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# Quick check to avoid unnecessary processing
if self.tool_call_prefix not in model_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
# Check if both think start and end tokens are present
if (self.think_start_token in model_output
and self.think_end_token in model_output):
# Find the position of think end token
think_end_index = model_output.find(self.think_end_token) + len(
self.think_end_token)
# Extract content after think end token
result_content = model_output[think_end_index:]
thinking_content = model_output[:think_end_index]
try:
function_calls = self._get_function_calls(result_content)
if len(function_calls) == 0:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
tool_calls = [
self._parse_xml_function_call(function_call_str, request.tools)
for function_call_str in function_calls
]
# Populate prev_tool_call_arr for serving layer to set finish_reason
self.prev_tool_call_arr.clear() # Clear previous calls
for tool_call in tool_calls:
if tool_call:
self.prev_tool_call_arr.append({
"name":
tool_call.function.name,
"arguments":
tool_call.function.arguments,
})
# Extract content before tool calls
tool_call_start_index = result_content.find(
self.tool_call_start_token)
tool_call_start_index = (
tool_call_start_index if tool_call_start_index >= 0 else
result_content.find(self.tool_call_prefix))
content = thinking_content + result_content[:tool_call_start_index]
return ExtractedToolCallInformation(
tools_called=(len(tool_calls) > 0),
tool_calls=tool_calls,
content=content if content else None,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
# If no delta text, return None unless
# it's an EOS token after tool calls
if not delta_text:
# Check if this is an EOS token after all tool calls are complete
# We check for tool calls in the text even if is_tool_call_started
# is False because it might have been reset after processing all tools
if (delta_token_ids
and self.tool_call_end_token_id not in delta_token_ids):
# Count complete tool calls
complete_calls = len(
self.tool_call_complete_regex.findall(current_text))
# If we have completed tool calls and populated prev_tool_call_arr
if complete_calls > 0 and len(self.prev_tool_call_arr) > 0:
# Check if all tool calls are closed
open_calls = current_text.count(
self.tool_call_start_token) - current_text.count(
self.tool_call_end_token)
if open_calls == 0:
# Return empty delta message to allow finish_reason processing
return DeltaMessage(content="")
elif not self.is_tool_call_started and current_text:
# This is a regular content response that's now complete
return DeltaMessage(content="")
return None
# Check if this is the first call (reset state if needed)
if not previous_text:
self._reset_streaming_state()
# Update accumulated text
self.accumulated_text = current_text
# Check if we need to advance to next tool
if self.json_closed and not self.in_function:
# Check if this tool call has ended
tool_ends = current_text.count(self.tool_call_end_token)
if tool_ends > self.current_tool_index:
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
self.param_count = 0
self.json_started = False
self.json_closed = False
# Check if there are more tool calls
if self.current_tool_index >= current_text.count(
self.tool_call_start_token):
# No more tool calls
self.is_tool_call_started = False
# Continue processing next tool
return None
# Check if end thinking
if (not self.is_thinking_end
and (self.think_end_token_id in delta_token_ids
or self.think_end_token in delta_text)):
self.is_thinking_end = True
# If thinking hasn't ended yet, don't process any tool calls
if not self.is_thinking_end:
return DeltaMessage(content=delta_text)
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
if (self.tool_call_start_token_id in delta_token_ids
or self.tool_call_start_token in delta_text):
self.is_tool_call_started = True
# Return any content before the tool call
if self.tool_call_start_token in delta_text:
content_before = delta_text[:delta_text.index(
self.tool_call_start_token)]
if content_before:
return DeltaMessage(content=content_before)
return None
else:
# Check if we're between tool calls - skip whitespace
if (current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""):
# We just ended a tool call, skip whitespace
return None
# Normal content, no tool call
return DeltaMessage(content=delta_text)
# Check if we're between tool calls (waiting for next one)
# Count tool calls we've seen vs processed
tool_starts_count = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts_count:
# We're past all tool calls, shouldn't be here
return None
# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
# Only process tool calls after think_end_token
think_end_index = current_text.find(self.think_end_token) + len(
self.think_end_token
) if self.think_end_token in current_text else 0
tool_starts: list[int] = []
idx = think_end_index
while True:
idx = current_text.find(self.tool_call_start_token, idx)
if idx == -1:
break
tool_starts.append(idx)
idx += len(self.tool_call_start_token)
if self.current_tool_index >= len(tool_starts):
# No more tool calls to process yet
return None
tool_start_idx = tool_starts[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
tool_end_idx = current_text.find(self.tool_call_end_token,
tool_start_idx)
if tool_end_idx == -1:
tool_text = current_text[tool_start_idx:]
else:
tool_text = current_text[tool_start_idx:tool_end_idx +
len(self.tool_call_end_token)]
# Looking for function header
if not self.header_sent:
if self.tool_call_prefix in tool_text:
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix)
func_end = tool_text.find(">", func_start)
if func_end != -1:
# Found complete function name
self.current_function_name = tool_text[func_start:func_end]
self.current_tool_id = self._generate_tool_call_id(
) # type: ignore
self.header_sent = True
self.in_function = True
# IMPORTANT: Add to prev_tool_call_arr immediately when we detect a tool call
# This ensures finish_reason="tool_calls" even if parsing isn't complete
already_added = any(
tool.get("name") == self.current_function_name
for tool in self.prev_tool_call_arr)
if not already_added:
self.prev_tool_call_arr.append({
"name": self.current_function_name,
"arguments":
"{}", # Placeholder, will be updated later
})
# Send header with function info
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
id=self.current_tool_id,
function=DeltaFunctionCall(
name=self.current_function_name, arguments=""),
type="function",
)
])
return None
# We've sent header, now handle function body
if self.in_function:
# Send opening brace if not sent yet
if (not self.json_started
and self.parameter_prefix not in delta_text):
self.json_started = True
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="{"),
)
])
# Make sure json_started is set if we're processing parameters
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.function_end_token in tool_text:
# Close JSON
self.json_closed = True
# Extract the complete tool call to update prev_tool_call_arr with final arguments
# Find the function content
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix)
func_content_end = tool_text.find(self.function_end_token,
func_start)
if func_content_end != -1:
func_content = tool_text[func_start:func_content_end]
# Parse to get the complete arguments
try:
parsed_tool = self._parse_xml_function_call(
func_content, request.tools if request else None)
if parsed_tool:
# Update existing entry in prev_tool_call_arr with complete arguments
for i, tool in enumerate(self.prev_tool_call_arr):
if tool.get(
"name") == parsed_tool.function.name:
self.prev_tool_call_arr[i]["arguments"] = (
parsed_tool.function.arguments)
break
except Exception:
logger.warning(
"Failed to parse tool arguments during streaming.",
exc_info=True)
result = DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
])
# Reset state for next tool
self.in_function = False
self.json_closed = True
return result
# Look for parameters
# Count how many complete parameters we have processed
complete_params = tool_text.count(self.parameter_end_token)
# Check if we should start a new parameter
if not self.in_param and self.param_count < complete_params:
# Find the unprocessed parameter
# Count parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
if len(param_starts) > self.param_count:
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
self.current_param_name = remaining[:name_end]
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(
self.parameter_end_token)
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Build complete JSON fragment for this parameter
if self.param_count == 0:
json_fragment = (
'"' + self.current_param_name + '": "' +
json.dumps(param_value)[1:-1] + '"')
else:
json_fragment = (
', "' + self.current_param_name + '": "' +
json.dumps(param_value)[1:-1] + '"')
self.param_count += 1
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=json_fragment),
)
])
# Continue parameter value
if self.in_param:
if self.parameter_end_token in delta_text:
# End of parameter
end_idx = delta_text.find(self.parameter_end_token)
value_chunk = delta_text[:end_idx]
# Skip past > if at start
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if not self.current_param_value and value_chunk.startswith(
"\n"):
value_chunk = value_chunk[1:]
# Calculate incremental JSON
full_value = self.current_param_value + value_chunk
prev_escaped = (json.dumps(self.current_param_value)[1:-1]
if self.current_param_value else "")
full_escaped = json.dumps(full_value)[1:-1]
delta_escaped = full_escaped[len(prev_escaped):]
self.in_param = False
self.current_param_value = ""
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped + '"'),
)
])
else:
# Continue accumulating value
value_chunk = delta_text
# Handle first chunk after param name
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1:]
if not self.current_param_value and value_chunk.startswith(
"\n"):
value_chunk = value_chunk[1:]
if value_chunk:
# Stream the escaped delta
prev_escaped = (json.dumps(
self.current_param_value)[1:-1]
if self.current_param_value else "")
self.current_param_value += value_chunk
full_escaped = json.dumps(
self.current_param_value)[1:-1]
delta_escaped = full_escaped[len(prev_escaped):]
if delta_escaped:
return DeltaMessage(tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped),
)
])
return None

View File

@ -158,8 +158,10 @@ if TYPE_CHECKING:
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None
VLLM_HAS_FLASHINFER_CUBIN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
@ -1105,6 +1107,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
# If set, it means we pre-downloaded cubin files and flashinfer will
# read the cubin files directly.
"VLLM_HAS_FLASHINFER_CUBIN":
lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False),
# If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer.
# Otherwise, uses the first available of: flashinfer cutlass GEMM,
# vllm cutlass GEMM, marlin GEMM.
@ -1150,6 +1157,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_RESPONSES_API_STORE":
lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))),
# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
# Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER":
lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),

View File

@ -0,0 +1,122 @@
{
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2
}
}

View File

@ -0,0 +1,114 @@
{
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8192": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
}
}

View File

@ -52,6 +52,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"HQQMarlinMethod",
"QuarkLinearMethod",
"ModelOptNvFp4LinearMethod",
"PetitNvFp4LinearMethod",
]

View File

@ -35,6 +35,7 @@ QuantizationMethods = Literal[
"rtn",
"inc",
"mxfp4",
"petit_nvfp4",
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .neuron_quant import NeuronQuantConfig
from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config
from .rtn import RTNConfig
from .torchao import TorchAOConfig
@ -142,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"rtn": RTNConfig,
"inc": INCConfig,
"mxfp4": Mxfp4Config,
"petit_nvfp4": PetitNvFp4Config,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

View File

@ -13,7 +13,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
@ -28,8 +29,10 @@ logger = init_logger(__name__)
class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
def __init__(self, ) -> None:
def __init__(self,
unquantized_modules: Optional[list[str]] = None) -> None:
super().__init__()
self.unquantized_modules = unquantized_modules or []
def __repr__(self) -> str:
return ("GGUFConfig()")
@ -55,6 +58,8 @@ class GGUFConfig(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
return UnquantizedLinearMethod()
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
@ -63,6 +68,10 @@ class GGUFConfig(QuantizationConfig):
return None
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
return any(module_name in prefix for module_name in unquantized_modules)
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
STANDARD_QUANT_TYPES = {
WeightType.Q4_0,

View File

@ -0,0 +1,306 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
from typing import Any, Optional
import regex as re
import torch
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.petit_utils import (
apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit,
verify_petit_nvfp4_supported)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
# Initialize logger for the module
logger = init_logger(__name__)
# Configuration class to support the NVFP4 quantized model
# generated by the ModelOpt quantization tool
class PetitNvFp4Config(QuantizationConfig):
"""Config class for Petit FP4."""
def __init__(
self,
is_checkpoint_nvfp4_serialized: bool = False,
kv_cache_quant_algo: Optional[str] = None,
group_size: Optional[int] = None,
exclude_modules: Optional[list[str]] = None,
) -> None:
self._check_hardware_support()
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning("Detected nvfp4 checkpoint. Please note that the "
"format is experimental and subject to change.")
self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules
def _check_hardware_support(self) -> None:
"""
Verifies that the current hardware is supported by the Petit backend.
This backend is specifically designed for AMD GPUs and is not
supported on the CUDA platform.
"""
# This check ensures the code is NOT running on an NVIDIA GPU.
if current_platform.is_cuda():
raise ValueError(
"The 'petit' quantization backend is designed for AMD GPUs "
"and is not supported on the CUDA platform. For NVIDIA GPUs, "
"please use a different quantization method such as FP8, AWQ, "
"or GPTQ.")
@classmethod
def get_name(cls) -> QuantizationMethods:
return "petit_nvfp4"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
# Petit supports the gfx90a and gfx942 GPUs
return 90
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config":
qc = cls.get_from_keys(config, ["quantization"])
quant_method_raw = qc.get("quant_algo")
if not isinstance(quant_method_raw, str) or not quant_method_raw:
raise ValueError(
"Missing or invalid 'quant_algo' in quantization config.")
quant_method = quant_method_raw.upper()
group_size_raw = qc.get("group_size")
if not isinstance(group_size_raw, int):
raise ValueError(
"Missing or invalid 'group_size' (int) in hf_quant_config.json."
)
group_size = group_size_raw
verify_petit_nvfp4_supported(quant_method, group_size)
kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto"
if not isinstance(kv_cache_quant_algo_raw, str):
raise ValueError(
"'kv_cache_quant_algo' must be a string if provided.")
kv_cache_quant_algo = kv_cache_quant_algo_raw
exclude_raw = qc.get("exclude_modules", [])
if exclude_raw is None:
exclude_modules: list[str] = []
elif isinstance(exclude_raw, list) and all(
isinstance(x, str) for x in exclude_raw):
exclude_modules = exclude_raw
else:
raise ValueError(
"'exclude_modules' must be a list[str] (or omitted).")
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
return cls(
is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized,
kv_cache_quant_algo=kv_cache_quant_algo,
group_size=group_size,
exclude_modules=exclude_modules,
)
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
if not current_platform.is_rocm():
return None
qc = hf_quant_cfg.get("quantization", hf_quant_cfg)
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"):
return cls.get_name() # "petit_nvfp4"
return None
@classmethod
def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool:
qc = quant_config.get("quantization", quant_config)
algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper()
return algo == "NVFP4"
def is_layer_excluded(self, prefix: str,
exclude_modules: list[str]) -> bool:
for pattern in exclude_modules:
regex_str = pattern.replace(".", r"\.").replace("*", r".*")
if re.fullmatch(regex_str, prefix):
return True
return False
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
exclude = self.require_exclude_modules()
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, exclude) or self.is_layer_excluded(
prefix, exclude):
return UnquantizedLinearMethod()
return PetitNvFp4LinearMethod(self)
elif isinstance(layer, Attention):
return PetitFp8KVCacheMethod(self)
return None
def get_scaled_act_names(self) -> list[str]:
return []
def require_group_size(self) -> int:
if self.group_size is None:
logger.warning("group_size not set; defaulting to 16 for NVFP4.")
return 16
return self.group_size
def require_kv_cache_quant_algo(self) -> str:
return self.kv_cache_quant_algo or "auto"
def require_exclude_modules(self) -> list[str]:
return list(self.exclude_modules or [])
class PetitFp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: PetitNvFp4Config):
super().__init__(quant_config)
class PetitNvFp4LinearMethod(LinearMethodBase):
"""Linear method for NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
|Tensor Name | datatype | shape |
|----------------------------------------------------|
|input_scale | torch.float32 | scalar |
|weight | NVFP4(SE2M1) | [1, X, y/2] |
|weight_scale | FP8-E4M3 | [X, Y] |
|weight_scale_2 | torch.float32 | scalar |
The weights are quantized per block of 16 elements.
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: PetitNvFp4Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError("NVFP4 quantization was selected, "
" dynamic quantization is not supported.")
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
if input_size_per_partition % 16 != 0:
raise ValueError("Unsupported model when in features size is "
"not multiple of 16")
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_nvfp4_serialized
else params_dtype)
weight = ModelWeightParameter(
data=torch.empty(
# 2 fp4 data is packed in one uint8 in the input dimension
output_size_per_partition,
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
input_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale)
weight_scale_2 = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale_2", weight_scale_2)
group_size = self.quant_config.require_group_size()
weight_scale = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // group_size,
dtype=weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
input_scale_2 = layer.input_scale.max().to(torch.float32)
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
requires_grad=False)
prepare_nvfp4_layer_for_petit(layer)
del layer.input_scale
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_petit_nvfp4_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)

View File

@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional
import torch
# TYPE_CHECKING is used for static type analysis to prevent circular imports.
if TYPE_CHECKING:
from types import ModuleType
# 1. Create a global variable as a placeholder for the module
_petit_kernel: Optional["ModuleType"] = None
_PETIT_INSTALL_MSG = ("Petit is not installed. Please install it with "
"`pip install petit-kernel`.")
def _import_petit_kernel() -> "ModuleType":
"""
A helper function to handle the lazy import.
The first time this function is called, it will import the petit_kernel
library and store it in the global _petit_kernel variable.
Subsequent calls will return the already-loaded module directly.
"""
global _petit_kernel
if _petit_kernel is not None:
return _petit_kernel
try:
import petit_kernel
_petit_kernel = petit_kernel
return _petit_kernel
except ImportError:
# The 'from None' syntax prevents chaining the original ImportError,
# making the traceback cleaner.
raise ImportError(_PETIT_INSTALL_MSG) from None
# The _require_petit function can now be a simple alias for consistency.
_require_petit = _import_petit_kernel
def _check_petit_nvfp4_supported(
quant_method: str,
group_size: Optional[int]) -> tuple[bool, Optional[str]]:
if quant_method != "NVFP4":
return (
False,
("Petit currently only supports: NVFP4 quantizations in sglang. "
"Please check the `hf_quant_config.json` file for your model's "
"quant configuration."),
)
if group_size is not None and group_size != 16:
return (
False,
"Petit currently only supports: group_size=16 quantizations.",
)
return (True, None)
def verify_petit_nvfp4_supported(quant_method: str,
group_size: Optional[int]) -> None:
supported, error_msg = _check_petit_nvfp4_supported(
quant_method, group_size)
if not supported:
assert error_msg is not None
raise ValueError(error_msg)
def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None:
# 2. Call _import_petit_kernel() to trigger (or get) the import.
petit_kernel = _import_petit_kernel()
# Repack weights to petit format
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
qweight = layer.weight.view(torch.int32).contiguous()
# 3. Call functions through the imported module variable.
petit_qweight = petit_kernel.repack_nvfp4(qweight,
size_n=part_size_n,
size_k=part_size_k)
layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False)
# Permute scales
weight_scale = petit_kernel.process_nvfp4_scales(scales=layer.weight_scale,
size_k=part_size_k,
size_n=part_size_n)
layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
def apply_petit_nvfp4_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Trigger (or get) the import here as well.
petit_kernel = _import_petit_kernel()
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n, )
# TODO: Use auto-tuning to find the performant solution_id
# Call the function via the module variable.
output = petit_kernel.mul_nvfp4_a16(
a=reshaped_x,
b=weight,
s=weight_scale,
global_scale=weight_scale_2,
size_m=reshaped_x.size(0),
size_n=size_n,
size_k=size_k,
solution_id=-1,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@ -2,16 +2,21 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""This file is used for /tests and /benchmarks"""
from collections.abc import Mapping
from dataclasses import dataclass
from types import MappingProxyType
from typing import ClassVar, NamedTuple, Optional
import numpy
import torch
from torch import fx
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types
FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8
# Use proxy as NamedTuple direct subclasses cannot have static members
class _GroupShape(NamedTuple):
@ -34,6 +39,64 @@ GroupShape.PER_TENSOR = GroupShape(-1, -1)
GroupShape.PER_TOKEN = GroupShape(1, -1)
@dataclass(frozen=True)
class ScaleDesc:
"""
Class for describing a single quantization scaling factor.
dtype: data type of the scale
static: static scale if True, dynamic if False
group_shape: group shape of the scale
"""
dtype: torch.dtype
static: bool
group_shape: GroupShape
def __str__(self):
group_shape = ('per_tensor'
if self.group_shape == GroupShape.PER_TENSOR else
('per_token' if self.group_shape == GroupShape.PER_TOKEN
else str(self.group_shape)))
return (f"{fx.graph.dtype_abbrs[self.dtype]},"
f"{'static' if self.static else 'dynamic'},{group_shape}")
@dataclass(frozen=True)
class QuantKey:
"""
Class for identifying the type of quantization.
dtype: quantized data type
scale: scale descriptor
scale2: second-level scale descriptor
symmetric: symmetric if True, asymmetric if False
"""
dtype: torch.dtype
scale: ScaleDesc
scale2: Optional[ScaleDesc] = None
symmetric: bool = True
def __str__(self):
scale2_str = f"scale2({self.scale2})," if self.scale2 else ""
return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]},"
f"scale({self.scale}),{scale2_str}"
f"{'a' if not self.symmetric else ''}symmetric)")
kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR)
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16))
kNvfp4Quant = QuantKey(FP4_DTYPE,
scale=kNvfp4GroupScale,
scale2=kStaticTensorScale)
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# -1 means full extent

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
@ -156,13 +157,10 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
return output.view(*output_shape)
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
def rocm_per_tensor_w8a8_scaled_mm_impl(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
@ -175,10 +173,38 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
return output
def rocm_per_tensor_w8a8_scaled_mm_fake(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor) -> torch.Tensor:
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]),
dtype=out_dtype)
def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d)
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape)
direct_register_custom_op(
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
mutates_args=[],
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
dispatch_key=current_platform.dispatch_key,
)
def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,

View File

@ -14,7 +14,8 @@ from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
get_gguf_extra_tensor_names, gguf_quant_weights_iterator)
get_gguf_extra_tensor_names, get_gguf_weight_type_map,
gguf_quant_weights_iterator)
class GGUFModelLoader(BaseModelLoader):
@ -132,6 +133,17 @@ class GGUFModelLoader(BaseModelLoader):
local_model_path, gguf_weights_map):
model_config.hf_config.update({"tie_word_embeddings": True})
weight_type_map = get_gguf_weight_type_map(model_config.model,
gguf_weights_map)
# filter out unquantized modules to skip
unquant_names = [
name.removesuffix(".weight")
for name, weight_type in weight_type_map.items()
if weight_type == "F32" and name.endswith(".weight")
]
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:

View File

@ -31,9 +31,7 @@ from vllm.utils import PlaceholderModule
try:
from runai_model_streamer import SafetensorsStreamer
except (ImportError, OSError):
# see https://github.com/run-ai/runai-model-streamer/issues/26
# OSError will be raised on arm64 platform
except ImportError:
runai_model_streamer = PlaceholderModule(
"runai_model_streamer") # type: ignore[assignment]
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
@ -565,6 +563,18 @@ def get_gguf_extra_tensor_names(
return [gguf_to_hf_name_map[key] for key in extra_keys]
def get_gguf_weight_type_map(
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> dict[str, str]:
"""
Return GGUF mapped weight's name and its quant type
"""
reader = gguf.GGUFReader(gguf_file)
return {
gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name
for tensor in reader.tensors if tensor.name in gguf_to_hf_name_map
}
def gguf_quant_weights_iterator(
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
) -> Generator[tuple[str, torch.Tensor], None, None]:

View File

@ -8,7 +8,7 @@ import torch
from torch import nn
from transformers import BertConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
@ -239,14 +239,13 @@ class BertSelfAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")
self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_ONLY)
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,

View File

@ -7,7 +7,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
@ -119,14 +119,13 @@ class BertWithRopeAttention(nn.Module):
self.rotary_emb = get_rope(**rotary_kwargs)
self.attn = Attention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=AttentionType.ENCODER_ONLY)
self.attn = EncoderOnlyAttention(num_heads=self.num_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.out_proj = RowParallelLinear(input_size=hidden_size,
output_size=hidden_size,

View File

@ -647,7 +647,8 @@ class Florence2LanguageModel(nn.Module):
encoder_hidden_states = None
if inputs_embeds is not None or encoder_input_ids.numel() > 0:
if ((inputs_embeds is not None and inputs_embeds.numel() > 0)
or encoder_input_ids.numel() > 0):
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
@ -681,6 +682,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
self.lm_head = BartParallelLMHead(self.vocab_size,
config.d_model,
embed_scale=embed_scale)
if self.config.tie_word_embeddings:
self.lm_head.tie_weights(self.model.shared)
self.logits_processor = LogitsProcessor(self.vocab_size,
config.vocab_size)
@ -749,7 +752,8 @@ class Florence2LanguageForConditionalGeneration(nn.Module, SupportsV0Only):
else:
if "final_logits_bias" in name:
continue
if self.config.tie_word_embeddings and "embed_tokens" in name:
if self.config.tie_word_embeddings and ("embed_tokens" in name
or "lm_head" in name):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",

View File

@ -74,7 +74,8 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from ..layers.activation import SiluAndMul
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .qwen2_vl import _qwen2vl_field_config, apply_rotary_pos_emb_vision
from .qwen2_vl import (_create_qwen2vl_field_factory,
apply_rotary_pos_emb_vision)
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
@ -1153,7 +1154,9 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _qwen2vl_field_config(hf_inputs)
return _create_qwen2vl_field_factory(
self.info.get_hf_config().vision_config.spatial_merge_size)(
hf_inputs)
def _get_prompt_updates(
self,

View File

@ -27,13 +27,15 @@ from transformers.models.idefics2.configuration_idefics2 import (
Idefics2Config, Idefics2VisionConfig)
from vllm.attention.layer import MultiHeadAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import run_dp_sharded_vision_model
class Idefics2VisionEmbeddings(nn.Module):
@ -118,6 +120,7 @@ class Idefics2VisionAttention(nn.Module):
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.config = config
@ -130,22 +133,43 @@ class Idefics2VisionAttention(nn.Module):
f" {self.num_heads}).")
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
assert self.num_heads % tp_size == 0
self.num_heads_per_partition = self.num_heads // tp_size
if use_data_parallel:
self.q_size = self.num_heads * self.head_dim
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
3 * self.q_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = ReplicatedLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.attn = MultiHeadAttention(self.num_heads_per_partition,
self.head_dim, self.scale)
@ -169,18 +193,23 @@ class Idefics2VisionMLP(nn.Module):
config: Idefics2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
self.fc1 = ColumnParallelLinear(
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(
config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.fc2 = cls_fc2(
config.intermediate_size,
config.hidden_size,
bias=True,
@ -202,17 +231,21 @@ class Idefics2EncoderLayer(nn.Module):
config: Idefics2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = Idefics2VisionAttention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn")
self.self_attn = Idefics2VisionAttention(
config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel)
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
@ -254,6 +287,7 @@ class Idefics2Encoder(nn.Module):
*,
num_hidden_layers_override: Optional[int] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
@ -267,7 +301,8 @@ class Idefics2Encoder(nn.Module):
self.layers = nn.ModuleList([
Idefics2EncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{layer_idx}")
prefix=f"{prefix}.layers.{layer_idx}",
use_data_parallel=use_data_parallel)
for layer_idx in range(num_hidden_layers)
])
@ -301,17 +336,20 @@ class Idefics2VisionTransformer(nn.Module):
num_hidden_layers_override: Optional[int] = None,
require_post_norm: bool = True,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
embed_dim = config.hidden_size
self.config = config
self.use_data_parallel = use_data_parallel
self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder")
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
@ -340,10 +378,38 @@ class Idefics2VisionTransformer(nn.Module):
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes,
)
encoder_outputs = self.encoder(hidden_states)
if self.use_data_parallel:
encoder_outputs = run_dp_sharded_vision_model(
hidden_states, self.encoder)
else:
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state
def _consolidate_qkv_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
qkv_idx_mappings = {
".self_attn.q_proj": 0,
".self_attn.k_proj": 1,
".self_attn.v_proj": 2,
}
qkv_weights = {}
for name, loaded_weight in weights:
for weight_name, idx in qkv_idx_mappings.items():
if weight_name not in name:
continue
new_name = name.replace(weight_name, ".self_attn.qkv_proj")
if new_name not in qkv_weights:
qkv_weights[new_name] = [None] * 3
qkv_weights[new_name][idx] = loaded_weight
break
else:
yield name, loaded_weight
for key, weight in qkv_weights.items():
qkv_weight = torch.cat(weight, dim=0)
yield key, qkv_weight
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
@ -356,6 +422,9 @@ class Idefics2VisionTransformer(nn.Module):
loaded_params: set[str] = set()
layer_count = len(self.encoder.layers)
if self.use_data_parallel:
weights = self._consolidate_qkv_weights(weights)
for name, loaded_weight in weights:
# skip pooling header
if name.startswith("head."):
@ -373,7 +442,7 @@ class Idefics2VisionTransformer(nn.Module):
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
if weight_name not in name or self.use_data_parallel:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]

View File

@ -31,6 +31,7 @@ from torch import nn
from transformers import LlamaConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -173,7 +174,10 @@ class LlamaAttention(nn.Module):
if is_sliding:
sliding_window = config.sliding_window
self.attn = Attention(
attn_cls = (EncoderOnlyAttention
if attn_type == AttentionType.ENCODER_ONLY else Attention)
self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,

View File

@ -24,7 +24,7 @@
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Callable, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Callable, Literal, Optional, Union
import torch
from torch import nn
@ -49,6 +49,7 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
MultiModalDataParser)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
MiniCPMVDummyInputsBuilder,
@ -61,35 +62,52 @@ from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
CPU_DEVICE = torch.device("cpu")
class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
audio_features: Union[torch.Tensor, list[torch.Tensor]]
class MiniCPMOAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- bns: Batch size * number of audios * number of slices
- bn: Batch size * number of audios
- c: Number of channels
- l: Length
- s: Number of slices
"""
type: Literal["audio_features"] = "audio_features"
audio_features: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bns", "c", "l", dynamic_dims={"l"}),
]
"""
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
Slice here means chunk. Audio that is too long will be split into slices,
which is the same as image.
Padding is used therefore `audio_features` is `torch.Tensor`.
which is the same as image. Padding is used therefore `audio_features` is
`torch.Tensor`.
"""
audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]]
audio_feature_lens: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "s"),
]
"""
Shape: `(batch_size * num_audios, num_slices)`
This should be feature length of each audio slice,
which equals to `audio_features.shape[-1]`
"""
class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
audio_embeds: Union[torch.Tensor, list[torch.Tensor]]
class MiniCPMOAudioEmbeddingInputs(TensorSchema):
"""
Shape: `(batch_size * num_audios, num_slices, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
Dimensions:
- bn: Batch size * number of audios
- s: Number of slices
- h: Hidden size (must match language model backbone)
Length of each slice may vary, so pass it as a list.
"""
type: Literal["audio_embeds"] = "audio_embeds"
audio_embeds: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "s", "h", dynamic_dims={"s"}),
]
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,

View File

@ -778,6 +778,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
# and config class
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.version = get_version_by_config(self.config)
self.llm = self.init_llm(vllm_config=vllm_config,
@ -1325,9 +1326,11 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
prefix: str = "",
) -> nn.Module:
quant_config = self._maybe_ignore_quant_config(quant_config)
model = Idefics2VisionTransformer(config.vision_config,
quant_config=quant_config,
prefix=prefix)
model = Idefics2VisionTransformer(
config.vision_config,
quant_config=quant_config,
prefix=prefix,
use_data_parallel=self.use_data_parallel)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
return model

View File

@ -7,7 +7,7 @@ import torch
from torch import nn
from transformers import ModernBertConfig
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
@ -104,12 +104,12 @@ class ModernBertAttention(nn.Module):
head_size=self.head_dim,
dim=self.head_dim,
base=rope_theta)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
prefix=f"{layer_id}.attn",
attn_type=AttentionType.ENCODER_ONLY,
per_layer_sliding_window=sliding_window)
self.attn = EncoderOnlyAttention(
self.num_heads,
self.head_dim,
self.scaling,
prefix=f"{layer_id}.attn",
per_layer_sliding_window=sliding_window)
self.Wo = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=config.attention_bias)

View File

@ -30,7 +30,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>"
@ -70,6 +70,7 @@ class VisualTokenizer(torch.nn.Module):
visual_vocab_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
@ -77,6 +78,7 @@ class VisualTokenizer(torch.nn.Module):
config=config,
quant_config=quant_config,
prefix=f"{prefix}.vit",
use_data_parallel=use_data_parallel,
)
# reserved tokens for INDICATOR_IDS
head_dim = visual_vocab_size - len(INDICATOR_IDS)
@ -93,31 +95,33 @@ class VisualTokenizer(torch.nn.Module):
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
model_type = config.model_type
if model_type == "siglip2_navit":
return Siglip2NavitModel(config=config, )
return Siglip2NavitModel(config=config,
quant_config=quant_config,
prefix=prefix,
use_data_parallel=use_data_parallel)
raise ValueError(
f"Unsupported visual tokenizer model_type: {model_type}")
@property
def dtype(self):
def dtype(self) -> torch.dtype:
return next(self.head.parameters()).dtype
@property
def device(self):
def device(self) -> torch.device:
return next(self.head.parameters()).device
def tokenize(self, logits):
def tokenize(self, logits: torch.Tensor) -> torch.Tensor:
tokens = torch.softmax(logits, dim=-1,
dtype=torch.float32).to(logits.dtype)
return tokens
def encode(self, pixel_values, grid_thws):
features = self.vit(pixel_values,
grid_thws,
output_hidden_states=True,
return_dict=True)
def encode(self, pixel_values: torch.Tensor,
grid_thws: torch.Tensor) -> torch.Tensor:
features = self.vit(pixel_values, grid_thws)
# refer to qwen2.5-vl patchmerger
seq_len, _ = features.shape
features = features.reshape(seq_len // (self.config.hidden_stride**2),
@ -125,7 +129,8 @@ class VisualTokenizer(torch.nn.Module):
return features
def forward(self, pixel_values, grid_thws) -> torch.Tensor:
def forward(self, pixel_values: torch.Tensor,
grid_thws: torch.Tensor) -> torch.Tensor:
features = self.encode(pixel_values, grid_thws)
logits = self.head(features)
tokens = self.tokenize(logits)
@ -395,7 +400,7 @@ class Ovis2_5MultiModalProcessor(BaseMultiModalProcessor[Ovis2_5ProcessingInfo]
@MULTIMODAL_REGISTRY.register_processor(Ovis2_5MultiModalProcessor,
info=Ovis2_5ProcessingInfo,
dummy_inputs=Ovis2_5DummyInputsBuilder)
class Ovis2_5(nn.Module, SupportsMultiModal):
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -421,9 +426,8 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
text_model_type = self.config.get_text_config().model_type
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
# TODO(Isotr0py): PP support
# self.make_empty_intermediate_tensors = (
# self.language_model.make_empty_intermediate_tensors)
self.make_empty_intermediate_tensors = (
self.get_language_model().make_empty_intermediate_tensors)
def _parse_and_validate_visual_input(
self, is_video,
@ -567,4 +571,4 @@ class Ovis2_5(nn.Module, SupportsMultiModal):
return loader.load_weights(weights)
def get_language_model(self) -> torch.nn.Module:
return self.llm
return self.llm

View File

@ -32,6 +32,7 @@ from torch import nn
from transformers import Qwen2Config
from vllm.attention import Attention, AttentionType
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
@ -51,7 +52,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import is_interleaved
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
@ -159,7 +160,9 @@ class Qwen2Attention(nn.Module):
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(
attn_cls = (EncoderOnlyAttention
if attn_type == AttentionType.ENCODER_ONLY else Attention)
self.attn = attn_cls(
self.num_heads,
self.head_dim,
self.scaling,
@ -439,7 +442,7 @@ class Qwen2Model(nn.Module):
return loaded_params
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -485,6 +488,13 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def forward(
self,
input_ids: torch.Tensor,

View File

@ -25,7 +25,7 @@
from collections.abc import Iterable, Mapping, Sequence
from copy import copy
from functools import partial
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
@ -79,40 +79,57 @@ except (ImportError, ModuleNotFoundError):
logger = init_logger(__name__)
def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_feature_lengths = hf_inputs.get("audio_feature_lengths",
torch.empty((0, )))
def create_qwen2_5_omni_thinker_field_factory(
spatial_merge_size: int
) -> Callable[[Mapping[str, torch.Tensor]], Mapping[str,
MultiModalFieldConfig]]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str,
torch.Tensor]):
audio_feature_lengths = hf_inputs.get("audio_feature_lengths",
torch.empty((0, )))
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_pixel_grid_sizes = image_grid_thw.prod(-1)
image_embed_grid_sizes = (image_pixel_grid_sizes //
spatial_merge_size // spatial_merge_size)
num_videos = len(video_grid_sizes)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size //
spatial_merge_size)
return dict(
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_feature_lengths, dim=1),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos),
)
num_videos = len(video_grid_sizes)
return dict(
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_feature_lengths, dim=1),
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
audio_feature_lengths=MultiModalFieldConfig.batched("audio"),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_pixel_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_embed_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_embed_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
use_audio_in_video=MultiModalFieldConfig.shared(
"video", num_videos),
)
return _qwen2_5_omni_thinker_field_config
class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
def __init__(self, spatial_merge_size: int, *args, **kwargs):
self._spatial_merge_size = spatial_merge_size
super().__init__(self._spatial_merge_size, *args, **kwargs)
def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
@ -124,7 +141,8 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
required_fields={
"input_audio_features", "audio_feature_lengths"
},
fields_factory=_qwen2_5_omni_thinker_field_config,
fields_factory=create_qwen2_5_omni_thinker_field_factory(
self._spatial_merge_size),
)
return super()._parse_audio_data(data)
@ -214,6 +232,8 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return Qwen2_5OmniThinkerMultiModalDataParser(
spatial_merge_size=self.info.get_hf_config(
).vision_config.spatial_merge_size,
target_sr=feature_extractor.sampling_rate)
def _call_hf_processor(
@ -265,7 +285,9 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _qwen2_5_omni_thinker_field_config(hf_inputs)
return create_qwen2_5_omni_thinker_field_factory(
self.info.get_hf_config().vision_config.spatial_merge_size)(
hf_inputs)
def _maybe_apply_prompt_updates(
self,

View File

@ -699,29 +699,46 @@ class Qwen2VisionTransformer(nn.Module):
return loaded_params
def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_grid_sizes = image_grid_thw.prod(-1)
def _create_qwen2vl_field_factory(
spatial_merge_size: int
) -> Callable[
[Mapping[str, torch.Tensor]],
Mapping[str, MultiModalFieldConfig],
]:
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
def _qwen2vl_field_config(hf_inputs: Mapping[str, torch.Tensor]):
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_pixel_grid_sizes = image_grid_thw.prod(-1)
image_embed_grid_sizes = (image_pixel_grid_sizes //
spatial_merge_size // spatial_merge_size)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
video_embed_grid_sizes = (video_grid_sizes // spatial_merge_size //
spatial_merge_size)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_pixel_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_embed_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_embed_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)
return _qwen2vl_field_config
class Qwen2VLMultiModalDataParser(MultiModalDataParser):
def __init__(self, spatial_merge_size: int, *args, **kwargs):
self._spatial_merge_size = spatial_merge_size
super().__init__(*args, **kwargs)
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
@ -731,7 +748,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data,
modality="image",
required_fields={"image_embeds", "image_grid_thw"},
fields_factory=_qwen2vl_field_config,
fields_factory=_create_qwen2vl_field_factory(
self._spatial_merge_size),
)
return super()._parse_image_data(data)
@ -745,7 +763,8 @@ class Qwen2VLMultiModalDataParser(MultiModalDataParser):
data,
modality="video",
required_fields={"video_embeds", "video_grid_thw"},
fields_factory=_qwen2vl_field_config,
fields_factory=_create_qwen2vl_field_factory(
self._spatial_merge_size),
)
return super()._parse_video_data(data)
@ -967,7 +986,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
return Qwen2VLMultiModalDataParser()
return Qwen2VLMultiModalDataParser(
self.info.get_hf_config().vision_config.spatial_merge_size)
def _get_prompt_updates(
self,
@ -1010,7 +1030,9 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _qwen2vl_field_config(hf_inputs)
return _create_qwen2vl_field_factory(
self.info.get_hf_config().vision_config.spatial_merge_size)(
hf_inputs)
@MULTIMODAL_REGISTRY.register_processor(Qwen2VLMultiModalProcessor,

View File

@ -130,6 +130,7 @@ _TEXT_GENERATION_MODELS = {
"Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"),
"Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"SeedOssForCausalLM": ("seed_oss", "SeedOssForCausalLM"),
"Step3TextForCausalLM": ("step3_text", "Step3TextForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),

View File

@ -0,0 +1,487 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The Seed team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only SeedOss model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from transformers import PretrainedConfig as SeedOssConfig
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
class SeedOssMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class SeedOssAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
self.head_dim = head_dim
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=attn_type,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class SeedOssDecoderLayer(nn.Module):
def __init__(
self,
config: SeedOssConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
# By default, SeedOss uses causal attention as it is a
# decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = SeedOssAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = SeedOssMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
})
class SeedOssModel(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
assert config.max_window_layers == config.num_hidden_layers, (
"Sliding window for some but all layers is not supported. "
"This model uses sliding window but `max_window_layers` = {} "
"is less than `num_hidden_layers` = {}. Please open an issue "
"to discuss this feature.".format(
config.max_window_layers,
config.num_hidden_layers,
))
self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to SeedDecoderLayer
decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: decoder_layer_type(config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = SeedOssModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@ -3,16 +3,24 @@
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""
from typing import Optional, Union
from collections.abc import Iterable
from typing import Optional
import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
from transformers.activations import ACT2FN
from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutputWithNoAttention
from vllm.config import QuantizationConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import _Backend
from .vision import get_vit_attn_backend
@ -48,10 +56,11 @@ class Siglip2VisionEmbeddings(nn.Module):
# siglip2 naflex
if self.num_patches > 0:
self.patch_embedding = nn.Linear(
in_features=config.num_channels * self.patch_size *
self.patch_embedding = ReplicatedLinear(
input_size=config.num_channels * self.patch_size *
self.patch_size,
out_features=self.embed_dim,
output_size=self.embed_dim,
return_bias=False,
)
if self.preserve_original_pe:
self.position_embedding_size = int(self.num_patches**0.5)
@ -89,7 +98,7 @@ class Siglip2VisionEmbeddings(nn.Module):
# Apply patch embeddings to already patchified pixel values
target_dtype = self.patch_embedding.weight.dtype
if isinstance(self.patch_embedding, nn.Linear):
if isinstance(self.patch_embedding, LinearBase):
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype))
elif isinstance(self.patch_embedding, nn.Conv2d):
@ -184,7 +193,13 @@ def apply_rotary_pos_emb(
class Siglip2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@ -199,11 +214,25 @@ class Siglip2Attention(nn.Module):
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
# TODO(Isotr0py): Enable data parallel after we support
# disabling TP on parallel linear layer
self.qkv_proj = QKVParallelLinear(
hidden_size=self.embed_dim,
head_size=self.head_dim,
total_num_heads=self.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.out_proj = RowParallelLinear(
input_size=self.embed_dim,
output_size=self.embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.use_rope = config.use_rope
# Detect attention implementation.
@ -228,13 +257,15 @@ class Siglip2Attention(nn.Module):
seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
qkv_states, _ = self.qkv_proj(hidden_states)
queries, keys, values = qkv_states.chunk(3, dim=-1)
queries = queries.view(seq_length, self.num_heads, self.head_dim)
keys = keys.view(seq_length, self.num_heads, self.head_dim)
values = values.view(seq_length, self.num_heads, self.head_dim)
queries = queries.view(seq_length, self.num_heads_per_partition,
self.head_dim)
keys = keys.view(seq_length, self.num_heads_per_partition,
self.head_dim)
values = values.view(seq_length, self.num_heads_per_partition,
self.head_dim)
if self.use_rope:
cos, sin = position_embeddings
@ -276,41 +307,72 @@ class Siglip2Attention(nn.Module):
v_i,
dropout_p=0.0)
# (1, num_heads, seq_len, head_dim) -> (seq_len, embed_dim)
output_i = output_i.transpose(1, 2).reshape(-1, self.embed_dim)
output_i = output_i.transpose(1, 2).reshape(
end_idx - start_idx, -1)
outputs.append(output_i)
attn_output = torch.cat(outputs, dim=0)
attn_output = self.out_proj(attn_output)
attn_output, _ = self.out_proj(attn_output)
return attn_output
class Siglip2MLP(nn.Module):
def __init__(self, config):
def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
self.activation_fn = get_act_fn(config.hidden_act)
# TODO(Isotr0py): Enable data parallel after we support
# disabling TP on parallel linear layer
self.fc1 = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states
class Siglip2EncoderLayer(nn.Module):
def __init__(self, config: PretrainedConfig):
def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.self_attn = Siglip2Attention(config)
self.self_attn = Siglip2Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_data_parallel=use_data_parallel)
self.layer_norm2 = nn.LayerNorm(self.embed_dim,
eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config)
self.mlp = Siglip2MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel)
def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
position_embeddings: torch.Tensor) -> tuple[torch.FloatTensor]:
@ -347,14 +409,22 @@ class Siglip2Encoder(nn.Module):
config: PretrainedConfig
"""
def __init__(self, config: PretrainedConfig):
def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
self.layers = nn.ModuleList([
Siglip2EncoderLayer(config)
for _ in range(config.num_hidden_layers)
Siglip2EncoderLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}",
use_data_parallel=use_data_parallel)
for idx in range(config.num_hidden_layers)
])
self.gradient_checkpointing = False
self.rotary_pos_emb = VisionRotaryEmbedding(
config.hidden_size // config.num_attention_heads // 2)
@ -445,13 +515,11 @@ class Siglip2Encoder(nn.Module):
return window_index, cu_window_seqlens
# Ignore copy
def forward(
self,
inputs_embeds,
inputs_embeds: torch.Tensor,
grid_thws: torch.Tensor,
output_hidden_states: bool = False,
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, ...]]]:
) -> torch.Tensor:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape
@ -506,7 +574,6 @@ class Siglip2Encoder(nn.Module):
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
reverse_indices = torch.argsort(window_index)
encoder_states = () if output_hidden_states else None
hidden_states = inputs_embeds
for index, block in enumerate(self.layers):
@ -517,45 +584,40 @@ class Siglip2Encoder(nn.Module):
cu_seqlens_tmp = cu_window_seqlens
hidden_states = block(hidden_states, cu_seqlens_tmp,
position_embeddings)
if output_hidden_states:
hidden_states_ = hidden_states.reshape(
seq_len // self.spatial_merge_unit,
self.spatial_merge_unit, -1)
encoder_states += (hidden_states_[reverse_indices, :].reshape(
seq_len, -1), )
# tokens = self.post_trunk_norm(tokens)
hidden_states = hidden_states.reshape(
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[reverse_indices, :].reshape(seq_len, -1)
return hidden_states, encoder_states
return hidden_states
class Siglip2VisionTransformer(nn.Module):
def __init__(self, config: PretrainedConfig):
def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = Siglip2VisionEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.encoder = Siglip2Encoder(config,
quant_config=quant_config,
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
self._use_flash_attention_2 = \
(config._attn_implementation == "flash_attention_2")
def forward(
self,
pixel_values: torch.FloatTensor,
grid_thws: torch.LongTensor,
output_hidden_states: Optional[bool] = True,
return_dict: Optional[bool] = True,
) -> Union[
tuple[torch.Tensor],
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
BaseModelOutputWithNoAttention,
]:
) -> torch.Tensor:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width)
@ -563,45 +625,64 @@ class Siglip2VisionTransformer(nn.Module):
"""
hidden_states = self.embeddings(pixel_values, grid_thws)
last_hidden_state, hidden_states = self.encoder(
hidden_states, grid_thws, output_hidden_states)
last_hidden_state = self.encoder(hidden_states, grid_thws)
last_hidden_state = self.post_layernorm(last_hidden_state)
if not return_dict:
output = (last_hidden_state, )
output += (hidden_states, ) if output_hidden_states else ()
return output
return last_hidden_state
class Siglip2NavitModel(torch.nn.Module):
def __init__(self, config: PretrainedConfig):
def __init__(
self,
config: Siglip2VisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
self.vision_model = Siglip2VisionTransformer(config)
self.vision_model = Siglip2VisionTransformer(
config,
quant_config=quant_config,
prefix=f"{prefix}.vision_model",
use_data_parallel=use_data_parallel)
def forward(
self,
pixel_values: torch.FloatTensor,
grid_thws: torch.LongTensor,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[
tuple[torch.Tensor],
tuple[torch.Tensor, tuple[torch.Tensor, ...]],
BaseModelOutputWithNoAttention,
]:
if output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states
if return_dict is None:
return_dict = self.config.use_return_dict
) -> torch.Tensor:
return self.vision_model(
pixel_values=pixel_values,
grid_thws=grid_thws,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params

View File

@ -461,6 +461,8 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor,
num_chunks_per_rank, ...]
vision_embeddings = vision_model(image_input_per_rank)
# Ensure tensor is contiguous before all_gather
vision_embeddings = vision_embeddings.contiguous()
vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings,
dim=0)
vision_embeddings = vision_embeddings[:num_chunks, ...]

View File

@ -171,7 +171,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8", "mxfp4"
"quark", "ptpc_fp8", "mxfp4", "petit_nvfp4"
]
@classmethod

View File

@ -49,12 +49,11 @@ def decode_tokens(
`skip_special_tokens=None` means to use the backend's default
settings.
"""
decode_method = getattr(tokenizer, "_decode", tokenizer.decode)
if skip_special_tokens is not None:
return decode_method(token_ids,
skip_special_tokens=skip_special_tokens)
return tokenizer.decode(token_ids,
skip_special_tokens=skip_special_tokens)
return decode_method(token_ids)
return tokenizer.decode(token_ids)
def encode_tokens(

View File

@ -1440,6 +1440,12 @@ def _patched_set_stream(stream: torch.cuda.Stream) -> None:
torch.cuda.set_stream = _patched_set_stream
class _StreamPlaceholder:
def __init__(self):
self.synchronize = lambda: None
def current_stream() -> torch.cuda.Stream:
"""
replace `torch.cuda.current_stream()` with `vllm.utils.current_stream()`.
@ -1459,8 +1465,18 @@ def current_stream() -> torch.cuda.Stream:
# On ROCm using the default 0 stream in combination with RCCL
# is hurting performance. Therefore creating a dedicated stream
# per process
_current_stream_tls.value = torch.cuda.Stream(
) if current_platform.is_rocm() else torch.cuda.current_stream()
if current_platform.is_rocm():
_current_stream_tls.value = torch.cuda.Stream()
elif current_platform.is_cpu():
_current_stream_tls.value = _StreamPlaceholder()
else:
current_stream = current_platform.current_stream
if current_stream is not None:
_current_stream_tls.value = current_stream()
else:
raise ValueError(
"Fail to set current stream, current platform "
"may not support current_stream with torch API")
return _current_stream_tls.value
@ -2466,7 +2482,7 @@ class PlaceholderModule(_PlaceholderBase):
A placeholder object to use when a module does not exist.
This enables more informative errors when trying to access attributes
of a module that does not exists.
of a module that does not exist.
"""
def __init__(self, name: str) -> None:
@ -3093,7 +3109,7 @@ class LazyLoader(types.ModuleType):
"""
LazyLoader module borrowed from Tensorflow
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py
with a addition of "module caching".
with an addition of "module caching".
Lazily import a module, mainly to avoid pulling in large dependencies.
Modules such as `xgrammar` might do additional side effects, so we

View File

@ -132,6 +132,11 @@ def has_nvidia_artifactory() -> bool:
This checks connectivity to the kernel inference library artifactory
which is required for downloading certain cubin kernels like TRTLLM FHMA.
"""
# Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when
# it's true, we could assume the cubins are available.
if envs.VLLM_HAS_FLASHINFER_CUBIN:
return True
try:
# Use a short timeout to avoid blocking for too long
response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5)

Some files were not shown because too many files have changed in this diff Show More